diff --git a/kitty/child.py b/kitty/child.py index 0596fca16..3ea941e33 100644 --- a/kitty/child.py +++ b/kitty/child.py @@ -194,11 +194,24 @@ class ProcessDesc(TypedDict): cmdline: Optional[Sequence[str]] +def is_prewarmable(argv: Sequence[str]) -> bool: + if len(argv) < 3 or os.path.basename(argv[0]) != 'kitty': + return False + if argv[1][:1] not in '@+': + return False + if argv[1][0] == '@': + return True + if argv[1] == '+': + return argv[2] != 'open' + return argv[1] != '+open' + + class Child: child_fd: Optional[int] = None pid: Optional[int] = None forked = False + is_prewarmed = False def __init__( self, @@ -262,14 +275,16 @@ class Child: self.forked = True master, slave = openpty() stdin, self.stdin = self.stdin, None - ready_read_fd, ready_write_fd = os.pipe() - remove_cloexec(ready_read_fd) - if stdin is not None: - stdin_read_fd, stdin_write_fd = os.pipe() - remove_cloexec(stdin_read_fd) - else: - stdin_read_fd = stdin_write_fd = -1 - env = tuple(f'{k}={v}' for k, v in self.final_env().items()) + self.is_prewarmed = is_prewarmable(self.argv) + if not self.is_prewarmed: + ready_read_fd, ready_write_fd = os.pipe() + remove_cloexec(ready_read_fd) + if stdin is not None: + stdin_read_fd, stdin_write_fd = os.pipe() + remove_cloexec(stdin_read_fd) + else: + stdin_read_fd = stdin_write_fd = -1 + env = tuple(f'{k}={v}' for k, v in self.final_env().items()) argv = list(self.argv) exe = argv[0] if is_macos and exe == shell_path: @@ -290,24 +305,33 @@ class Child: argv[0] = (f'-{exe.split("/")[-1]}') self.final_exe = which(exe) or exe self.final_argv0 = argv[0] - pid = fast_data_types.spawn( - self.final_exe, self.cwd, tuple(argv), env, master, slave, - stdin_read_fd, stdin_write_fd, ready_read_fd, ready_write_fd, tuple(handled_signals)) + if self.is_prewarmed: + fe = self.final_env() + self.prewarmed_child = fast_data_types.get_boss().prewarm(slave, self.argv, self.cwd, fe, stdin) + pid = self.prewarmed_child.child_process_pid + else: + pid = fast_data_types.spawn( + self.final_exe, self.cwd, tuple(argv), env, master, slave, stdin_read_fd, stdin_write_fd, + ready_read_fd, ready_write_fd, tuple(handled_signals)) os.close(slave) self.pid = pid self.child_fd = master - if stdin is not None: - os.close(stdin_read_fd) - fast_data_types.thread_write(stdin_write_fd, stdin) - os.close(ready_read_fd) - self.terminal_ready_fd = ready_write_fd + if not self.is_prewarmed: + if stdin is not None: + os.close(stdin_read_fd) + fast_data_types.thread_write(stdin_write_fd, stdin) + os.close(ready_read_fd) + self.terminal_ready_fd = ready_write_fd if self.child_fd is not None: remove_blocking(self.child_fd) return pid def mark_terminal_ready(self) -> None: - os.close(self.terminal_ready_fd) - self.terminal_ready_fd = -1 + if self.is_prewarmed: + fast_data_types.get_boss().prewarm.mark_child_as_ready(self.prewarmed_child.child_id) + else: + os.close(self.terminal_ready_fd) + self.terminal_ready_fd = -1 def cmdline_of_pid(self, pid: int) -> List[str]: try: diff --git a/kitty/prewarm.py b/kitty/prewarm.py index 079973e79..bb2830bf9 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -15,7 +15,6 @@ from typing import ( IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast ) -from kitty.child import remove_cloexec from kitty.constants import kitty_exe from kitty.entry_points import main as main_entry_point from kitty.fast_data_types import ( @@ -27,9 +26,7 @@ if TYPE_CHECKING: from _typeshed import ReadableBuffer, WriteableBuffer -hangup_events = select.POLLHUP -error_events = select.POLLERR | select.POLLNVAL -basic_events = hangup_events | error_events +error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP class PrewarmProcessFailed(Exception): @@ -98,11 +95,11 @@ class PrewarmProcess: os.set_blocking(self.write_to_process_fd, False) os.set_blocking(self.read_from_process_fd, False) self.poll = select.poll() - self.poll.register(self.process.stdout.fileno(), select.POLLIN | basic_events) + self.poll.register(self.process.stdout.fileno(), select.POLLIN) def poll_to_send(self, yes: bool = True) -> None: if yes: - self.poll.register(self.write_to_process_fd, select.POLLOUT | basic_events) + self.poll.register(self.write_to_process_fd, select.POLLOUT) else: self.poll.unregister(self.write_to_process_fd) @@ -143,7 +140,7 @@ class PrewarmProcess: st = time.monotonic() while time.monotonic() - st < 2: for (fd, event) in self.poll.poll(0.2): - if event & basic_events: + if event & error_events: raise PrewarmProcessFailed('Failed doing I/O with prewarm process') if fd == self.read_from_process_fd and event & select.POLLIN: d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE) @@ -171,8 +168,8 @@ class PrewarmProcess: while time.monotonic() - st < timeout and output_buf: self.poll_to_send(bool(output_buf)) for (fd, event) in self.poll.poll(0.2): - if event & basic_events: - raise PrewarmProcessFailed('Failed doing I/O with prewarm process') + if event & error_events: + raise PrewarmProcessFailed('Failed doing I/O with prewarm process: {event}') if fd == self.write_to_process_fd and event & select.POLLOUT: n = os.write(self.write_to_process_fd, output_buf) output_buf = output_buf[n:] @@ -198,7 +195,8 @@ def prewarm() -> None: reload_kitty_config() from kittens.runner import all_kitten_names for kitten in all_kitten_names(): - import_module(f'kittens.{kitten}.main') + with suppress(Exception): + import_module(f'kittens.{kitten}.main') import_module('kitty.complete') @@ -249,12 +247,8 @@ class MemoryViewReadWrapper(io.TextIOWrapper): def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn: cwd = cmd.get('cwd') if cwd: - try: + with suppress(OSError): os.chdir(cwd) - except OSError: - with suppress(OSError): - os.chdir('/') - os.setsid() env = cmd.get('env') if env is not None: os.environ.clear() @@ -264,13 +258,8 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn: sys.argv = list(argv) poll = select.poll() poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP) - poll.poll() + tuple(poll.poll()) os.close(ready_fd) - tty_name = cmd.get('tty_name') - if tty_name: - sys.__stdout__.flush() - sys.__stderr__.flush() - establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) main_entry_point() raise SystemExit(0) @@ -285,6 +274,8 @@ def fork(shm_address: str, ready_fd: int) -> int: pos = shm.tell() shm.unlink_on_exit = False + r, w = os.pipe() + os.set_inheritable(r, False) try: child_pid = os.fork() except OSError: @@ -293,8 +284,23 @@ def fork(shm_address: str, ready_fd: int) -> int: pass if child_pid: # master process + os.close(w) + try: + poll = select.poll() + poll.register(r, select.POLLIN | select.POLLHUP | select.POLLERR) + tuple(poll.poll()) + finally: + os.close(r) return child_pid # child process + os.setsid() + tty_name = cmd.get('tty_name') + if tty_name: + sys.__stdout__.flush() + sys.__stderr__.flush() + establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) + os.write(w, b'1') + os.close(w) if shm.unlink_on_exit: child_main(cmd, ready_fd) else: @@ -321,15 +327,16 @@ def main(notify_child_death_fd: int) -> None: stdout_fd = sys.__stdout__.fileno() os.set_blocking(stdout_fd, False) poll = select.poll() - poll.register(stdin_fd, select.POLLIN | basic_events) - poll.register(read_signal_fd, select.POLLIN | basic_events) + poll.register(stdin_fd, select.POLLIN) + poll.register(read_signal_fd, select.POLLIN) input_buf = output_buf = child_death_buf = b'' child_ready_fds: Dict[int, int] = {} child_id_map: Dict[int, int] = {} self_pid = os.getpid() + os.setsid() def check_event(event: int, err_msg: str) -> None: - if event & hangup_events: + if event & select.POLLHUP: raise SystemExit(0) if event & error_events: raise SystemExit(err_msg) @@ -357,10 +364,10 @@ def main(notify_child_death_fd: int) -> None: os.write(cfd, b'1') os.close(cfd) elif cmd == 'fork': - read_fd, write_fd = safe_pipe(False) - remove_cloexec(read_fd) + r, w = os.pipe() + os.set_inheritable(w, False) try: - child_pid = fork(payload, read_fd) + child_pid = fork(payload, r) except Exception as e: es = str(e).replace('\n', ' ') output_buf += f'ERR:{es}\n'.encode() @@ -368,11 +375,11 @@ def main(notify_child_death_fd: int) -> None: if os.getpid() == self_pid: child_id = len(child_id_map) + 1 child_id_map[child_id] = child_pid - child_ready_fds[child_id] = write_fd + child_ready_fds[child_id] = w output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode() finally: if os.getpid() == self_pid: - os.close(read_fd) + os.close(r) elif cmd == 'echo': output_buf += f'{payload}\n'.encode() @@ -428,14 +435,16 @@ def main(notify_child_death_fd: int) -> None: break if matched_child_id > -1: del child_id_map[matched_child_id] - child_ready_fds.pop(matched_child_id, None) + ofd = child_ready_fds.pop(matched_child_id, None) + if ofd is not None: + os.close(ofd) child_death_buf += f'{pid}\n'.encode() try: while True: if output_buf: - poll.register(stdout_fd, select.POLLOUT | basic_events) + poll.register(stdout_fd, select.POLLOUT) if child_death_buf: - poll.register(notify_child_death_fd, select.POLLOUT | basic_events) + poll.register(notify_child_death_fd, select.POLLOUT) for (q, event) in poll.poll(): if q == stdin_fd: handle_input(event)