From cf667b8c4734f937f2f1bd90a4d1627b3aa5e5e5 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 13 Jun 2022 20:40:04 +0530 Subject: [PATCH] Use a signal handler instead of a pipe for child death notification --- kitty/prewarm.py | 78 ++++++++++++++++++++++++++++-------------------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/kitty/prewarm.py b/kitty/prewarm.py index 7160b8eb5..c75d0f786 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -21,10 +21,13 @@ from typing import ( from kitty.constants import kitty_exe, running_in_kitty from kitty.entry_points import main as main_entry_point from kitty.fast_data_types import ( - establish_controlling_tty, get_options, safe_pipe, set_options + CLD_EXITED, CLD_KILLED, establish_controlling_tty, get_options, + install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe, + set_options ) from kitty.options.types import Options from kitty.shm import SharedMemory +from kitty.types import SignalInfo from kitty.utils import log_error if TYPE_CHECKING: @@ -32,7 +35,7 @@ if TYPE_CHECKING: error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP -TIMEOUT = 4.0 +TIMEOUT = 5.0 class PrewarmProcessFailed(Exception): @@ -288,7 +291,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn: raise SystemExit(0) -def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int, int]: +def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]: sz = pos = 0 with SharedMemory(name=shm_address, unlink_on_exit=True) as shm: data = shm.read_data_with_size() @@ -312,13 +315,10 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int, os.close(ready_fd_read) poll = select.poll() poll.register(r, select.POLLIN) - for (fd, event) in poll.poll(): - if event & select.POLLIN: - os.read(r, 1) - return child_pid, r, ready_fd_write - else: - raise ValueError('Child process pipe failed') + tuple(poll.poll()) + return child_pid, ready_fd_write # child process + remove_signal_handlers() os.close(r) os.close(ready_fd_write) for fd in all_non_child_fds: @@ -330,7 +330,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int, sys.__stdout__.flush() sys.__stderr__.flush() establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) - os.write(w, b'1') # this will be closed on process exit and thereby used to detect child death + os.close(w) if shm.unlink_on_exit: child_main(cmd, ready_fd_read) else: @@ -349,12 +349,13 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: os.set_blocking(notify_child_death_fd, False) os.set_blocking(stdin_fd, False) os.set_blocking(stdout_fd, False) + signal_read_fd = install_signal_handlers(signal.SIGCHLD)[0] poll = select.poll() poll.register(stdin_fd, select.POLLIN) + poll.register(signal_read_fd, select.POLLIN) input_buf = output_buf = child_death_buf = b'' child_ready_fds: Dict[int, int] = {} - child_death_fds: Dict[int, int] = {} - child_id_map: Dict[int, int] = {} + child_pid_map: Dict[int, int] = {} child_id_counter = count() self_pid = os.getpid() # runpy issues a warning when running modules that have already been @@ -366,8 +367,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: yield notify_child_death_fd yield stdin_fd yield stdout_fd + # the signal fds are closed by remove_signal_handlers() yield from child_ready_fds.values() - yield from child_death_fds.keys() def check_event(event: int, err_msg: str) -> None: if event & select.POLLHUP: @@ -377,7 +378,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: def handle_input(event: int) -> None: nonlocal input_buf, output_buf - check_event(event, 'Polling of STDIN failed') + check_event(event, 'Polling of input pipe failed') if not (event & select.POLLIN): return d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE) @@ -392,31 +393,29 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: reload_kitty_config(payload) elif cmd == 'ready': child_id = int(payload) - cfd = child_ready_fds.pop(child_id) + cfd = child_ready_fds.pop(child_id, None) if cfd is not None: os.close(cfd) elif cmd == 'quit': raise SystemExit(0) elif cmd == 'fork': try: - child_pid, child_death_fd, ready_fd_write = fork(payload, get_all_non_child_fds()) + child_pid, ready_fd_write = fork(payload, get_all_non_child_fds()) except Exception as e: es = str(e).replace('\n', ' ') output_buf += f'ERR:{es}\n'.encode() else: if os.getpid() == self_pid: child_id = next(child_id_counter) - child_id_map[child_id] = child_pid + child_pid_map[child_pid] = child_id child_ready_fds[child_id] = ready_fd_write - child_death_fds[child_death_fd] = child_id - poll.register(child_death_fd, select.POLLIN) output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode() elif cmd == 'echo': output_buf += f'{payload}\n'.encode() def handle_output(event: int) -> None: nonlocal output_buf - check_event(event, 'Polling of STDOUT failed') + check_event(event, 'Polling of output pipe failed') if not (event & select.POLLOUT): return if output_buf: @@ -429,7 +428,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: def handle_notify_child_death(event: int) -> None: nonlocal child_death_buf - check_event(event, 'Polling of notify child death fd failed') + check_event(event, 'Polling of notify child death pipe failed') if not (event & select.POLLOUT): return if child_death_buf: @@ -440,17 +439,33 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: if not child_death_buf: poll.unregister(notify_child_death_fd) - def handle_child_death(dead_child_fd: int, dead_child_id: int) -> None: + def handle_child_death(dead_child_id: int, dead_child_pid: int) -> None: nonlocal child_death_buf - poll.unregister(dead_child_fd) - del child_death_fds[dead_child_fd] xfd = child_ready_fds.pop(dead_child_id, None) if xfd is not None: os.close(xfd) - dead_child_pid = child_id_map.pop(dead_child_id, None) - if dead_child_pid is not None: - wait_for_child_death(dead_child_pid) - child_death_buf += f'{dead_child_pid}\n'.encode() + child_death_buf += f'{dead_child_pid}\n'.encode() + + def handle_signals(event: int) -> None: + check_event(event, 'Polling of signal pipe failed') + if not event & select.POLLIN: + return + + def handle_signal(siginfo: SignalInfo) -> None: + if siginfo.si_signo != signal.SIGCHLD or siginfo.si_code not in (CLD_KILLED, CLD_EXITED): + return + while True: + try: + pid, status = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + pid = 0 + if not pid: + break + child_id = child_pid_map.pop(pid, None) + if child_id is not None: + handle_child_death(child_id, pid) + + read_signals(signal_read_fd, handle_signal) try: while True: @@ -463,12 +478,10 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: handle_input(event) elif q == stdout_fd: handle_output(event) + elif q == signal_read_fd: + handle_signals(event) elif q == notify_child_death_fd: handle_notify_child_death(event) - else: - dead_child_id = child_death_fds.get(q) - if dead_child_id is not None and event & select.POLLHUP: - handle_child_death(q, dead_child_id) except (KeyboardInterrupt, EOFError, BrokenPipeError): if os.getpid() == self_pid: raise SystemExit(1) @@ -480,6 +493,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: raise finally: if os.getpid() == self_pid: + remove_signal_handlers() for fmd in child_ready_fds.values(): with suppress(OSError): os.close(fmd)