From b222ab1bf6c6d3d592ab0fa140fe4e3ecff67c9c Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 13 Jun 2022 21:34:49 +0530 Subject: [PATCH] Start work on socket based prewarm --- kitty/child.c | 10 +- kitty/fast_data_types.pyi | 2 +- kitty/prewarm.py | 203 +++++++++++++++++++++++++++++++++++--- 3 files changed, 195 insertions(+), 20 deletions(-) diff --git a/kitty/child.c b/kitty/child.c index 5d6c71349..0169b0038 100644 --- a/kitty/child.c +++ b/kitty/child.c @@ -186,17 +186,17 @@ spawn(PyObject *self UNUSED, PyObject *args) { static PyObject* establish_controlling_tty(PyObject *self UNUSED, PyObject *args) { const char *ttyname; - int stdin_fd, stdout_fd, stderr_fd; - if (!PyArg_ParseTuple(args, "siii", &ttyname, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL; + int stdin_fd = -1, stdout_fd = -1, stderr_fd = -1; + if (!PyArg_ParseTuple(args, "s|iii", &ttyname, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL; int tfd = safe_open(ttyname, O_RDWR, 0); if (tfd == -1) return PyErr_SetFromErrnoWithFilename(PyExc_OSError, ttyname); #ifdef TIOCSCTTY // On BSD open() does not establish the controlling terminal if (ioctl(tfd, TIOCSCTTY, 0) == -1) return PyErr_SetFromErrno(PyExc_OSError); #endif - if (dup2(tfd, stdin_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); - if (dup2(tfd, stdout_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); - if (dup2(tfd, stderr_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); + if (stdin_fd > -1 && dup2(tfd, stdin_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); + if (stdout_fd > -1 && dup2(tfd, stdout_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); + if (stderr_fd > -1 && dup2(tfd, stderr_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); safe_close(tfd, __FILE__, __LINE__); Py_RETURN_NONE; } diff --git a/kitty/fast_data_types.pyi b/kitty/fast_data_types.pyi index 7646e9321..863f4fa4b 100644 --- a/kitty/fast_data_types.pyi +++ b/kitty/fast_data_types.pyi @@ -1395,7 +1395,7 @@ def sigqueue(pid: int, signal: int, value: int) -> None: pass -def establish_controlling_tty(ttyname: str, stdin: int, stdout: int, stderr: int) -> None: +def establish_controlling_tty(ttyname: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None: pass diff --git a/kitty/prewarm.py b/kitty/prewarm.py index c796ba1ca..1e86433a7 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -6,6 +6,7 @@ import json import os import select import signal +import socket import sys import time import warnings @@ -27,7 +28,7 @@ from kitty.fast_data_types import ( from kitty.options.types import Options from kitty.shm import SharedMemory from kitty.types import SignalInfo -from kitty.utils import log_error +from kitty.utils import log_error, random_unix_socket if TYPE_CHECKING: from _typeshed import ReadableBuffer, WriteableBuffer @@ -69,6 +70,7 @@ class PrewarmProcess: to_prewarm_stdin: int, from_prewarm_stdout: int, from_prewarm_death_notify: int, + unix_socket_name: str, ) -> None: self.children: Dict[int, Child] = {} self.worker_pid = prewarm_process_pid @@ -77,6 +79,7 @@ class PrewarmProcess: self.read_from_process_fd = from_prewarm_stdout self.poll = select.poll() self.poll.register(self.read_from_process_fd, select.POLLIN) + self.unix_socket_name = unix_socket_name def take_from_worker_fd(self, create_file: bool = False) -> int: if create_file: @@ -256,7 +259,7 @@ class MemoryViewReadWrapper(io.TextIOWrapper): super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace') -def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn: +def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn: cwd = cmd.get('cwd') if cwd: with suppress(OSError): @@ -268,10 +271,11 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn: argv = cmd.get('argv') if argv: sys.argv = list(argv) - poll = select.poll() - poll.register(ready_fd, select.POLLIN) - tuple(poll.poll()) - os.close(ready_fd) + if ready_fd > -1: + poll = select.poll() + poll.register(ready_fd, select.POLLIN) + tuple(poll.poll()) + os.close(ready_fd) main_entry_point() raise SystemExit(0) @@ -291,9 +295,14 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]: try: child_pid = os.fork() except OSError: + os.close(r) + os.close(w) + os.close(ready_fd_read) + os.close(ready_fd_write) if sz: with SharedMemory(shm_address, unlink_on_exit=True): pass + raise if child_pid: # master process os.close(w) @@ -308,7 +317,8 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]: os.close(r) os.close(ready_fd_write) for fd in all_non_child_fds: - os.close(fd) + if fd > -1: + os.close(fd) os.setsid() tty_name = cmd.get('tty_name') if tty_name: @@ -331,17 +341,123 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]: sys.stdin = sys.__stdin__ -def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: +class SocketClosed(Exception): + pass + + +class SocketChild: + + def __init__(self, conn: socket.socket, addr: bytes): + self.fd = conn.fileno() + self.addr = addr + self.conn = conn + self.input_buf = self.output_buf = b'' + self.fds: List[int] = [] + self.child_id = -1 + self.cwd = self.tty_name = '' + self.env: Dict[str, str] = {} + self.argv: List[str] = [] + self.stdin = self.stdout = self.stderr = -1 + self.pid = -1 + + def read(self) -> bool: + import array + fds = array.array("i") # Array of ints + maxfds = 3 + msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, socket.CMSG_LEN(maxfds * fds.itemsize)) + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + # Append data, ignoring any truncated integers at the end. + fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + self.fds += list(fds) + if not msg: + raise SocketClosed('socket unexpectedly closed') + self.input_buf += msg + while (idx := self.input_buf.find(b'\0')) > -1: + line = self.input_buf[:idx].decode('utf-8') + self.input_buf = self.input_buf[idx+1:] + cmd, _, payload = line.partition(':') + if cmd == 'finish': + if self.stdin > -1: + self.stdin = self.fds[self.stdin] + if self.stdout > -1: + self.stdout = self.fds[self.stdout] + if self.stderr > -1: + self.stderr = self.fds[self.stderr] + return True + elif cmd == 'cwd': + self.cwd = payload + elif cmd == 'tty_name': + self.tty_name = payload + elif cmd == 'env': + k, _, v = payload.partition('=') + self.env[k] = v + elif cmd == 'argv': + self.argv.append(payload) + elif cmd == 'stdin': + self.stdin = int(payload) + elif cmd == 'stdout': + self.stdout = int(payload) + elif cmd == 'stderr': + self.stderr = int(payload) + + return False + + def fork(self, all_non_child_fds: Iterable[int]) -> None: + self.pid = os.fork() + if self.pid > 0: + # master process + if self.stdin > -1: + os.close(self.stdin) + self.stdin = -1 + if self.stdout > -1: + os.close(self.stdout) + self.stdout = -1 + if self.stderr > -1: + os.close(self.stderr) + self.stderr = -1 + return + # child process + os.setsid() + remove_signal_handlers() + if self.tty_name: + sys.__stdout__.flush() + sys.__stderr__.flush() + establish_controlling_tty( + self.tty_name, + sys.__stdin__.fileno() if self.stdin == -1 else -1, + sys.__stdout__.fileno() if self.stdout == -1 else -1, + sys.__stderr__.fileno() if self.stderr == -1 else -1) + # the std streams fds are in all_non_child_fds already + # so they will be closed there + if self.stdin > -1: + os.dup2(self.stdin, sys.__stdin__.fileno()) + if self.stdout > -1: + os.dup2(self.stdout, sys.__stdout__.fileno()) + if self.stderr > -1: + os.dup2(self.stderr, sys.__stderr__.fileno()) + for fd in all_non_child_fds: + if fd > -1: + os.close(fd) + child_main({'cwd': self.cwd, 'env': self.env, 'argv': self.argv}) + raise SystemExit(0) + + +def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None: os.set_blocking(notify_child_death_fd, False) os.set_blocking(stdin_fd, False) os.set_blocking(stdout_fd, False) signal_read_fd = install_signal_handlers(signal.SIGCHLD, signal.SIGUSR1)[0] + os.set_blocking(unix_socket.fileno(), False) + unix_socket.listen(5) poll = select.poll() poll.register(stdin_fd, select.POLLIN) poll.register(signal_read_fd, select.POLLIN) + poll.register(unix_socket.fileno(), select.POLLIN) input_buf = output_buf = child_death_buf = b'' child_ready_fds: Dict[int, int] = {} child_pid_map: Dict[int, int] = {} + socket_children: Dict[int, SocketChild] = {} child_id_counter = count() self_pid = os.getpid() # runpy issues a warning when running modules that have already been @@ -355,6 +471,14 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: yield stdout_fd # the signal fds are closed by remove_signal_handlers() yield from child_ready_fds.values() + for sc in socket_children.values(): + yield sc.fd + if sc.stdin > -1: + yield sc.stdin + if sc.stdout > -1: + yield sc.stdout + if sc.stderr > -1: + yield sc.stderr def check_event(event: int, err_msg: str) -> None: if event & select.POLLHUP: @@ -448,11 +572,25 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: if not pid: break child_id = child_pid_map.pop(pid, None) - if child_id is not None: + if child_id is None: + for sc in socket_children.values(): + if sc.pid == pid: + sc.conn.sendall(f'{status}'.encode('ascii')) + sc.conn.shutdown(socket.SHUT_RDWR) + sc.conn.close() + break + else: handle_child_death(child_id, pid) read_signals(signal_read_fd, handle_signal) + def handle_socket_client(event: int) -> None: + check_event(event, 'UNIX socket fd listener failed') + conn, addr = unix_socket.accept() + sc = SocketChild(conn, addr) + socket_children[sc.fd] = sc + poll.register(sc.fd, select.POLLIN) + try: while True: if output_buf: @@ -468,6 +606,24 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: handle_signals(event) elif q == notify_child_death_fd: handle_notify_child_death(event) + elif q == unix_socket.fileno(): + handle_socket_client(event) + else: + scq = socket_children.get(q) + if scq is not None: + if event & select.POLLIN: + try: + if scq.read(): + poll.unregister(scq.fd) + scq.conn.shutdown(socket.SHUT_RD) + scq.fork(get_all_non_child_fds()) + scq.child_id = child_pid_map[scq.pid] = next(child_id_counter) + except SocketClosed: + socket_children.pop(q) + continue + if event & error_events: + socket_children.pop(q) + continue except (KeyboardInterrupt, EOFError, BrokenPipeError): if os.getpid() == self_pid: raise SystemExit(1) @@ -485,7 +641,17 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None: os.close(fmd) -def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int) -> None: +def get_socket_name(unix_socket: socket.socket) -> str: + sname = unix_socket.getsockname() + if isinstance(sname, bytes): + sname = sname.decode('utf-8') + assert isinstance(sname, str) + if sname.startswith('\0'): + sname = '@' + sname[1:] + return sname + + +def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int, unix_socket: Optional[socket.socket] = None) -> None: os.setsid() # SIGUSR1 is used for reloading kitty config, we rely on the parent process # to inform us of that @@ -495,9 +661,11 @@ def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int) -> No os.set_inheritable(stdout_write, False) os.set_inheritable(death_notify_write, False) running_in_kitty(False) - + if unix_socket is None: + unix_socket = random_unix_socket() + os.write(stdout_write, f'{get_socket_name(unix_socket)}\n'.encode('utf-8')) try: - main(stdin_read, stdout_write, death_notify_write) + main(stdin_read, stdout_write, death_notify_write, unix_socket) finally: set_options(None) @@ -513,14 +681,21 @@ def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[Prew pass_fds=(stdin_read, stdout_write, death_notify_write)) child_pid = tp.pid tp.returncode = 0 # prevent a warning when the popen object is deleted with the process still running + os.set_blocking(stdout_read, True) + with open(stdout_read, 'rb', closefd=False) as f: + socket_name = f.readline().decode('utf-8').rstrip() + os.set_blocking(stdout_read, False) else: + unix_socket = random_unix_socket() + socket_name = get_socket_name(unix_socket) child_pid = os.fork() if child_pid: # master + unix_socket.close() os.close(stdin_read) os.close(stdout_write) os.close(death_notify_write) - p = PrewarmProcess(child_pid, stdin_write, stdout_read, death_notify_read) + p = PrewarmProcess(child_pid, stdin_write, stdout_read, death_notify_read, socket_name) if use_exec: p.reload_kitty_config() return p @@ -529,5 +704,5 @@ def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[Prew os.close(stdout_read) os.close(death_notify_read) set_options(opts) - exec_main(stdin_read, stdout_write, death_notify_write) + exec_main(stdin_read, stdout_write, death_notify_write, unix_socket) raise SystemExit(0)