From 7b7f1ecc54f4f97a0d281aed40a3ebbff53ecac9 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Tue, 5 Jul 2022 12:49:22 +0530 Subject: [PATCH] Add basic tests for socket prewarm --- kitty/child.py | 2 +- kitty/prewarm.py | 74 ++++++++++++++++++++++++++--------------- kitty_tests/__init__.py | 4 +-- kitty_tests/prewarm.py | 49 ++++++++++++++++++++++++++- 4 files changed, 98 insertions(+), 31 deletions(-) diff --git a/kitty/child.py b/kitty/child.py index d8c0162a7..dbfc5f221 100644 --- a/kitty/child.py +++ b/kitty/child.py @@ -251,7 +251,7 @@ class Child: env['COLORTERM'] = 'truecolor' env['KITTY_PID'] = getpid() if not self.is_prewarmed: - env['KITTY_PREWARM_SOCKET'] = f'{os.geteuid()}:{os.getegid()}:{fast_data_types.get_boss().prewarm.unix_socket_name}' + env['KITTY_PREWARM_SOCKET'] = fast_data_types.get_boss().prewarm.socket_env_var() if self.cwd: # needed in case cwd is a symlink, in which case shells # can use it to display the current directory name rather diff --git a/kitty/prewarm.py b/kitty/prewarm.py index db98f3468..fdaa11388 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from importlib import import_module from itertools import count from typing import ( - IO, TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, NoReturn, Optional, + IO, TYPE_CHECKING, Any, Callable, Dict, Iterator, List, NoReturn, Optional, Tuple, Union, cast ) @@ -85,6 +85,9 @@ class PrewarmProcess: self.poll.register(self.read_from_process_fd, select.POLLIN) self.unix_socket_name = unix_socket_name + def socket_env_var(self) -> str: + return f'{os.geteuid()}:{os.getegid()}:{self.unix_socket_name}' + def take_from_worker_fd(self, create_file: bool = False) -> int: if create_file: os.set_blocking(self.from_prewarm_death_notify, True) @@ -293,7 +296,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn: raise SystemExit(0) -def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]: +def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tuple[int, int]: global is_zygote sz = pos = 0 with SharedMemory(name=shm_address, unlink_on_exit=True) as shm: @@ -331,9 +334,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]: remove_signal_handlers() os.close(r) os.close(ready_fd_write) - for fd in all_non_child_fds: - if fd > -1: - os.close(fd) + free_non_child_resources() os.setsid() tty_name = cmd.get('tty_name') if tty_name: @@ -363,12 +364,11 @@ class SocketClosed(Exception): class SocketChild: def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll): - self.fd = conn.fileno() - poll.register(self.fd, select.POLLIN) self.registered = True self.poll = poll self.addr = addr self.conn = conn + self.poll.register(self.conn.fileno(), select.POLLIN) self.input_buf = self.output_buf = b'' self.fds: List[int] = [] self.child_id = -1 @@ -377,10 +377,13 @@ class SocketChild: self.argv: List[str] = [] self.stdin = self.stdout = self.stderr = -1 self.pid = -1 + self.closed = False def unregister_from_poll(self) -> None: if self.registered: - self.poll.unregister(self.fd) + fd = self.conn.fileno() + if fd > -1: + self.poll.unregister(self.conn.fileno()) self.registered = False def read(self) -> bool: @@ -426,7 +429,7 @@ class SocketChild: return False - def fork(self, all_non_child_fds: Iterable[int]) -> None: + def fork(self, free_non_child_resources: Callable[[], None]) -> None: global is_zygote r, w = safe_pipe() self.pid = os.fork() @@ -470,9 +473,7 @@ class SocketChild: if self.stderr > -1: os.dup2(self.stderr, sys.__stderr__.fileno()) os.close(w) - for fd in all_non_child_fds: - if fd > -1: - os.close(fd) + free_non_child_resources() child_main({'cwd': self.cwd, 'env': self.env, 'argv': self.argv}) raise SystemExit(0) @@ -499,6 +500,23 @@ class SocketChild: return False return True + def close(self) -> None: + if self.closed: + return + self.unregister_from_poll() + self.closed = True + self.conn.close() + if self.stdin > -1: + os.close(self.stdin) + self.stdin = -1 + if self.stdout > -1: + os.close(self.stdout) + self.stdout = -1 + if self.stderr > -1: + os.close(self.stderr) + self.stderr = -1 + __del__ = close + def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None: global parent_tty_name @@ -525,10 +543,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: prewarm() def remove_socket_child(sc: SocketChild) -> None: - socket_children.pop(sc.fd, None) - sc.unregister_from_poll() + socket_children.pop(sc.conn.fileno(), None) socket_pid_map.pop(sc.pid, None) - sc.conn.close() + sc.close() def get_all_non_child_fds() -> Iterator[int]: yield notify_child_death_fd @@ -536,14 +553,14 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: yield stdout_fd # the signal fds are closed by remove_signal_handlers() yield from child_ready_fds.values() - for sc in socket_children.values(): - yield sc.fd - if sc.stdin > -1: - yield sc.stdin - if sc.stdout > -1: - yield sc.stdout - if sc.stderr > -1: - yield sc.stderr + + def free_non_child_resources() -> None: + for fd in get_all_non_child_fds(): + if fd > -1: + os.close(fd) + unix_socket.close() + for sc in tuple(socket_children.values()): + remove_socket_child(sc) def check_event(event: int, err_msg: str) -> None: if event & select.POLLHUP: @@ -576,7 +593,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: raise SystemExit(0) elif cmd == 'fork': try: - child_pid, ready_fd_write = fork(payload, get_all_non_child_fds()) + child_pid, ready_fd_write = fork(payload, free_non_child_resources) except Exception as e: es = str(e).replace('\n', ' ') output_buf += f'ERR:{es}\n'.encode() @@ -654,7 +671,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: check_event(event, 'UNIX socket fd listener failed') conn, addr = unix_socket.accept() sc = SocketChild(conn, addr, poll) - socket_children[sc.fd] = sc + socket_children[sc.conn.fileno()] = sc def handle_socket_launch(fd: int, event: int) -> None: scq = socket_children.get(q) @@ -664,7 +681,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: try: if scq.read(): scq.unregister_from_poll() - scq.fork(get_all_non_child_fds()) + scq.fork(free_non_child_resources) socket_pid_map[scq.pid] = scq scq.child_id = next(child_id_counter) except SocketClosed: @@ -714,6 +731,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: for fmd in child_ready_fds.values(): with suppress(OSError): os.close(fmd) + for sc in tuple(socket_children.values()): + remove_socket_child(sc) def get_socket_name(unix_socket: socket.socket) -> str: @@ -743,7 +762,8 @@ def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int, unix_ main(stdin_read, stdout_write, death_notify_write, unix_socket) finally: set_options(None) - unix_socket.close() + if is_zygote: + unix_socket.close() def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[PrewarmProcess]: diff --git a/kitty_tests/__init__.py b/kitty_tests/__init__.py index 4096b7ddf..294a7b5dc 100644 --- a/kitty_tests/__init__.py +++ b/kitty_tests/__init__.py @@ -204,8 +204,8 @@ class PTY: self.master_fd, self.slave_fd = openpty() self.is_child = False else: - pid, self.master_fd = fork() - self.is_child = pid == CHILD + self.child_pid, self.master_fd = fork() + self.is_child = self.child_pid == CHILD if self.is_child: while read_screen_size().width != columns * cell_width: time.sleep(0.01) diff --git a/kitty_tests/prewarm.py b/kitty_tests/prewarm.py index 193a14d87..75fcb7c8e 100644 --- a/kitty_tests/prewarm.py +++ b/kitty_tests/prewarm.py @@ -4,13 +4,15 @@ import json import os +import re import select import signal import subprocess import tempfile import time +from contextlib import suppress -from kitty.constants import kitty_exe +from kitty.constants import kitty_exe, read_kitty_resource from kitty.fast_data_types import ( CLD_EXITED, CLD_KILLED, get_options, has_sigqueue, install_signal_handlers, read_signals, remove_signal_handlers, sigqueue @@ -19,10 +21,55 @@ from kitty.fast_data_types import ( from . import BaseTest +def socket_child_main(exit_code=0): + import json + import os + from kitty.fast_data_types import get_options + from kitty.utils import read_screen_size + output = { + 'test_env': os.environ.get('TEST_ENV_PASS', ''), + 'cwd': os.getcwd(), + 'font_family': get_options().font_family, + 'cols': read_screen_size().cols, + + 'done': 'hello', + } + print(json.dumps(output, indent=2)) + raise SystemExit(exit_code) + +# END_socket_child_main + + class Prewarm(BaseTest): maxDiff = None + def test_socket_prewarming(self): + from kitty.prewarm import fork_prewarm_process + exit_code = 17 + src = re.search( + r'^(def socket_child_main.+?)^# END_socket_child_main', read_kitty_resource('prewarm.py', 'kitty_tests').decode(), + flags=re.M | re.DOTALL).group(1) + '\n\n' + + cwd = tempfile.gettempdir() + opts = self.set_options() + opts.config_overrides = 'font_family prewarm', + p = fork_prewarm_process(opts, use_exec=True) + if p is None: + return + env = {'TEST_ENV_PASS': 'xyz', 'KITTY_PREWARM_SOCKET': p.socket_env_var()} + cols = 117 + pty = self.create_pty(argv=[kitty_exe(), '+runpy', src + f'socket_child_main({exit_code})'], cols=cols, env=env, cwd=cwd) + status = os.waitpid(pty.child_pid, 0)[1] + with suppress(AttributeError): + self.assertEqual(os.waitstatus_to_exitcode(status), exit_code) + pty.wait_till(lambda: 'hello' in pty.screen_contents()) + output = json.loads(pty.screen_contents().strip()) + self.assertEqual(output['test_env'], env['TEST_ENV_PASS']) + self.assertEqual(output['cwd'], cwd) + self.assertEqual(output['font_family'], 'prewarm') + self.assertEqual(output['cols'], cols) + def test_prewarming(self): from kitty.prewarm import fork_prewarm_process