Use a signal handler instead of a pipe for child death notification
This commit is contained in:
parent
7e3bd8586f
commit
cf667b8c47
@ -21,10 +21,13 @@ from typing import (
|
|||||||
from kitty.constants import kitty_exe, running_in_kitty
|
from kitty.constants import kitty_exe, running_in_kitty
|
||||||
from kitty.entry_points import main as main_entry_point
|
from kitty.entry_points import main as main_entry_point
|
||||||
from kitty.fast_data_types import (
|
from kitty.fast_data_types import (
|
||||||
establish_controlling_tty, get_options, safe_pipe, set_options
|
CLD_EXITED, CLD_KILLED, establish_controlling_tty, get_options,
|
||||||
|
install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe,
|
||||||
|
set_options
|
||||||
)
|
)
|
||||||
from kitty.options.types import Options
|
from kitty.options.types import Options
|
||||||
from kitty.shm import SharedMemory
|
from kitty.shm import SharedMemory
|
||||||
|
from kitty.types import SignalInfo
|
||||||
from kitty.utils import log_error
|
from kitty.utils import log_error
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -32,7 +35,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
|
error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
|
||||||
TIMEOUT = 4.0
|
TIMEOUT = 5.0
|
||||||
|
|
||||||
|
|
||||||
class PrewarmProcessFailed(Exception):
|
class PrewarmProcessFailed(Exception):
|
||||||
@ -288,7 +291,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
|||||||
raise SystemExit(0)
|
raise SystemExit(0)
|
||||||
|
|
||||||
|
|
||||||
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int, int]:
|
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
||||||
sz = pos = 0
|
sz = pos = 0
|
||||||
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
|
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
|
||||||
data = shm.read_data_with_size()
|
data = shm.read_data_with_size()
|
||||||
@ -312,13 +315,10 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int,
|
|||||||
os.close(ready_fd_read)
|
os.close(ready_fd_read)
|
||||||
poll = select.poll()
|
poll = select.poll()
|
||||||
poll.register(r, select.POLLIN)
|
poll.register(r, select.POLLIN)
|
||||||
for (fd, event) in poll.poll():
|
tuple(poll.poll())
|
||||||
if event & select.POLLIN:
|
return child_pid, ready_fd_write
|
||||||
os.read(r, 1)
|
|
||||||
return child_pid, r, ready_fd_write
|
|
||||||
else:
|
|
||||||
raise ValueError('Child process pipe failed')
|
|
||||||
# child process
|
# child process
|
||||||
|
remove_signal_handlers()
|
||||||
os.close(r)
|
os.close(r)
|
||||||
os.close(ready_fd_write)
|
os.close(ready_fd_write)
|
||||||
for fd in all_non_child_fds:
|
for fd in all_non_child_fds:
|
||||||
@ -330,7 +330,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int,
|
|||||||
sys.__stdout__.flush()
|
sys.__stdout__.flush()
|
||||||
sys.__stderr__.flush()
|
sys.__stderr__.flush()
|
||||||
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
||||||
os.write(w, b'1') # this will be closed on process exit and thereby used to detect child death
|
os.close(w)
|
||||||
if shm.unlink_on_exit:
|
if shm.unlink_on_exit:
|
||||||
child_main(cmd, ready_fd_read)
|
child_main(cmd, ready_fd_read)
|
||||||
else:
|
else:
|
||||||
@ -349,12 +349,13 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
|||||||
os.set_blocking(notify_child_death_fd, False)
|
os.set_blocking(notify_child_death_fd, False)
|
||||||
os.set_blocking(stdin_fd, False)
|
os.set_blocking(stdin_fd, False)
|
||||||
os.set_blocking(stdout_fd, False)
|
os.set_blocking(stdout_fd, False)
|
||||||
|
signal_read_fd = install_signal_handlers(signal.SIGCHLD)[0]
|
||||||
poll = select.poll()
|
poll = select.poll()
|
||||||
poll.register(stdin_fd, select.POLLIN)
|
poll.register(stdin_fd, select.POLLIN)
|
||||||
|
poll.register(signal_read_fd, select.POLLIN)
|
||||||
input_buf = output_buf = child_death_buf = b''
|
input_buf = output_buf = child_death_buf = b''
|
||||||
child_ready_fds: Dict[int, int] = {}
|
child_ready_fds: Dict[int, int] = {}
|
||||||
child_death_fds: Dict[int, int] = {}
|
child_pid_map: Dict[int, int] = {}
|
||||||
child_id_map: Dict[int, int] = {}
|
|
||||||
child_id_counter = count()
|
child_id_counter = count()
|
||||||
self_pid = os.getpid()
|
self_pid = os.getpid()
|
||||||
# runpy issues a warning when running modules that have already been
|
# runpy issues a warning when running modules that have already been
|
||||||
@ -366,8 +367,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
|||||||
yield notify_child_death_fd
|
yield notify_child_death_fd
|
||||||
yield stdin_fd
|
yield stdin_fd
|
||||||
yield stdout_fd
|
yield stdout_fd
|
||||||
|
# the signal fds are closed by remove_signal_handlers()
|
||||||
yield from child_ready_fds.values()
|
yield from child_ready_fds.values()
|
||||||
yield from child_death_fds.keys()
|
|
||||||
|
|
||||||
def check_event(event: int, err_msg: str) -> None:
|
def check_event(event: int, err_msg: str) -> None:
|
||||||
if event & select.POLLHUP:
|
if event & select.POLLHUP:
|
||||||
@ -377,7 +378,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
|||||||
|
|
||||||
def handle_input(event: int) -> None:
|
def handle_input(event: int) -> None:
|
||||||
nonlocal input_buf, output_buf
|
nonlocal input_buf, output_buf
|
||||||
check_event(event, 'Polling of STDIN failed')
|
check_event(event, 'Polling of input pipe failed')
|
||||||
if not (event & select.POLLIN):
|
if not (event & select.POLLIN):
|
||||||
return
|
return
|
||||||
d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE)
|
d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE)
|
||||||
@ -392,31 +393,29 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
|||||||
reload_kitty_config(payload)
|
reload_kitty_config(payload)
|
||||||
elif cmd == 'ready':
|
elif cmd == 'ready':
|
||||||
child_id = int(payload)
|
child_id = int(payload)
|
||||||
cfd = child_ready_fds.pop(child_id)
|
cfd = child_ready_fds.pop(child_id, None)
|
||||||
if cfd is not None:
|
if cfd is not None:
|
||||||
os.close(cfd)
|
os.close(cfd)
|
||||||
elif cmd == 'quit':
|
elif cmd == 'quit':
|
||||||
raise SystemExit(0)
|
raise SystemExit(0)
|
||||||
elif cmd == 'fork':
|
elif cmd == 'fork':
|
||||||
try:
|
try:
|
||||||
child_pid, child_death_fd, ready_fd_write = fork(payload, get_all_non_child_fds())
|
child_pid, ready_fd_write = fork(payload, get_all_non_child_fds())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
es = str(e).replace('\n', ' ')
|
es = str(e).replace('\n', ' ')
|
||||||
output_buf += f'ERR:{es}\n'.encode()
|
output_buf += f'ERR:{es}\n'.encode()
|
||||||
else:
|
else:
|
||||||
if os.getpid() == self_pid:
|
if os.getpid() == self_pid:
|
||||||
child_id = next(child_id_counter)
|
child_id = next(child_id_counter)
|
||||||
child_id_map[child_id] = child_pid
|
child_pid_map[child_pid] = child_id
|
||||||
child_ready_fds[child_id] = ready_fd_write
|
child_ready_fds[child_id] = ready_fd_write
|
||||||
child_death_fds[child_death_fd] = child_id
|
|
||||||
poll.register(child_death_fd, select.POLLIN)
|
|
||||||
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
|
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
|
||||||
elif cmd == 'echo':
|
elif cmd == 'echo':
|
||||||
output_buf += f'{payload}\n'.encode()
|
output_buf += f'{payload}\n'.encode()
|
||||||
|
|
||||||
def handle_output(event: int) -> None:
|
def handle_output(event: int) -> None:
|
||||||
nonlocal output_buf
|
nonlocal output_buf
|
||||||
check_event(event, 'Polling of STDOUT failed')
|
check_event(event, 'Polling of output pipe failed')
|
||||||
if not (event & select.POLLOUT):
|
if not (event & select.POLLOUT):
|
||||||
return
|
return
|
||||||
if output_buf:
|
if output_buf:
|
||||||
@ -429,7 +428,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
|||||||
|
|
||||||
def handle_notify_child_death(event: int) -> None:
|
def handle_notify_child_death(event: int) -> None:
|
||||||
nonlocal child_death_buf
|
nonlocal child_death_buf
|
||||||
check_event(event, 'Polling of notify child death fd failed')
|
check_event(event, 'Polling of notify child death pipe failed')
|
||||||
if not (event & select.POLLOUT):
|
if not (event & select.POLLOUT):
|
||||||
return
|
return
|
||||||
if child_death_buf:
|
if child_death_buf:
|
||||||
@ -440,18 +439,34 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
|||||||
if not child_death_buf:
|
if not child_death_buf:
|
||||||
poll.unregister(notify_child_death_fd)
|
poll.unregister(notify_child_death_fd)
|
||||||
|
|
||||||
def handle_child_death(dead_child_fd: int, dead_child_id: int) -> None:
|
def handle_child_death(dead_child_id: int, dead_child_pid: int) -> None:
|
||||||
nonlocal child_death_buf
|
nonlocal child_death_buf
|
||||||
poll.unregister(dead_child_fd)
|
|
||||||
del child_death_fds[dead_child_fd]
|
|
||||||
xfd = child_ready_fds.pop(dead_child_id, None)
|
xfd = child_ready_fds.pop(dead_child_id, None)
|
||||||
if xfd is not None:
|
if xfd is not None:
|
||||||
os.close(xfd)
|
os.close(xfd)
|
||||||
dead_child_pid = child_id_map.pop(dead_child_id, None)
|
|
||||||
if dead_child_pid is not None:
|
|
||||||
wait_for_child_death(dead_child_pid)
|
|
||||||
child_death_buf += f'{dead_child_pid}\n'.encode()
|
child_death_buf += f'{dead_child_pid}\n'.encode()
|
||||||
|
|
||||||
|
def handle_signals(event: int) -> None:
|
||||||
|
check_event(event, 'Polling of signal pipe failed')
|
||||||
|
if not event & select.POLLIN:
|
||||||
|
return
|
||||||
|
|
||||||
|
def handle_signal(siginfo: SignalInfo) -> None:
|
||||||
|
if siginfo.si_signo != signal.SIGCHLD or siginfo.si_code not in (CLD_KILLED, CLD_EXITED):
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
pid, status = os.waitpid(-1, os.WNOHANG)
|
||||||
|
except ChildProcessError:
|
||||||
|
pid = 0
|
||||||
|
if not pid:
|
||||||
|
break
|
||||||
|
child_id = child_pid_map.pop(pid, None)
|
||||||
|
if child_id is not None:
|
||||||
|
handle_child_death(child_id, pid)
|
||||||
|
|
||||||
|
read_signals(signal_read_fd, handle_signal)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if output_buf:
|
if output_buf:
|
||||||
@ -463,12 +478,10 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
|||||||
handle_input(event)
|
handle_input(event)
|
||||||
elif q == stdout_fd:
|
elif q == stdout_fd:
|
||||||
handle_output(event)
|
handle_output(event)
|
||||||
|
elif q == signal_read_fd:
|
||||||
|
handle_signals(event)
|
||||||
elif q == notify_child_death_fd:
|
elif q == notify_child_death_fd:
|
||||||
handle_notify_child_death(event)
|
handle_notify_child_death(event)
|
||||||
else:
|
|
||||||
dead_child_id = child_death_fds.get(q)
|
|
||||||
if dead_child_id is not None and event & select.POLLHUP:
|
|
||||||
handle_child_death(q, dead_child_id)
|
|
||||||
except (KeyboardInterrupt, EOFError, BrokenPipeError):
|
except (KeyboardInterrupt, EOFError, BrokenPipeError):
|
||||||
if os.getpid() == self_pid:
|
if os.getpid() == self_pid:
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
@ -480,6 +493,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
|||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if os.getpid() == self_pid:
|
if os.getpid() == self_pid:
|
||||||
|
remove_signal_handlers()
|
||||||
for fmd in child_ready_fds.values():
|
for fmd in child_ready_fds.values():
|
||||||
with suppress(OSError):
|
with suppress(OSError):
|
||||||
os.close(fmd)
|
os.close(fmd)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user