diff --git a/kittens/diff/highlight.py b/kittens/diff/highlight.py index 0e13407fb..968b195b3 100644 --- a/kittens/diff/highlight.py +++ b/kittens/diff/highlight.py @@ -4,7 +4,10 @@ import concurrent import os import re -from typing import IO, Dict, Iterable, List, Optional, Tuple, Union, cast +from concurrent.futures import ProcessPoolExecutor +from typing import ( + IO, Dict, Iterable, Iterator, List, Optional, Tuple, Union, cast +) from pygments import highlight # type: ignore from pygments.formatter import Formatter # type: ignore @@ -136,10 +139,22 @@ def highlight_for_diff(path: str, aliases: Dict[str, str]) -> DiffHighlight: return ans +process_pool_executor: Optional[ProcessPoolExecutor] = None + + +def get_highlight_processes() -> Iterator[int]: + if process_pool_executor is None: + return + for pid in process_pool_executor._processes: + yield pid + + def highlight_collection(collection: Collection, aliases: Optional[Dict[str, str]] = None) -> Union[str, Dict[str, DiffHighlight]]: + global process_pool_executor jobs = {} ans: Dict[str, DiffHighlight] = {} with get_process_pool_executor(prefer_fork=True) as executor: + process_pool_executor = executor for path, item_type, other_path in collection: if item_type != 'rename': for p in (path, other_path): @@ -159,8 +174,9 @@ def highlight_collection(collection: Collection, aliases: Optional[Dict[str, str def main() -> None: # kitty +runpy "from kittens.diff.highlight import main; main()" file - from .options.types import defaults import sys + + from .options.types import defaults initialize_highlighter() with open(sys.argv[-1]) as f: highlighted = highlight_data(f.read(), f.name, defaults.syntax_aliases) diff --git a/kittens/diff/main.py b/kittens/diff/main.py index a10ff8c70..32a84f5aa 100644 --- a/kittens/diff/main.py +++ b/kittens/diff/main.py @@ -13,7 +13,7 @@ from contextlib import suppress from functools import partial from gettext import gettext as _ from typing import ( - Any, DefaultDict, Dict, Iterable, List, Optional, Tuple, Union + Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple, Union ) from kitty.cli import CONFIG_HELP, parse_args @@ -44,7 +44,8 @@ from .search import BadRegex, Search try: from .highlight import ( - DiffHighlight, highlight_collection, initialize_highlighter + DiffHighlight, get_highlight_processes, highlight_collection, + initialize_highlighter ) has_highlighter = True DiffHighlight @@ -54,6 +55,10 @@ except ImportError: def highlight_collection(collection: 'Collection', aliases: Optional[Dict[str, str]] = None) -> Union[str, Dict[str, 'DiffHighlight']]: return '' + def get_highlight_processes() -> Iterator[int]: + if has_highlighter: + yield -1 + INITIALIZING, COLLECTED, DIFFED, COMMAND, MESSAGE = range(5) @@ -90,14 +95,18 @@ class DiffHandler(Handler): if self.current_context_count < 0: self.current_context_count = self.original_context_count = self.opts.num_context_lines self.highlighting_done = False + self.doing_background_work = '' self.restore_position: Optional[Reference] = None for key_def, action in self.opts.key_definitions.items(): self.add_shortcut(action, key_def) + def terminate(self, return_code: int = 0) -> None: + self.quit_loop(return_code) + def perform_action(self, action: KeyAction) -> None: func, args = action if func == 'quit': - self.quit_loop(0) + self.terminate() return if self.state <= DIFFED: if func == 'scroll_by': @@ -130,6 +139,7 @@ class DiffHandler(Handler): def create_collection(self) -> None: def collect_done(collection: Collection) -> None: + self.doing_background_work = '' self.collection = collection self.state = COLLECTED self.generate_diff() @@ -139,13 +149,15 @@ class DiffHandler(Handler): self.asyncio_loop.call_soon_threadsafe(collect_done, collection) self.asyncio_loop.run_in_executor(None, collect, self.left, self.right) + self.doing_background_work = 'collecting' def generate_diff(self) -> None: def diff_done(diff_map: Union[str, Dict[str, Patch]]) -> None: + self.doing_background_work = '' if isinstance(diff_map, str): self.report_traceback_on_exit = diff_map - self.quit_loop(1) + self.terminate(1) return self.state = DIFFED self.diff_map = diff_map @@ -163,7 +175,7 @@ class DiffHandler(Handler): initialize_highlighter(self.opts.pygments_style) except StyleNotFound as e: self.report_traceback_on_exit = str(e) - self.quit_loop(1) + self.terminate(1) return self.syntax_highlight() @@ -172,13 +184,15 @@ class DiffHandler(Handler): self.asyncio_loop.call_soon_threadsafe(diff_done, diff_map) self.asyncio_loop.run_in_executor(None, diff, self.collection, self.current_context_count) + self.doing_background_work = 'diffing' def syntax_highlight(self) -> None: def highlighting_done(hdata: Union[str, Dict[str, 'DiffHighlight']]) -> None: + self.doing_background_work = '' if isinstance(hdata, str): self.report_traceback_on_exit = hdata - self.quit_loop(1) + self.terminate(1) return set_highlight_data(hdata) self.render_diff() @@ -189,6 +203,7 @@ class DiffHandler(Handler): self.asyncio_loop.call_soon_threadsafe(highlighting_done, result) self.asyncio_loop.run_in_executor(None, highlight, self.collection, self.opts.syntax_aliases) + self.doing_background_work = 'highlighting' def calculate_statistics(self) -> None: self.added_count = self.collection.added_count @@ -532,10 +547,10 @@ class DiffHandler(Handler): self.draw_screen() def on_interrupt(self) -> None: - self.quit_loop(1) + self.terminate(1) def on_eot(self) -> None: - self.quit_loop(1) + self.terminate(1) OPTIONS = partial('''\ @@ -639,9 +654,10 @@ def main(args: List[str]) -> None: for message in showwarning.warnings: from kitty.utils import safe_print safe_print(message, file=sys.stderr) - highlight_processes = getattr(highlight_collection, 'processes', ()) - terminate_processes(tuple(highlight_processes)) - terminate_processes(tuple(worker_processes)) + if handler.doing_background_work == 'highlighting': + terminate_processes(tuple(get_highlight_processes())) + elif handler.doing_background_work == 'diffing': + terminate_processes(tuple(worker_processes)) if loop.return_code != 0: if handler.report_traceback_on_exit: print(handler.report_traceback_on_exit, file=sys.stderr)