From 9c30cd88918bb05d6a5aa17485f2af7b932f74bf Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Wed, 13 Jul 2022 09:36:12 +0530 Subject: [PATCH] Use a process supervisor for socket workers This simplifies the code and also allows SIGTSTP to work as the worker process is no longer in an orphaned process group. --- kitty/child.c | 3 +- kitty/fast_data_types.pyi | 2 +- kitty/prewarm.py | 401 +++++++++++++++++--------------------- 3 files changed, 182 insertions(+), 224 deletions(-) diff --git a/kitty/child.c b/kitty/child.c index a9e660b69..c3610f75c 100644 --- a/kitty/child.c +++ b/kitty/child.c @@ -194,10 +194,9 @@ establish_controlling_tty(PyObject *self UNUSED, PyObject *args) { if (stdin_fd > -1 && safe_dup2(tfd, stdin_fd) == -1) fail(); if (stdout_fd > -1 && safe_dup2(tfd, stdout_fd) == -1) fail(); if (stderr_fd > -1 && safe_dup2(tfd, stderr_fd) == -1) fail(); - cleanup(); #undef cleanup #undef fail - Py_RETURN_NONE; + return PyLong_FromLong(tfd); } static PyMethodDef module_methods[] = { diff --git a/kitty/fast_data_types.pyi b/kitty/fast_data_types.pyi index b0570b433..ab17d6057 100644 --- a/kitty/fast_data_types.pyi +++ b/kitty/fast_data_types.pyi @@ -1395,7 +1395,7 @@ def sigqueue(pid: int, signal: int, value: int) -> None: pass -def establish_controlling_tty(tty_name: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None: +def establish_controlling_tty(tty_name: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> int: pass diff --git a/kitty/prewarm.py b/kitty/prewarm.py index 541823ca3..066224789 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -47,6 +47,9 @@ def restore_python_signal_handlers() -> None: signal.signal(signal.SIGPIPE, signal.SIG_IGN) signal.signal(signal.SIGUSR1, signal.SIG_DFL) signal.signal(signal.SIGCHLD, signal.SIG_DFL) + signal.signal(signal.SIGTSTP, signal.SIG_DFL) + signal.signal(signal.SIGTTIN, signal.SIG_DFL) + signal.signal(signal.SIGTTOU, signal.SIG_DFL) def print_error(*a: Any) -> None: @@ -354,7 +357,7 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl if tty_name: sys.__stdout__.flush() sys.__stderr__.flush() - establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) + open(establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()), 'w').close() os.close(w) if shm.unlink_on_exit: child_main(cmd, ready_fd_read) @@ -377,201 +380,217 @@ def verify_socket_creds(conn: socket.socket) -> bool: return uid == os.geteuid() and gid == os.getegid() -class SocketChild: - - def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll): - self.registered = True - self.poll = poll - self.addr = addr - self.conn = conn - self.winsize = 8 - self.poll.register(self.conn.fileno(), select.POLLIN) - self.input_buf = self.output_buf = b'' - self.fds: List[int] = [] - self.child_id = -1 +class SocketChildData: + def __init__(self) -> None: self.cwd = self.tty_name = '' - self.env: Dict[str, str] = {} self.argv: List[str] = [] - self.stdin = self.stdout = self.stderr = -1 - self.pid = -1 - self.closed = False - self.launch_msg_read = False + self.env: Dict[str, str] = {} - def unregister_from_poll(self) -> None: - if self.registered: - fd = self.conn.fileno() - if fd > -1: - self.poll.unregister(self.conn.fileno()) - self.registered = False - def read(self) -> None: - import fcntl - import termios - msg = self.conn.recv(io.DEFAULT_BUFFER_SIZE) +def fork_socket_child(child_data: SocketChildData, tty_fd: int, stdio_fds: Dict[str, int], free_non_child_resources: Callable[[], None]) -> int: + # see https://www.gnu.org/software/libc/manual/html_node/Launching-Jobs.html + child_pid = safer_fork() + if child_pid: + return child_pid + # child process + os.setpgid(0, 0) + os.tcsetpgrp(tty_fd, os.getpgid(0)) + restore_python_signal_handlers() + # the std streams fds are closed in free_non_child_resources() + for which in ('stdin', 'stdout', 'stderr'): + fd = stdio_fds[which] if stdio_fds[which] > -1 else tty_fd + os.dup2(fd, getattr(sys, which).fileno()) + free_non_child_resources() + child_main({'cwd': child_data.cwd, 'env': child_data.env, 'argv': child_data.argv}) + + +def fork_socket_child_supervisor(conn: socket.socket, free_non_child_resources: Callable[[], None]) -> None: + import array + import fcntl + import termios + global is_zygote + if safer_fork(): + conn.close() + return + is_zygote = False + os.setsid() + restore_python_signal_handlers() + free_non_child_resources() + # See https://www.gnu.org/software/libc/manual/html_node/Initializing-the-Shell.html + signal_read_fd = install_signal_handlers( + signal.SIGCHLD, signal.SIGUSR1, signal.SIGINT, signal.SIGTSTP, signal.SIGTTIN, signal.SIGTTOU, signal.SIGQUIT + )[0] + poll = select.poll() + poll.register(signal_read_fd, select.POLLIN) + from_socket_buf = b'' + to_socket_buf = b'' + keep_going = True + child_pid = -1 + socket_fd = conn.fileno() + launch_msg_read = False + os.set_blocking(socket_fd, False) + received_fds: List[int] = [] + stdio_positions = dict.fromkeys(('stdin', 'stdout', 'stderr'), -1) + stdio_fds = dict.fromkeys(('stdin', 'stdout', 'stderr'), -1) + winsize = 8 + exit_after_write = False + child_data = SocketChildData() + + def handle_signal(siginfo: SignalInfo) -> None: + nonlocal to_socket_buf, exit_after_write, child_pid + if siginfo.si_signo != signal.SIGCHLD or siginfo.si_code not in (CLD_KILLED, CLD_EXITED, CLD_STOPPED): + return + while True: + try: + pid, status = os.waitpid(-1, os.WNOHANG | os.WUNTRACED) + except ChildProcessError: + pid = 0 + if not pid: + break + if pid != child_pid: + continue + to_socket_buf += struct.pack('q', status) + if not os.WIFSTOPPED(status): + exit_after_write = True + child_pid = -1 + + def write_to_socket() -> None: + nonlocal keep_going, to_socket_buf, keep_going + buf = memoryview(to_socket_buf) + while buf: + try: + n = os.write(socket_fd, buf) + except OSError: + n = 0 + if n == 0: + keep_going = False + return + buf = buf[n:] + to_socket_buf = bytes(buf) + if exit_after_write and not to_socket_buf: + keep_going = False + + def read_winsize() -> None: + nonlocal from_socket_buf + msg = conn.recv(io.DEFAULT_BUFFER_SIZE) if not msg: return - self.input_buf += msg - data = memoryview(self.input_buf) - while len(data) >= self.winsize: - record, data = data[:self.winsize], data[self.winsize:] - with open(os.open(self.tty_name, os.O_RDWR | os.O_CLOEXEC | os.O_NOCTTY, 0), 'rb') as f: + from_socket_buf += msg + data = memoryview(from_socket_buf) + while len(data) >= winsize: + record, data = data[:winsize], data[winsize:] + with open(os.open(os.ctermid(), os.O_RDWR | os.O_CLOEXEC | os.O_NOCTTY, 0), 'rb') as f: fcntl.ioctl(f.fileno(), termios.TIOCSWINSZ, record) - self.input_buf = bytes(data) + from_socket_buf = bytes(data) - def read_launch_msg(self) -> bool: - import array - fds = array.array("i") # Array of ints + def read_launch_msg() -> bool: + nonlocal keep_going, from_socket_buf, launch_msg_read, winsize try: - msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024) + msg, ancdata, flags, addr = conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024) except OSError as e: if e.errno == errno.ENOMEM: # macOS does this when no ancilliary data is present - msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE) + msg, ancdata, flags, addr = conn.recvmsg(io.DEFAULT_BUFFER_SIZE) else: raise for cmsg_level, cmsg_type, cmsg_data in ancdata: if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + fds = array.array("i") # Array of ints # Append data, ignoring any truncated integers at the end. fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) - self.fds += list(fds) + received_fds.extend(fds) + if not msg: return False - self.input_buf += msg - while (idx := self.input_buf.find(b'\0')) > -1: - line = self.input_buf[:idx].decode('utf-8') - self.input_buf = self.input_buf[idx+1:] + from_socket_buf += msg + while (idx := from_socket_buf.find(b'\0')) > -1: + line = from_socket_buf[:idx].decode('utf-8') + from_socket_buf = from_socket_buf[idx+1:] cmd, _, payload = line.partition(':') if cmd == 'finish': - self.launch_msg_read = True - for x in self.fds: - os.set_inheritable(x, x is not self.fds[0]) + for x in received_fds: + os.set_inheritable(x, True) os.set_blocking(x, True) - if self.stdin > -1: - self.stdin = self.fds[self.stdin] - if self.stdout > -1: - self.stdout = self.fds[self.stdout] - if self.stderr > -1: - self.stderr = self.fds[self.stderr] - del self.fds[:] + for k, pos in stdio_positions.items(): + if pos > -1: + stdio_fds[k] = received_fds[pos] + del received_fds[:] return True elif cmd == 'cwd': - self.cwd = payload + child_data.cwd = payload elif cmd == 'env': k, _, v = payload.partition('=') - self.env[k] = v + child_data.env[k] = v elif cmd == 'argv': - self.argv.append(payload) - elif cmd == 'stdin': - self.stdin = int(payload) - elif cmd == 'stdout': - self.stdout = int(payload) - elif cmd == 'stderr': - self.stderr = int(payload) + child_data.argv.append(payload) + elif cmd in stdio_positions: + stdio_positions[cmd] = int(payload) elif cmd == 'tty_name': - self.tty_name = payload + child_data.tty_name = payload elif cmd == 'winsize': - self.winsize = int(payload) - + winsize = int(payload) return False - def fork(self, free_non_child_resources: Callable[[], None]) -> None: - global is_zygote - r, w = safe_pipe() - self.pid = safer_fork() - if self.pid > 0: - # master process - os.close(w) - if self.stdin > -1: - os.close(self.stdin) - self.stdin = -1 - if self.stdout > -1: - os.close(self.stdout) - self.stdout = -1 - if self.stderr > -1: - os.close(self.stderr) - self.stderr = -1 - poll = select.poll() - poll.register(r, select.POLLIN) - tuple(poll.poll()) - os.close(r) - self.handle_creation() - return - # child process - is_zygote = False - os.close(r) - os.setsid() - restore_python_signal_handlers() - if self.tty_name: - sys.__stdout__.flush() - sys.__stderr__.flush() - establish_controlling_tty( - self.tty_name, - sys.__stdin__.fileno() if self.stdin < 0 else -1, - sys.__stdout__.fileno() if self.stdout < 0 else -1, - sys.__stderr__.fileno() if self.stderr < 0 else -1) - # the std streams fds are closed in free_non_child_resources(), see - # SocketChild.close() - if self.stdin > -1: - os.dup2(self.stdin, sys.__stdin__.fileno()) - if self.stdout > -1: - os.dup2(self.stdout, sys.__stdout__.fileno()) - if self.stderr > -1: - os.dup2(self.stderr, sys.__stderr__.fileno()) - os.close(w) - free_non_child_resources() - child_main({'cwd': self.cwd, 'env': self.env, 'argv': self.argv}) - raise SystemExit(0) + def free_non_child_resources2() -> None: + for fd in received_fds: + os.close(fd) + for k, v in tuple(stdio_fds.items()): + if v > -1: + os.close(v) + stdio_fds[k] = -1 + conn.close() - def handle_stop(self, status: int) -> None: - if self.closed: - return - try: - self.conn.sendall(struct.pack('q', status)) - except OSError as e: - print_error(f'Failed to send exit status of socket child with error: {e}') + def launch_child() -> None: + nonlocal to_socket_buf, child_pid + sys.__stdout__.flush() + sys.__stderr__.flush() + tty_fd = establish_controlling_tty(child_data.tty_name) + child_pid = fork_socket_child(child_data, tty_fd, stdio_fds, free_non_child_resources2) + if child_pid: + # this is also done in the child process, but we dont + # know when, so do it here as well + os.setpgid(child_pid, child_pid) + os.tcsetpgrp(tty_fd, child_pid) + for fd in stdio_fds.values(): + if fd > -1: + os.close(fd) + os.close(tty_fd) + else: + raise SystemExit('fork_socket_child() returned in the child process') + to_socket_buf += struct.pack('q', child_pid) - def handle_death(self, status: int) -> None: - if self.closed: - return - try: - self.conn.sendall(struct.pack('q', status)) - except OSError as e: - print_error(f'Failed to send exit status of socket child with error: {e}') + def read_from_socket() -> None: + nonlocal launch_msg_read + if launch_msg_read: + read_winsize() + else: + if read_launch_msg(): + launch_msg_read = True + launch_child() - def handle_creation(self) -> bool: - if self.closed: - return False - try: - self.conn.sendall(struct.pack('q', self.pid)) - except OSError as e: - print_error(f'Failed to send pid of socket child with error: {e}') - return False - return True - - def close(self) -> None: - if self.closed: - return - self.unregister_from_poll() - self.closed = True - if is_zygote: + try: + while keep_going: + poll.register(socket_fd, select.POLLIN | (select.POLLOUT if to_socket_buf else 0)) + for fd, event in poll.poll(): + if event & error_events: + keep_going = False + break + if fd == socket_fd: + if event & select.POLLOUT: + write_to_socket() + if event & select.POLLIN: + read_from_socket() + elif fd == signal_read_fd and event & select.POLLIN: + read_signals(signal_read_fd, handle_signal) + finally: + if child_pid: # supervisor process with suppress(OSError): - self.conn.shutdown(socket.SHUT_RDWR) - with suppress(OSError): - self.conn.close() - for x in self.fds: - os.close(x) - del self.fds[:] - if self.stdin > -1: - os.close(self.stdin) - self.stdin = -1 - if self.stdout > -1: - os.close(self.stdout) - self.stdout = -1 - if self.stderr > -1: - os.close(self.stderr) - self.stderr = -1 - __del__ = close + conn.shutdown(socket.SHUT_RDWR) + with suppress(OSError): + conn.close() + + raise SystemExit(0) def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None: @@ -591,19 +610,12 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: input_buf = output_buf = child_death_buf = b'' child_ready_fds: Dict[int, int] = {} child_pid_map: Dict[int, int] = {} - socket_pid_map: Dict[int, SocketChild] = {} - socket_children: Dict[int, SocketChild] = {} child_id_counter = count() # runpy issues a warning when running modules that have already been # imported. Ignore it. warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy') prewarm() - def remove_socket_child(sc: SocketChild) -> None: - socket_children.pop(sc.conn.fileno(), None) - socket_pid_map.pop(sc.pid, None) - sc.close() - def get_all_non_child_fds() -> Iterator[int]: yield notify_child_death_fd yield stdin_fd @@ -616,8 +628,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: if fd > -1: os.close(fd) unix_socket.close() - for sc in tuple(socket_children.values()): - remove_socket_child(sc) def check_event(event: int, err_msg: str) -> None: if event & select.POLLHUP: @@ -711,27 +721,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: pid = 0 if not pid: break - if os.WIFSTOPPED(status): - sc = socket_pid_map.get(pid) - if sc is not None: - try: - sc.handle_stop(status) - except Exception: - import traceback - traceback.print_exc() - return child_id = child_pid_map.pop(pid, None) - if child_id is None: - sc = socket_pid_map.get(pid) - if sc is not None: - try: - sc.handle_death(status) - except Exception: - import traceback - traceback.print_exc() - finally: - remove_socket_child(sc) - else: + if child_id is not None: handle_child_death(child_id, pid) read_signals(signal_read_fd, handle_signal) @@ -743,35 +734,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: print_error('Connection attempted with invalid credentials ignoring') conn.close() return - sc = SocketChild(conn, addr, poll) - socket_children[sc.conn.fileno()] = sc - - def handle_socket_input(fd: int, event: int) -> None: - scq = socket_children.get(q) - if scq is None: - return - if event & select.POLLIN: - if scq.launch_msg_read: - scq.read() - else: - try: - if scq.read_launch_msg(): - scq.fork(free_non_child_resources) - socket_pid_map[scq.pid] = scq - scq.child_id = next(child_id_counter) - except OSError: - if is_zygote: - remove_socket_child(scq) - import traceback - tb = traceback.format_exc() - print_error(f'Failed to fork socket child with error: {tb}') - else: - raise - if is_zygote and (event & error_events): - if event & select.POLLHUP: - scq.unregister_from_poll() - else: - remove_socket_child(scq) + fork_socket_child_supervisor(conn, free_non_child_resources) keep_type_checker_happy = True try: @@ -791,8 +754,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: handle_notify_child_death(event) elif q == unix_socket.fileno(): handle_socket_client(event) - else: - handle_socket_input(q, event) except (KeyboardInterrupt, EOFError, BrokenPipeError): if is_zygote: raise SystemExit(1) @@ -808,8 +769,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: for fmd in child_ready_fds.values(): with suppress(OSError): os.close(fmd) - for sc in tuple(socket_children.values()): - remove_socket_child(sc) def get_socket_name(unix_socket: socket.socket) -> str: