kitty/kitty/prewarm.py

879 lines
31 KiB
Python

#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import errno
import fcntl
import io
import json
import os
import select
import signal
import socket
import struct
import sys
import termios
import time
import traceback
import warnings
from contextlib import suppress
from dataclasses import dataclass
from importlib import import_module
from itertools import count
from typing import (
IO, TYPE_CHECKING, Any, Callable, Dict, Iterator, List, NoReturn, Optional,
Tuple, TypeVar, Union, cast
)
from kitty.constants import kitty_exe, running_in_kitty
from kitty.entry_points import main as main_entry_point
from kitty.fast_data_types import (
CLD_EXITED, CLD_KILLED, CLD_STOPPED, get_options, getpeereid,
install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe,
set_options, set_use_os_log
)
from kitty.options.types import Options
from kitty.shm import SharedMemory
from kitty.types import SignalInfo
from kitty.utils import log_error, random_unix_socket, safer_fork
if TYPE_CHECKING:
from _typeshed import ReadableBuffer, WriteableBuffer
error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
TIMEOUT = 15.0 if os.environ.get('CI') == 'true' else 5.0
def restore_python_signal_handlers() -> None:
remove_signal_handlers()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
def print_error(*a: Any) -> None:
log_error('Prewarm zygote:', *a)
class PrewarmProcessFailed(Exception):
pass
@dataclass
class Child:
child_id: int
child_process_pid: int
def wait_for_child_death(child_pid: int, timeout: float = 1, options: int = 0) -> Optional[int]:
st = time.monotonic()
while not timeout or time.monotonic() - st < timeout:
try:
pid, status = os.waitpid(child_pid, options | os.WNOHANG)
except ChildProcessError:
return 0
else:
if pid == child_pid:
return status
if not timeout:
break
time.sleep(0.01)
return None
class PrewarmProcess:
def __init__(
self,
prewarm_process_pid: int,
to_prewarm_stdin: int,
from_prewarm_stdout: int,
from_prewarm_death_notify: int,
unix_socket_name: str,
) -> None:
self.children: Dict[int, Child] = {}
self.worker_pid = prewarm_process_pid
self.from_prewarm_death_notify = from_prewarm_death_notify
self.write_to_process_fd = to_prewarm_stdin
self.read_from_process_fd = from_prewarm_stdout
self.poll = select.poll()
self.poll.register(self.read_from_process_fd, select.POLLIN)
self.unix_socket_name = unix_socket_name
def socket_env_var(self) -> str:
return f'{os.geteuid()}:{os.getegid()}:{self.unix_socket_name}'
def take_from_worker_fd(self, create_file: bool = False) -> int:
if create_file:
os.set_blocking(self.from_prewarm_death_notify, True)
self.from_worker = open(self.from_prewarm_death_notify, mode='r', closefd=True)
self.from_prewarm_death_notify = -1
return -1
ans, self.from_prewarm_death_notify = self.from_prewarm_death_notify, -1
return ans
def __del__(self) -> None:
if self.write_to_process_fd > -1:
safe_close(self.write_to_process_fd)
self.write_to_process_fd = -1
if self.from_prewarm_death_notify > -1:
safe_close(self.from_prewarm_death_notify)
self.from_prewarm_death_notify = -1
if self.read_from_process_fd > -1:
safe_close(self.read_from_process_fd)
self.read_from_process_fd = -1
if hasattr(self, 'from_worker'):
self.from_worker.close()
del self.from_worker
if self.worker_pid > 0:
if wait_for_child_death(self.worker_pid) is None:
log_error('Prewarm process failed to quit gracefully, killing it')
os.kill(self.worker_pid, signal.SIGKILL)
os.waitpid(self.worker_pid, 0)
def poll_to_send(self, yes: bool = True) -> None:
if yes:
self.poll.register(self.write_to_process_fd, select.POLLOUT)
else:
self.poll.unregister(self.write_to_process_fd)
def reload_kitty_config(self, opts: Optional[Options] = None) -> None:
if opts is None:
opts = get_options()
data = json.dumps({'paths': opts.config_paths, 'overrides': opts.config_overrides})
if self.write_to_process_fd > -1:
self.send_to_prewarm_process(f'reload_kitty_config:{data}\n')
def __call__(
self,
tty_fd: int,
argv: List[str],
cwd: str = '',
env: Optional[Dict[str, str]] = None,
stdin_data: Optional[Union[str, bytes]] = None,
timeout: float = TIMEOUT,
) -> Child:
tty_name = os.ttyname(tty_fd)
if isinstance(stdin_data, str):
stdin_data = stdin_data.encode()
if env is None:
env = dict(os.environ)
cmd: Dict[str, Union[int, List[str], str, Dict[str, str]]] = {
'tty_name': tty_name, 'cwd': cwd or os.getcwd(), 'argv': argv, 'env': env,
}
total_size = 0
if stdin_data is not None:
cmd['stdin_size'] = len(stdin_data)
total_size += len(stdin_data)
data = json.dumps(cmd).encode()
total_size += len(data) + SharedMemory.num_bytes_for_size
with SharedMemory(size=total_size, unlink_on_exit=True) as shm:
shm.write_data_with_size(data)
if stdin_data:
shm.write(stdin_data)
shm.flush()
self.send_to_prewarm_process(f'fork:{shm.name}\n')
input_buf = b''
st = time.monotonic()
while time.monotonic() - st < timeout:
for (fd, event) in self.poll.poll(2):
if event & error_events:
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
if fd == self.read_from_process_fd and event & select.POLLIN:
d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE)
input_buf += d
while (idx := input_buf.find(b'\n')) > -1:
line = input_buf[:idx].decode()
input_buf = input_buf[idx+1:]
if line.startswith('CHILD:'):
_, cid, pid = line.split(':')
child = self.add_child(int(cid), int(pid))
shm.unlink_on_exit = False
return child
if line.startswith('ERR:'):
raise PrewarmProcessFailed(line.split(':', 1)[-1])
raise PrewarmProcessFailed('Timed out waiting for I/O with prewarm process')
def add_child(self, child_id: int, pid: int) -> Child:
self.children[child_id] = c = Child(child_id, pid)
return c
def send_to_prewarm_process(self, output_buf: Union[str, bytes] = b'', timeout: float = TIMEOUT) -> None:
if isinstance(output_buf, str):
output_buf = output_buf.encode()
st = time.monotonic()
while time.monotonic() - st < timeout and output_buf:
self.poll_to_send(bool(output_buf))
for (fd, event) in self.poll.poll(2):
if event & error_events:
raise PrewarmProcessFailed(f'Failed doing I/O with prewarm process: {event}')
if fd == self.write_to_process_fd and event & select.POLLOUT:
n = os.write(self.write_to_process_fd, output_buf)
output_buf = output_buf[n:]
self.poll_to_send(False)
if output_buf:
raise PrewarmProcessFailed('Timed out waiting to write to prewarm process')
def mark_child_as_ready(self, child_id: int) -> bool:
c = self.children.pop(child_id, None)
if c is None:
return False
self.send_to_prewarm_process(f'ready:{child_id}\n')
return True
def reload_kitty_config(payload: str) -> None:
d = json.loads(payload)
from kittens.tui.utils import set_kitty_opts
set_kitty_opts(paths=d['paths'], overrides=d['overrides'])
def prewarm() -> None:
from kittens.runner import all_kitten_names
for kitten in all_kitten_names():
with suppress(Exception):
import_module(f'kittens.{kitten}.main')
import_module('kitty.complete')
class MemoryViewReadWrapperBytes(io.BufferedIOBase):
def __init__(self, mw: memoryview):
self.mw = mw
self.pos = 0
def detach(self) -> io.RawIOBase:
raise io.UnsupportedOperation('detach() not supported')
def read(self, size: Optional[int] = -1) -> bytes:
if size is None or size < 0:
size = max(0, len(self.mw) - self.pos)
oldpos = self.pos
self.pos = min(len(self.mw), self.pos + size)
if self.pos <= oldpos:
return b''
return bytes(self.mw[oldpos:self.pos])
def readinto(self, b: 'WriteableBuffer') -> int:
if not isinstance(b, memoryview):
b = memoryview(b)
b = b.cast('B')
data = self.read(len(b))
n = len(data)
b[:n] = data
return n
readinto1 = readinto
def readall(self) -> bytes:
return self.read()
def write(self, b: 'ReadableBuffer') -> int:
raise io.UnsupportedOperation('readonly stream')
def readable(self) -> bool:
return True
class MemoryViewReadWrapper(io.TextIOWrapper):
def __init__(self, mw: memoryview):
super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace')
parent_tty_name = ''
is_zygote = True
def debug(*a: Any) -> None:
if parent_tty_name:
with open(parent_tty_name, 'w') as f:
print(*a, file=f)
def child_main(cmd: Dict[str, Any], ready_fd: int = -1, prewarm_type: str = 'direct') -> NoReturn:
getattr(sys, 'kitty_run_data')['prewarmed'] = prewarm_type
cwd = cmd.get('cwd')
if cwd:
with suppress(OSError):
os.chdir(cwd)
env = cmd.get('env')
if env is not None:
os.environ.clear()
os.environ.update(env)
argv = cmd.get('argv')
if argv:
sys.argv = list(argv)
if ready_fd > -1:
poll = select.poll()
poll.register(ready_fd, select.POLLIN)
tuple(poll.poll())
safe_close(ready_fd)
main_entry_point()
raise SystemExit(0)
def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tuple[int, int]:
global is_zygote
sz = pos = 0
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
data = shm.read_data_with_size()
cmd = json.loads(data)
sz = cmd.get('stdin_size', 0)
if sz:
pos = shm.tell()
shm.unlink_on_exit = False
r, w = safe_pipe()
ready_fd_read, ready_fd_write = safe_pipe()
try:
child_pid = safer_fork()
except OSError:
safe_close(r)
safe_close(w)
safe_close(ready_fd_read)
safe_close(ready_fd_write)
if sz:
with SharedMemory(shm_address, unlink_on_exit=True):
pass
raise
if child_pid:
# master process
safe_close(w)
safe_close(ready_fd_read)
poll = select.poll()
poll.register(r, select.POLLIN)
tuple(poll.poll())
safe_close(r)
return child_pid, ready_fd_write
# child process
is_zygote = False
restore_python_signal_handlers()
safe_close(r)
safe_close(ready_fd_write)
free_non_child_resources()
os.setsid()
tty_name = cmd.get('tty_name')
if tty_name:
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
safe_close(w)
if shm.unlink_on_exit:
child_main(cmd, ready_fd_read)
else:
with SharedMemory(shm_address, unlink_on_exit=True) as shm:
stdin_data = memoryview(shm.mmap)[pos:pos + sz]
if stdin_data:
sys.stdin = MemoryViewReadWrapper(stdin_data)
try:
child_main(cmd, ready_fd_read)
finally:
stdin_data.release()
sys.stdin = sys.__stdin__
return 0, -1 # type: ignore
def verify_socket_creds(conn: socket.socket) -> bool:
# needed as abstract unix sockets used on Linux have no permissions and
# older BSDs ignore socket file permissions
uid, gid = getpeereid(conn.fileno())
return uid == os.geteuid() and gid == os.getegid()
class SocketChildData:
def __init__(self) -> None:
self.cwd = self.tty_name = ''
self.argv: List[str] = []
self.env: Dict[str, str] = {}
Funtion = TypeVar('Funtion', bound=Callable[..., Any])
def eintr_retry(func: Funtion) -> Funtion:
def ret(*a: Any, **kw: Any) -> Any:
while True:
with suppress(InterruptedError):
return func(*a, **kw)
return cast(Funtion, ret)
safe_close = eintr_retry(os.close)
safe_open = eintr_retry(os.open)
safe_ioctl = eintr_retry(fcntl.ioctl)
safe_dup2 = eintr_retry(os.dup2)
def establish_controlling_tty(fd_or_tty_name: Union[str, int], *dups: int, closefd: bool = True) -> int:
tty_name = os.ttyname(fd_or_tty_name) if isinstance(fd_or_tty_name, int) else fd_or_tty_name
with open(safe_open(tty_name, os.O_RDWR | os.O_CLOEXEC), 'w', closefd=closefd) as f:
tty_fd = f.fileno()
safe_ioctl(tty_fd, termios.TIOCSCTTY, 0)
for fd in dups:
safe_dup2(tty_fd, fd)
return -1 if closefd else tty_fd
interactive_and_job_control_signals = (
signal.SIGINT, signal.SIGQUIT, signal.SIGTSTP, signal.SIGTTIN, signal.SIGTTOU
)
def fork_socket_child(child_data: SocketChildData, tty_fd: int, stdio_fds: Dict[str, int], free_non_child_resources: Callable[[], None]) -> int:
# see https://www.gnu.org/software/libc/manual/html_node/Launching-Jobs.html
child_pid = safer_fork()
if child_pid:
return child_pid
# child process
eintr_retry(os.setpgid)(0, 0)
eintr_retry(os.tcsetpgrp)(tty_fd, eintr_retry(os.getpgid)(0))
for x in interactive_and_job_control_signals:
signal.signal(x, signal.SIG_DFL)
restore_python_signal_handlers()
# the std streams fds are closed in free_non_child_resources()
for which in ('stdin', 'stdout', 'stderr'):
fd = stdio_fds[which] if stdio_fds[which] > -1 else tty_fd
safe_dup2(fd, getattr(sys, which).fileno())
free_non_child_resources()
child_main({'cwd': child_data.cwd, 'env': child_data.env, 'argv': child_data.argv}, prewarm_type='socket')
def fork_socket_child_supervisor(conn: socket.socket, free_non_child_resources: Callable[[], None]) -> None:
import array
global is_zygote
if safer_fork():
conn.close()
return
is_zygote = False
os.setsid()
restore_python_signal_handlers()
free_non_child_resources()
signal_read_fd = install_signal_handlers(signal.SIGCHLD, signal.SIGUSR1)[0]
# See https://www.gnu.org/software/libc/manual/html_node/Initializing-the-Shell.html
for x in interactive_and_job_control_signals:
signal.signal(x, signal.SIG_IGN)
poll = select.poll()
poll.register(signal_read_fd, select.POLLIN)
from_socket_buf = b''
to_socket_buf = b''
keep_going = True
child_pid = -1
socket_fd = conn.fileno()
launch_msg_read = False
os.set_blocking(socket_fd, False)
received_fds: List[int] = []
stdio_positions = dict.fromkeys(('stdin', 'stdout', 'stderr'), -1)
stdio_fds = dict.fromkeys(('stdin', 'stdout', 'stderr'), -1)
winsize = 8
exit_after_write = False
child_data = SocketChildData()
def handle_signal(siginfo: SignalInfo) -> None:
nonlocal to_socket_buf, exit_after_write, child_pid
if siginfo.si_signo != signal.SIGCHLD or siginfo.si_code not in (CLD_KILLED, CLD_EXITED, CLD_STOPPED):
return
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG | os.WUNTRACED)
except ChildProcessError:
pid = 0
if not pid:
break
if pid != child_pid:
continue
to_socket_buf += struct.pack('q', status)
if not os.WIFSTOPPED(status):
exit_after_write = True
child_pid = -1
def write_to_socket() -> None:
nonlocal keep_going, to_socket_buf, keep_going
buf = memoryview(to_socket_buf)
while buf:
try:
n = os.write(socket_fd, buf)
except OSError:
n = 0
if n == 0:
keep_going = False
return
buf = buf[n:]
to_socket_buf = bytes(buf)
if exit_after_write and not to_socket_buf:
keep_going = False
def read_winsize() -> None:
nonlocal from_socket_buf
msg = conn.recv(io.DEFAULT_BUFFER_SIZE)
if not msg:
return
from_socket_buf += msg
data = memoryview(from_socket_buf)
record = memoryview(b'')
while len(data) >= winsize:
record, data = data[:winsize], data[winsize:]
if record:
try:
with open(safe_open(os.ctermid(), os.O_RDWR | os.O_CLOEXEC), 'w') as f:
safe_ioctl(f.fileno(), termios.TIOCSWINSZ, record)
except OSError:
traceback.print_exc()
from_socket_buf = bytes(data)
def read_launch_msg() -> bool:
nonlocal keep_going, from_socket_buf, launch_msg_read, winsize
try:
msg, ancdata, flags, addr = conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024)
except OSError as e:
if e.errno == errno.ENOMEM:
# macOS does this when no ancilliary data is present
msg, ancdata, flags, addr = conn.recvmsg(io.DEFAULT_BUFFER_SIZE)
else:
raise
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
fds = array.array("i") # Array of ints
# Append data, ignoring any truncated integers at the end.
fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
received_fds.extend(fds)
if not msg:
return False
from_socket_buf += msg
while (idx := from_socket_buf.find(b'\0')) > -1:
line = from_socket_buf[:idx].decode('utf-8')
from_socket_buf = from_socket_buf[idx+1:]
cmd, _, payload = line.partition(':')
if cmd == 'finish':
for x in received_fds:
os.set_inheritable(x, True)
os.set_blocking(x, True)
for k, pos in stdio_positions.items():
if pos > -1:
stdio_fds[k] = received_fds[pos]
del received_fds[:]
return True
elif cmd == 'cwd':
child_data.cwd = payload
elif cmd == 'env':
k, _, v = payload.partition('=')
child_data.env[k] = v
elif cmd == 'argv':
child_data.argv.append(payload)
elif cmd in stdio_positions:
stdio_positions[cmd] = int(payload)
elif cmd == 'tty_name':
child_data.tty_name = payload
elif cmd == 'winsize':
winsize = int(payload)
return False
def free_non_child_resources2() -> None:
for fd in received_fds:
safe_close(fd)
for k, v in tuple(stdio_fds.items()):
if v > -1:
safe_close(v)
stdio_fds[k] = -1
conn.close()
def launch_child() -> None:
nonlocal to_socket_buf, child_pid
sys.__stdout__.flush()
sys.__stderr__.flush()
tty_fd = establish_controlling_tty(child_data.tty_name, closefd=False)
child_pid = fork_socket_child(child_data, tty_fd, stdio_fds, free_non_child_resources2)
if child_pid:
# this is also done in the child process, but we dont
# know when, so do it here as well
eintr_retry(os.setpgid)(child_pid, child_pid)
eintr_retry(os.tcsetpgrp)(tty_fd, child_pid)
for fd in stdio_fds.values():
if fd > -1:
safe_close(fd)
safe_close(tty_fd)
else:
raise SystemExit('fork_socket_child() returned in the child process')
to_socket_buf += struct.pack('q', child_pid)
def read_from_socket() -> None:
nonlocal launch_msg_read
if launch_msg_read:
read_winsize()
else:
if read_launch_msg():
launch_msg_read = True
launch_child()
try:
while keep_going:
poll.register(socket_fd, select.POLLIN | (select.POLLOUT if to_socket_buf else 0))
for fd, event in poll.poll():
if event & error_events:
keep_going = False
break
if fd == socket_fd:
if event & select.POLLOUT:
write_to_socket()
if event & select.POLLIN:
read_from_socket()
elif fd == signal_read_fd and event & select.POLLIN:
read_signals(signal_read_fd, handle_signal)
finally:
if child_pid: # supervisor process
with suppress(OSError):
conn.shutdown(socket.SHUT_RDWR)
with suppress(OSError):
conn.close()
raise SystemExit(0)
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
global parent_tty_name
with suppress(OSError):
parent_tty_name = os.ttyname(sys.stdout.fileno())
os.set_blocking(notify_child_death_fd, False)
os.set_blocking(stdin_fd, False)
os.set_blocking(stdout_fd, False)
signal_read_fd = install_signal_handlers(signal.SIGCHLD, signal.SIGUSR1)[0]
os.set_blocking(unix_socket.fileno(), False)
unix_socket.listen(5)
poll = select.poll()
poll.register(stdin_fd, select.POLLIN)
poll.register(signal_read_fd, select.POLLIN)
poll.register(unix_socket.fileno(), select.POLLIN)
input_buf = output_buf = child_death_buf = b''
child_ready_fds: Dict[int, int] = {}
child_pid_map: Dict[int, int] = {}
child_id_counter = count()
# runpy issues a warning when running modules that have already been
# imported. Ignore it.
warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy')
prewarm()
def get_all_non_child_fds() -> Iterator[int]:
yield notify_child_death_fd
yield stdin_fd
yield stdout_fd
# the signal fds are closed by remove_signal_handlers()
yield from child_ready_fds.values()
def free_non_child_resources() -> None:
for fd in get_all_non_child_fds():
if fd > -1:
safe_close(fd)
unix_socket.close()
def check_event(event: int, err_msg: str) -> None:
if event & select.POLLHUP:
raise SystemExit(0)
if event & error_events:
print_error(err_msg)
raise SystemExit(1)
def handle_input(event: int) -> None:
nonlocal input_buf, output_buf
check_event(event, 'Polling of input pipe failed')
if not (event & select.POLLIN):
return
d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE)
if not d:
raise SystemExit(0)
input_buf += d
while (idx := input_buf.find(b'\n')) > -1:
line = input_buf[:idx].decode()
input_buf = input_buf[idx+1:]
cmd, _, payload = line.partition(':')
if cmd == 'reload_kitty_config':
reload_kitty_config(payload)
elif cmd == 'ready':
child_id = int(payload)
cfd = child_ready_fds.pop(child_id, None)
if cfd is not None:
safe_close(cfd)
elif cmd == 'quit':
raise SystemExit(0)
elif cmd == 'fork':
try:
child_pid, ready_fd_write = fork(payload, free_non_child_resources)
except Exception as e:
es = str(e).replace('\n', ' ')
output_buf += f'ERR:{es}\n'.encode()
else:
if is_zygote:
child_id = next(child_id_counter)
child_pid_map[child_pid] = child_id
child_ready_fds[child_id] = ready_fd_write
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
elif cmd == 'echo':
output_buf += f'{payload}\n'.encode()
def handle_output(event: int) -> None:
nonlocal output_buf
check_event(event, 'Polling of output pipe failed')
if not (event & select.POLLOUT):
return
if output_buf:
n = os.write(stdout_fd, output_buf)
if not n:
raise SystemExit(0)
output_buf = output_buf[n:]
if not output_buf:
poll.unregister(stdout_fd)
def handle_notify_child_death(event: int) -> None:
nonlocal child_death_buf
check_event(event, 'Polling of notify child death pipe failed')
if not (event & select.POLLOUT):
return
if child_death_buf:
n = os.write(notify_child_death_fd, child_death_buf)
if not n:
raise SystemExit(0)
child_death_buf = child_death_buf[n:]
if not child_death_buf:
poll.unregister(notify_child_death_fd)
def handle_child_death(dead_child_id: int, dead_child_pid: int) -> None:
nonlocal child_death_buf
xfd = child_ready_fds.pop(dead_child_id, None)
if xfd is not None:
safe_close(xfd)
child_death_buf += f'{dead_child_pid}\n'.encode()
def handle_signals(event: int) -> None:
check_event(event, 'Polling of signal pipe failed')
if not event & select.POLLIN:
return
def handle_signal(siginfo: SignalInfo) -> None:
if siginfo.si_signo != signal.SIGCHLD or siginfo.si_code not in (CLD_KILLED, CLD_EXITED, CLD_STOPPED):
return
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG | os.WUNTRACED)
except ChildProcessError:
pid = 0
if not pid:
break
child_id = child_pid_map.pop(pid, None)
if child_id is not None:
handle_child_death(child_id, pid)
read_signals(signal_read_fd, handle_signal)
def handle_socket_client(event: int) -> None:
check_event(event, 'UNIX socket fd listener failed')
conn, addr = unix_socket.accept()
if not verify_socket_creds(conn):
print_error('Connection attempted with invalid credentials ignoring')
conn.close()
return
fork_socket_child_supervisor(conn, free_non_child_resources)
keep_type_checker_happy = True
try:
while is_zygote and keep_type_checker_happy:
if output_buf:
poll.register(stdout_fd, select.POLLOUT)
if child_death_buf:
poll.register(notify_child_death_fd, select.POLLOUT)
for (q, event) in poll.poll():
if q == stdin_fd:
handle_input(event)
elif q == stdout_fd:
handle_output(event)
elif q == signal_read_fd:
handle_signals(event)
elif q == notify_child_death_fd:
handle_notify_child_death(event)
elif q == unix_socket.fileno():
handle_socket_client(event)
except (KeyboardInterrupt, EOFError, BrokenPipeError):
if is_zygote:
raise SystemExit(1)
raise
except Exception:
if is_zygote:
traceback.print_exc()
raise
finally:
if is_zygote:
restore_python_signal_handlers()
for fmd in child_ready_fds.values():
with suppress(OSError):
safe_close(fmd)
def get_socket_name(unix_socket: socket.socket) -> str:
sname = unix_socket.getsockname()
if isinstance(sname, bytes):
sname = sname.decode('utf-8')
assert isinstance(sname, str)
if sname.startswith('\0'):
sname = '@' + sname[1:]
return sname
def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int, unix_socket: Optional[socket.socket] = None) -> None:
os.setsid()
os.set_inheritable(stdin_read, False)
os.set_inheritable(stdout_write, False)
os.set_inheritable(death_notify_write, False)
running_in_kitty(False)
if unix_socket is None:
unix_socket = random_unix_socket()
os.write(stdout_write, f'{get_socket_name(unix_socket)}\n'.encode('utf-8'))
if not sys.stdout.line_buffering: # happens if the parent kitty instance has stdout not pointing to a terminal
sys.stdout.reconfigure(line_buffering=True) # type: ignore
try:
main(stdin_read, stdout_write, death_notify_write, unix_socket)
finally:
set_options(None)
if is_zygote:
unix_socket.close()
def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[PrewarmProcess]:
stdin_read, stdin_write = safe_pipe()
stdout_read, stdout_write = safe_pipe()
death_notify_read, death_notify_write = safe_pipe()
if use_exec:
import subprocess
tp = subprocess.Popen(
[kitty_exe(), '+runpy', f'from kitty.prewarm import exec_main; exec_main({stdin_read}, {stdout_write}, {death_notify_write})'],
pass_fds=(stdin_read, stdout_write, death_notify_write))
child_pid = tp.pid
tp.returncode = 0 # prevent a warning when the popen object is deleted with the process still running
os.set_blocking(stdout_read, True)
with open(stdout_read, 'rb', closefd=False) as f:
socket_name = f.readline().decode('utf-8').rstrip()
os.set_blocking(stdout_read, False)
else:
unix_socket = random_unix_socket()
socket_name = get_socket_name(unix_socket)
child_pid = safer_fork()
if child_pid:
# master
if not use_exec:
unix_socket.close()
safe_close(stdin_read)
safe_close(stdout_write)
safe_close(death_notify_write)
p = PrewarmProcess(child_pid, stdin_write, stdout_read, death_notify_read, socket_name)
if use_exec:
p.reload_kitty_config()
return p
# child
set_use_os_log(False)
safe_close(stdin_write)
safe_close(stdout_read)
safe_close(death_notify_read)
set_options(opts)
exec_main(stdin_read, stdout_write, death_notify_write, unix_socket)
raise SystemExit(0)