From 9fcb8e5b6e2e419cfb0dc23e191f78b34a7aa373 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 12 Jun 2022 20:26:20 +0530 Subject: [PATCH] Close unneeded fds in forked children --- kitty/prewarm.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/kitty/prewarm.py b/kitty/prewarm.py index 258274e64..e943ac41d 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -14,7 +14,8 @@ from dataclasses import dataclass from importlib import import_module from itertools import count from typing import ( - IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Tuple, Union, cast + IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Sequence, Tuple, + Union, cast ) from kitty.constants import kitty_exe, running_in_kitty @@ -268,7 +269,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn: raise SystemExit(0) -def fork(shm_address: str) -> Tuple[int, int, int]: +def fork(shm_address: str, all_non_child_fds: Sequence[int]) -> Tuple[int, int, int]: sz = pos = 0 with SharedMemory(name=shm_address, unlink_on_exit=True) as shm: data = shm.read_data_with_size() @@ -301,6 +302,8 @@ def fork(shm_address: str) -> Tuple[int, int, int]: # child process os.close(r) os.close(ready_fd_write) + for fd in all_non_child_fds: + os.close(fd) os.setsid() signal.signal(signal.SIGUSR1, signal.SIG_DFL) tty_name = cmd.get('tty_name') @@ -340,6 +343,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy') prewarm() + def get_all_non_child_fds() -> Sequence[int]: + return [notify_child_death_fd, stdin_fd, stdout_fd] + list(child_ready_fds.values()) + list(child_death_fds) + def check_event(event: int, err_msg: str) -> None: if event & select.POLLHUP: raise SystemExit(0) @@ -365,13 +371,12 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: child_id = int(payload) cfd = child_ready_fds.pop(child_id) if cfd is not None: - os.write(cfd, b'1') os.close(cfd) elif cmd == 'quit': raise SystemExit(0) elif cmd == 'fork': try: - child_pid, child_death_fd, ready_fd_write = fork(payload) + child_pid, child_death_fd, ready_fd_write = fork(payload, get_all_non_child_fds()) except Exception as e: es = str(e).replace('\n', ' ') output_buf += f'ERR:{es}\n'.encode()