Close unneeded fds in forked children

This commit is contained in:
Kovid Goyal 2022-06-12 20:26:20 +05:30
parent a1a637c7f1
commit 9fcb8e5b6e
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -14,7 +14,8 @@ from dataclasses import dataclass
from importlib import import_module
from itertools import count
from typing import (
IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Tuple, Union, cast
IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Sequence, Tuple,
Union, cast
)
from kitty.constants import kitty_exe, running_in_kitty
@ -268,7 +269,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
raise SystemExit(0)
def fork(shm_address: str) -> Tuple[int, int, int]:
def fork(shm_address: str, all_non_child_fds: Sequence[int]) -> Tuple[int, int, int]:
sz = pos = 0
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
data = shm.read_data_with_size()
@ -301,6 +302,8 @@ def fork(shm_address: str) -> Tuple[int, int, int]:
# child process
os.close(r)
os.close(ready_fd_write)
for fd in all_non_child_fds:
os.close(fd)
os.setsid()
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
tty_name = cmd.get('tty_name')
@ -340,6 +343,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy')
prewarm()
def get_all_non_child_fds() -> Sequence[int]:
return [notify_child_death_fd, stdin_fd, stdout_fd] + list(child_ready_fds.values()) + list(child_death_fds)
def check_event(event: int, err_msg: str) -> None:
if event & select.POLLHUP:
raise SystemExit(0)
@ -365,13 +371,12 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
child_id = int(payload)
cfd = child_ready_fds.pop(child_id)
if cfd is not None:
os.write(cfd, b'1')
os.close(cfd)
elif cmd == 'quit':
raise SystemExit(0)
elif cmd == 'fork':
try:
child_pid, child_death_fd, ready_fd_write = fork(payload)
child_pid, child_death_fd, ready_fd_write = fork(payload, get_all_non_child_fds())
except Exception as e:
es = str(e).replace('\n', ' ')
output_buf += f'ERR:{es}\n'.encode()