Use a signal handler instead of a pipe for child death notification

This commit is contained in:
Kovid Goyal 2022-06-13 20:40:04 +05:30
parent 7e3bd8586f
commit cf667b8c47
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -21,10 +21,13 @@ from typing import (
from kitty.constants import kitty_exe, running_in_kitty from kitty.constants import kitty_exe, running_in_kitty
from kitty.entry_points import main as main_entry_point from kitty.entry_points import main as main_entry_point
from kitty.fast_data_types import ( from kitty.fast_data_types import (
establish_controlling_tty, get_options, safe_pipe, set_options CLD_EXITED, CLD_KILLED, establish_controlling_tty, get_options,
install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe,
set_options
) )
from kitty.options.types import Options from kitty.options.types import Options
from kitty.shm import SharedMemory from kitty.shm import SharedMemory
from kitty.types import SignalInfo
from kitty.utils import log_error from kitty.utils import log_error
if TYPE_CHECKING: if TYPE_CHECKING:
@ -32,7 +35,7 @@ if TYPE_CHECKING:
error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
TIMEOUT = 4.0 TIMEOUT = 5.0
class PrewarmProcessFailed(Exception): class PrewarmProcessFailed(Exception):
@ -288,7 +291,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
raise SystemExit(0) raise SystemExit(0)
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int, int]: def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
sz = pos = 0 sz = pos = 0
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm: with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
data = shm.read_data_with_size() data = shm.read_data_with_size()
@ -312,13 +315,10 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int,
os.close(ready_fd_read) os.close(ready_fd_read)
poll = select.poll() poll = select.poll()
poll.register(r, select.POLLIN) poll.register(r, select.POLLIN)
for (fd, event) in poll.poll(): tuple(poll.poll())
if event & select.POLLIN: return child_pid, ready_fd_write
os.read(r, 1)
return child_pid, r, ready_fd_write
else:
raise ValueError('Child process pipe failed')
# child process # child process
remove_signal_handlers()
os.close(r) os.close(r)
os.close(ready_fd_write) os.close(ready_fd_write)
for fd in all_non_child_fds: for fd in all_non_child_fds:
@ -330,7 +330,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int,
sys.__stdout__.flush() sys.__stdout__.flush()
sys.__stderr__.flush() sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
os.write(w, b'1') # this will be closed on process exit and thereby used to detect child death os.close(w)
if shm.unlink_on_exit: if shm.unlink_on_exit:
child_main(cmd, ready_fd_read) child_main(cmd, ready_fd_read)
else: else:
@ -349,12 +349,13 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
os.set_blocking(notify_child_death_fd, False) os.set_blocking(notify_child_death_fd, False)
os.set_blocking(stdin_fd, False) os.set_blocking(stdin_fd, False)
os.set_blocking(stdout_fd, False) os.set_blocking(stdout_fd, False)
signal_read_fd = install_signal_handlers(signal.SIGCHLD)[0]
poll = select.poll() poll = select.poll()
poll.register(stdin_fd, select.POLLIN) poll.register(stdin_fd, select.POLLIN)
poll.register(signal_read_fd, select.POLLIN)
input_buf = output_buf = child_death_buf = b'' input_buf = output_buf = child_death_buf = b''
child_ready_fds: Dict[int, int] = {} child_ready_fds: Dict[int, int] = {}
child_death_fds: Dict[int, int] = {} child_pid_map: Dict[int, int] = {}
child_id_map: Dict[int, int] = {}
child_id_counter = count() child_id_counter = count()
self_pid = os.getpid() self_pid = os.getpid()
# runpy issues a warning when running modules that have already been # runpy issues a warning when running modules that have already been
@ -366,8 +367,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
yield notify_child_death_fd yield notify_child_death_fd
yield stdin_fd yield stdin_fd
yield stdout_fd yield stdout_fd
# the signal fds are closed by remove_signal_handlers()
yield from child_ready_fds.values() yield from child_ready_fds.values()
yield from child_death_fds.keys()
def check_event(event: int, err_msg: str) -> None: def check_event(event: int, err_msg: str) -> None:
if event & select.POLLHUP: if event & select.POLLHUP:
@ -377,7 +378,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
def handle_input(event: int) -> None: def handle_input(event: int) -> None:
nonlocal input_buf, output_buf nonlocal input_buf, output_buf
check_event(event, 'Polling of STDIN failed') check_event(event, 'Polling of input pipe failed')
if not (event & select.POLLIN): if not (event & select.POLLIN):
return return
d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE) d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE)
@ -392,31 +393,29 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
reload_kitty_config(payload) reload_kitty_config(payload)
elif cmd == 'ready': elif cmd == 'ready':
child_id = int(payload) child_id = int(payload)
cfd = child_ready_fds.pop(child_id) cfd = child_ready_fds.pop(child_id, None)
if cfd is not None: if cfd is not None:
os.close(cfd) os.close(cfd)
elif cmd == 'quit': elif cmd == 'quit':
raise SystemExit(0) raise SystemExit(0)
elif cmd == 'fork': elif cmd == 'fork':
try: try:
child_pid, child_death_fd, ready_fd_write = fork(payload, get_all_non_child_fds()) child_pid, ready_fd_write = fork(payload, get_all_non_child_fds())
except Exception as e: except Exception as e:
es = str(e).replace('\n', ' ') es = str(e).replace('\n', ' ')
output_buf += f'ERR:{es}\n'.encode() output_buf += f'ERR:{es}\n'.encode()
else: else:
if os.getpid() == self_pid: if os.getpid() == self_pid:
child_id = next(child_id_counter) child_id = next(child_id_counter)
child_id_map[child_id] = child_pid child_pid_map[child_pid] = child_id
child_ready_fds[child_id] = ready_fd_write child_ready_fds[child_id] = ready_fd_write
child_death_fds[child_death_fd] = child_id
poll.register(child_death_fd, select.POLLIN)
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode() output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
elif cmd == 'echo': elif cmd == 'echo':
output_buf += f'{payload}\n'.encode() output_buf += f'{payload}\n'.encode()
def handle_output(event: int) -> None: def handle_output(event: int) -> None:
nonlocal output_buf nonlocal output_buf
check_event(event, 'Polling of STDOUT failed') check_event(event, 'Polling of output pipe failed')
if not (event & select.POLLOUT): if not (event & select.POLLOUT):
return return
if output_buf: if output_buf:
@ -429,7 +428,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
def handle_notify_child_death(event: int) -> None: def handle_notify_child_death(event: int) -> None:
nonlocal child_death_buf nonlocal child_death_buf
check_event(event, 'Polling of notify child death fd failed') check_event(event, 'Polling of notify child death pipe failed')
if not (event & select.POLLOUT): if not (event & select.POLLOUT):
return return
if child_death_buf: if child_death_buf:
@ -440,17 +439,33 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
if not child_death_buf: if not child_death_buf:
poll.unregister(notify_child_death_fd) poll.unregister(notify_child_death_fd)
def handle_child_death(dead_child_fd: int, dead_child_id: int) -> None: def handle_child_death(dead_child_id: int, dead_child_pid: int) -> None:
nonlocal child_death_buf nonlocal child_death_buf
poll.unregister(dead_child_fd)
del child_death_fds[dead_child_fd]
xfd = child_ready_fds.pop(dead_child_id, None) xfd = child_ready_fds.pop(dead_child_id, None)
if xfd is not None: if xfd is not None:
os.close(xfd) os.close(xfd)
dead_child_pid = child_id_map.pop(dead_child_id, None) child_death_buf += f'{dead_child_pid}\n'.encode()
if dead_child_pid is not None:
wait_for_child_death(dead_child_pid) def handle_signals(event: int) -> None:
child_death_buf += f'{dead_child_pid}\n'.encode() check_event(event, 'Polling of signal pipe failed')
if not event & select.POLLIN:
return
def handle_signal(siginfo: SignalInfo) -> None:
if siginfo.si_signo != signal.SIGCHLD or siginfo.si_code not in (CLD_KILLED, CLD_EXITED):
return
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG)
except ChildProcessError:
pid = 0
if not pid:
break
child_id = child_pid_map.pop(pid, None)
if child_id is not None:
handle_child_death(child_id, pid)
read_signals(signal_read_fd, handle_signal)
try: try:
while True: while True:
@ -463,12 +478,10 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
handle_input(event) handle_input(event)
elif q == stdout_fd: elif q == stdout_fd:
handle_output(event) handle_output(event)
elif q == signal_read_fd:
handle_signals(event)
elif q == notify_child_death_fd: elif q == notify_child_death_fd:
handle_notify_child_death(event) handle_notify_child_death(event)
else:
dead_child_id = child_death_fds.get(q)
if dead_child_id is not None and event & select.POLLHUP:
handle_child_death(q, dead_child_id)
except (KeyboardInterrupt, EOFError, BrokenPipeError): except (KeyboardInterrupt, EOFError, BrokenPipeError):
if os.getpid() == self_pid: if os.getpid() == self_pid:
raise SystemExit(1) raise SystemExit(1)
@ -480,6 +493,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
raise raise
finally: finally:
if os.getpid() == self_pid: if os.getpid() == self_pid:
remove_signal_handlers()
for fmd in child_ready_fds.values(): for fmd in child_ready_fds.values():
with suppress(OSError): with suppress(OSError):
os.close(fmd) os.close(fmd)