diff --git a/kitty/prewarm.py b/kitty/prewarm.py index 35ca8da43..2d6a88c64 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -468,6 +468,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: input_buf = output_buf = child_death_buf = b'' child_ready_fds: Dict[int, int] = {} child_pid_map: Dict[int, int] = {} + socket_pid_map: Dict[int, SocketChild] = {} socket_children: Dict[int, SocketChild] = {} child_id_counter = count() self_pid = os.getpid() @@ -584,10 +585,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: break child_id = child_pid_map.pop(pid, None) if child_id is None: - for sc in socket_children.values(): - if sc.pid == pid: - sc.handle_death(status) - break + sc = socket_pid_map.pop(pid, None) + if sc is not None: + sc.handle_death(status) else: handle_child_death(child_id, pid) @@ -624,8 +624,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: try: if scq.read(): poll.unregister(scq.fd) - scq.conn.shutdown(socket.SHUT_RD) scq.fork(get_all_non_child_fds()) + socket_pid_map[scq.pid] = scq scq.child_id = next(child_id_counter) except SocketClosed: socket_children.pop(q)