From dbb084da7ae091fff939cb8e724576642d775cef Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 3 Jul 2022 21:42:34 +0530 Subject: [PATCH] Cleanup closing of socket child --- kitty/prewarm.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/kitty/prewarm.py b/kitty/prewarm.py index bc1991c6f..36bd99bc6 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -362,8 +362,11 @@ class SocketClosed(Exception): class SocketChild: - def __init__(self, conn: socket.socket, addr: bytes): + def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll): self.fd = conn.fileno() + poll.register(self.fd, select.POLLIN) + self.registered = True + self.poll = poll self.addr = addr self.conn = conn self.input_buf = self.output_buf = b'' @@ -375,6 +378,11 @@ class SocketChild: self.stdin = self.stdout = self.stderr = -1 self.pid = -1 + def unregister_from_poll(self) -> None: + if self.registered: + self.poll.unregister(self.fd) + self.registered = False + def read(self) -> bool: import array fds = array.array("i") # Array of ints @@ -511,6 +519,12 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy') prewarm() + def remove_socket_child(sc: SocketChild) -> None: + socket_children.pop(sc.fd, None) + sc.unregister_from_poll() + socket_pid_map.pop(sc.pid, None) + sc.conn.close() + def get_all_non_child_fds() -> Iterator[int]: yield notify_child_death_fd yield stdin_fd @@ -620,9 +634,12 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: break child_id = child_pid_map.pop(pid, None) if child_id is None: - sc = socket_pid_map.pop(pid, None) + sc = socket_pid_map.get(pid) if sc is not None: - sc.handle_death(status) + try: + sc.handle_death(status) + finally: + remove_socket_child(sc) else: handle_child_death(child_id, pid) @@ -631,9 +648,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: def handle_socket_client(event: int) -> None: check_event(event, 'UNIX socket fd listener failed') conn, addr = unix_socket.accept() - sc = SocketChild(conn, addr) + sc = SocketChild(conn, addr, poll) socket_children[sc.fd] = sc - poll.register(sc.fd, select.POLLIN) def handle_socket_launch(fd: int, event: int) -> None: scq = socket_children.get(q) @@ -642,20 +658,21 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: if event & select.POLLIN: try: if scq.read(): - poll.unregister(scq.fd) + scq.unregister_from_poll() scq.fork(get_all_non_child_fds()) socket_pid_map[scq.pid] = scq scq.child_id = next(child_id_counter) except SocketClosed: - socket_children.pop(q, None) + if is_zygote: + remove_socket_child(scq) except OSError as e: if is_zygote: - socket_children.pop(q, None) + remove_socket_child(scq) print_error(f'Failed to fork socket child with error: {e}') else: raise - if event & error_events: - socket_children.pop(q, None) + if is_zygote and (event & error_events): + remove_socket_child(scq) keep_type_checker_happy = True try: