Use a signal handler instead of a pipe for child death notification
This commit is contained in:
parent
7e3bd8586f
commit
cf667b8c47
@ -21,10 +21,13 @@ from typing import (
|
||||
from kitty.constants import kitty_exe, running_in_kitty
|
||||
from kitty.entry_points import main as main_entry_point
|
||||
from kitty.fast_data_types import (
|
||||
establish_controlling_tty, get_options, safe_pipe, set_options
|
||||
CLD_EXITED, CLD_KILLED, establish_controlling_tty, get_options,
|
||||
install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe,
|
||||
set_options
|
||||
)
|
||||
from kitty.options.types import Options
|
||||
from kitty.shm import SharedMemory
|
||||
from kitty.types import SignalInfo
|
||||
from kitty.utils import log_error
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -32,7 +35,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
|
||||
TIMEOUT = 4.0
|
||||
TIMEOUT = 5.0
|
||||
|
||||
|
||||
class PrewarmProcessFailed(Exception):
|
||||
@ -288,7 +291,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
||||
raise SystemExit(0)
|
||||
|
||||
|
||||
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int, int]:
|
||||
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
||||
sz = pos = 0
|
||||
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
|
||||
data = shm.read_data_with_size()
|
||||
@ -312,13 +315,10 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int,
|
||||
os.close(ready_fd_read)
|
||||
poll = select.poll()
|
||||
poll.register(r, select.POLLIN)
|
||||
for (fd, event) in poll.poll():
|
||||
if event & select.POLLIN:
|
||||
os.read(r, 1)
|
||||
return child_pid, r, ready_fd_write
|
||||
else:
|
||||
raise ValueError('Child process pipe failed')
|
||||
tuple(poll.poll())
|
||||
return child_pid, ready_fd_write
|
||||
# child process
|
||||
remove_signal_handlers()
|
||||
os.close(r)
|
||||
os.close(ready_fd_write)
|
||||
for fd in all_non_child_fds:
|
||||
@ -330,7 +330,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int,
|
||||
sys.__stdout__.flush()
|
||||
sys.__stderr__.flush()
|
||||
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
||||
os.write(w, b'1') # this will be closed on process exit and thereby used to detect child death
|
||||
os.close(w)
|
||||
if shm.unlink_on_exit:
|
||||
child_main(cmd, ready_fd_read)
|
||||
else:
|
||||
@ -349,12 +349,13 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
os.set_blocking(notify_child_death_fd, False)
|
||||
os.set_blocking(stdin_fd, False)
|
||||
os.set_blocking(stdout_fd, False)
|
||||
signal_read_fd = install_signal_handlers(signal.SIGCHLD)[0]
|
||||
poll = select.poll()
|
||||
poll.register(stdin_fd, select.POLLIN)
|
||||
poll.register(signal_read_fd, select.POLLIN)
|
||||
input_buf = output_buf = child_death_buf = b''
|
||||
child_ready_fds: Dict[int, int] = {}
|
||||
child_death_fds: Dict[int, int] = {}
|
||||
child_id_map: Dict[int, int] = {}
|
||||
child_pid_map: Dict[int, int] = {}
|
||||
child_id_counter = count()
|
||||
self_pid = os.getpid()
|
||||
# runpy issues a warning when running modules that have already been
|
||||
@ -366,8 +367,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
yield notify_child_death_fd
|
||||
yield stdin_fd
|
||||
yield stdout_fd
|
||||
# the signal fds are closed by remove_signal_handlers()
|
||||
yield from child_ready_fds.values()
|
||||
yield from child_death_fds.keys()
|
||||
|
||||
def check_event(event: int, err_msg: str) -> None:
|
||||
if event & select.POLLHUP:
|
||||
@ -377,7 +378,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
|
||||
def handle_input(event: int) -> None:
|
||||
nonlocal input_buf, output_buf
|
||||
check_event(event, 'Polling of STDIN failed')
|
||||
check_event(event, 'Polling of input pipe failed')
|
||||
if not (event & select.POLLIN):
|
||||
return
|
||||
d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE)
|
||||
@ -392,31 +393,29 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
reload_kitty_config(payload)
|
||||
elif cmd == 'ready':
|
||||
child_id = int(payload)
|
||||
cfd = child_ready_fds.pop(child_id)
|
||||
cfd = child_ready_fds.pop(child_id, None)
|
||||
if cfd is not None:
|
||||
os.close(cfd)
|
||||
elif cmd == 'quit':
|
||||
raise SystemExit(0)
|
||||
elif cmd == 'fork':
|
||||
try:
|
||||
child_pid, child_death_fd, ready_fd_write = fork(payload, get_all_non_child_fds())
|
||||
child_pid, ready_fd_write = fork(payload, get_all_non_child_fds())
|
||||
except Exception as e:
|
||||
es = str(e).replace('\n', ' ')
|
||||
output_buf += f'ERR:{es}\n'.encode()
|
||||
else:
|
||||
if os.getpid() == self_pid:
|
||||
child_id = next(child_id_counter)
|
||||
child_id_map[child_id] = child_pid
|
||||
child_pid_map[child_pid] = child_id
|
||||
child_ready_fds[child_id] = ready_fd_write
|
||||
child_death_fds[child_death_fd] = child_id
|
||||
poll.register(child_death_fd, select.POLLIN)
|
||||
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
|
||||
elif cmd == 'echo':
|
||||
output_buf += f'{payload}\n'.encode()
|
||||
|
||||
def handle_output(event: int) -> None:
|
||||
nonlocal output_buf
|
||||
check_event(event, 'Polling of STDOUT failed')
|
||||
check_event(event, 'Polling of output pipe failed')
|
||||
if not (event & select.POLLOUT):
|
||||
return
|
||||
if output_buf:
|
||||
@ -429,7 +428,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
|
||||
def handle_notify_child_death(event: int) -> None:
|
||||
nonlocal child_death_buf
|
||||
check_event(event, 'Polling of notify child death fd failed')
|
||||
check_event(event, 'Polling of notify child death pipe failed')
|
||||
if not (event & select.POLLOUT):
|
||||
return
|
||||
if child_death_buf:
|
||||
@ -440,18 +439,34 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
if not child_death_buf:
|
||||
poll.unregister(notify_child_death_fd)
|
||||
|
||||
def handle_child_death(dead_child_fd: int, dead_child_id: int) -> None:
|
||||
def handle_child_death(dead_child_id: int, dead_child_pid: int) -> None:
|
||||
nonlocal child_death_buf
|
||||
poll.unregister(dead_child_fd)
|
||||
del child_death_fds[dead_child_fd]
|
||||
xfd = child_ready_fds.pop(dead_child_id, None)
|
||||
if xfd is not None:
|
||||
os.close(xfd)
|
||||
dead_child_pid = child_id_map.pop(dead_child_id, None)
|
||||
if dead_child_pid is not None:
|
||||
wait_for_child_death(dead_child_pid)
|
||||
child_death_buf += f'{dead_child_pid}\n'.encode()
|
||||
|
||||
def handle_signals(event: int) -> None:
|
||||
check_event(event, 'Polling of signal pipe failed')
|
||||
if not event & select.POLLIN:
|
||||
return
|
||||
|
||||
def handle_signal(siginfo: SignalInfo) -> None:
|
||||
if siginfo.si_signo != signal.SIGCHLD or siginfo.si_code not in (CLD_KILLED, CLD_EXITED):
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
pid, status = os.waitpid(-1, os.WNOHANG)
|
||||
except ChildProcessError:
|
||||
pid = 0
|
||||
if not pid:
|
||||
break
|
||||
child_id = child_pid_map.pop(pid, None)
|
||||
if child_id is not None:
|
||||
handle_child_death(child_id, pid)
|
||||
|
||||
read_signals(signal_read_fd, handle_signal)
|
||||
|
||||
try:
|
||||
while True:
|
||||
if output_buf:
|
||||
@ -463,12 +478,10 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
handle_input(event)
|
||||
elif q == stdout_fd:
|
||||
handle_output(event)
|
||||
elif q == signal_read_fd:
|
||||
handle_signals(event)
|
||||
elif q == notify_child_death_fd:
|
||||
handle_notify_child_death(event)
|
||||
else:
|
||||
dead_child_id = child_death_fds.get(q)
|
||||
if dead_child_id is not None and event & select.POLLHUP:
|
||||
handle_child_death(q, dead_child_id)
|
||||
except (KeyboardInterrupt, EOFError, BrokenPipeError):
|
||||
if os.getpid() == self_pid:
|
||||
raise SystemExit(1)
|
||||
@ -480,6 +493,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
raise
|
||||
finally:
|
||||
if os.getpid() == self_pid:
|
||||
remove_signal_handlers()
|
||||
for fmd in child_ready_fds.values():
|
||||
with suppress(OSError):
|
||||
os.close(fmd)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user