From 98f46f8bd7d557a77ec19f573ca854d8e304aeca Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 6 Jun 2022 20:39:36 +0530 Subject: [PATCH] Start work on prewarming The prewarm process and its controller are implemented with some basic tests. --- kittens/prewarm/__init__.py | 0 kittens/prewarm/main.py | 454 ++++++++++++++++++++++++++++++++++++ kitty_tests/__init__.py | 8 +- kitty_tests/prewarm.py | 48 ++++ 4 files changed, 507 insertions(+), 3 deletions(-) create mode 100644 kittens/prewarm/__init__.py create mode 100644 kittens/prewarm/main.py create mode 100644 kitty_tests/prewarm.py diff --git a/kittens/prewarm/__init__.py b/kittens/prewarm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kittens/prewarm/main.py b/kittens/prewarm/main.py new file mode 100644 index 000000000..9bc63b934 --- /dev/null +++ b/kittens/prewarm/main.py @@ -0,0 +1,454 @@ +#!/usr/bin/env python +# License: GPLv3 Copyright: 2022, Kovid Goyal + +import io +import json +import os +import select +import signal +import sys +import time +from contextlib import suppress +from dataclasses import dataclass +from importlib import import_module +from typing import ( + IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast +) + +from kitty.child import remove_cloexec +from kitty.constants import kitty_exe +from kitty.entry_points import main as main_entry_point +from kitty.fast_data_types import establish_controlling_tty, safe_pipe +from kitty.shm import SharedMemory + +if TYPE_CHECKING: + from _typeshed import ReadableBuffer, WriteableBuffer + + +hangup_events = select.POLLHUP +error_events = select.POLLERR | select.POLLNVAL +basic_events = hangup_events | error_events + + +class PrewarmProcessFailed(Exception): + pass + + +@dataclass +class Child: + child_id: int + child_process_pid: int + + +class PrewarmProcess: + + def __init__(self, create_file_to_read_from_worker: bool = False) -> None: + self.from_worker_fd, self.in_worker_fd = safe_pipe() + self.children: Dict[int, Child] = {} + if create_file_to_read_from_worker: + os.set_blocking(self.from_worker_fd, True) + self.from_worker = open(self.from_worker_fd, mode='r', closefd=True) + self.from_worker_fd = -1 + + def take_from_worker_fd(self) -> int: + ans, self.from_worker_fd = self.from_worker_fd, -1 + return ans + + def __del__(self) -> None: + if self.from_worker_fd > -1: + os.close(self.from_worker_fd) + self.from_worker_fd = -1 + if hasattr(self, 'from_worker'): + self.from_worker.close() + del self.from_worker + if self.worker_started: + import subprocess + self.process.stdin and self.process.stdin.close() + self.process.stdout and self.process.stdout.close() + try: + self.process.wait(timeout=1.0) + except subprocess.TimeoutExpired: + self.process.kill() + del self.process + + @property + def worker_started(self) -> bool: + return self.in_worker_fd == -1 + + def ensure_worker(self) -> None: + if not self.worker_started: + import subprocess + self.process = subprocess.Popen( + [kitty_exe(), '+kitten', 'prewarm', str(self.in_worker_fd)], stdin=subprocess.PIPE, stdout=subprocess.PIPE, pass_fds=(self.in_worker_fd,)) + os.close(self.in_worker_fd) + self.in_worker_fd = -1 + assert self.process.stdin is not None and self.process.stdout is not None + self.write_to_process_fd = self.process.stdin.fileno() + self.read_from_process_fd = self.process.stdout.fileno() + os.set_blocking(self.write_to_process_fd, False) + os.set_blocking(self.read_from_process_fd, False) + self.poll = select.poll() + self.poll.register(self.process.stdout.fileno(), select.POLLIN | basic_events) + + def poll_to_send(self, yes: bool = True) -> None: + if yes: + self.poll.register(self.write_to_process_fd, select.POLLOUT | basic_events) + else: + self.poll.unregister(self.write_to_process_fd) + + def reload_kitty_config(self) -> None: + if self.worker_started: + self.send_to_prewarm_process('reload_kitty_config:\n') + + def __call__( + self, + tty_fd: int, + argv: List[str], + cwd: str = '', + env: Optional[Dict[str, str]] = None, + stdin_data: Optional[Union[str, bytes]] = None + ) -> Child: + self.ensure_worker() + tty_name = os.ttyname(tty_fd) + if isinstance(stdin_data, str): + stdin_data = stdin_data.encode() + if env is None: + env = dict(os.environ) + cmd: Dict[str, Union[int, List[str], str, Dict[str, str]]] = { + 'tty_name': tty_name, 'cwd': cwd or os.getcwd(), 'argv': argv, 'env': env, + } + total_size = 0 + if stdin_data is not None: + cmd['stdin_size'] = len(stdin_data) + total_size += len(stdin_data) + data = json.dumps(cmd).encode() + total_size += len(data) + SharedMemory.num_bytes_for_size + with SharedMemory(size=total_size, unlink_on_exit=True) as shm: + shm.write_data_with_size(data) + if stdin_data: + shm.write(stdin_data) + shm.flush() + self.send_to_prewarm_process(f'fork:{shm.name}\n') + input_buf = b'' + st = time.monotonic() + while time.monotonic() - st < 2: + for (fd, event) in self.poll.poll(0.2): + if event & basic_events: + raise PrewarmProcessFailed('Failed doing I/O with prewarm process') + if fd == self.read_from_process_fd and event & select.POLLIN: + d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE) + input_buf += d + while (idx := input_buf.find(b'\n')) > -1: + line = input_buf[:idx].decode() + input_buf = input_buf[idx+1:] + if line.startswith('CHILD:'): + _, cid, pid = line.split(':') + child = self.add_child(int(cid), int(pid)) + shm.unlink_on_exit = False + return child + if line.startswith('ERR:'): + raise PrewarmProcessFailed(line.split(':', 1)[-1]) + raise PrewarmProcessFailed('Timed out waiting for I/O with prewarm process') + + def add_child(self, child_id: int, pid: int) -> Child: + self.children[child_id] = c = Child(child_id, pid) + return c + + def send_to_prewarm_process(self, output_buf: Union[str, bytes] = b'', timeout: float = 2) -> None: + if isinstance(output_buf, str): + output_buf = output_buf.encode() + st = time.monotonic() + while time.monotonic() - st < timeout and output_buf: + self.poll_to_send(bool(output_buf)) + for (fd, event) in self.poll.poll(0.2): + if event & basic_events: + raise PrewarmProcessFailed('Failed doing I/O with prewarm process') + if fd == self.write_to_process_fd and event & select.POLLOUT: + n = os.write(self.write_to_process_fd, output_buf) + output_buf = output_buf[n:] + self.poll_to_send(False) + if output_buf: + raise PrewarmProcessFailed('Timed out waiting to write to prewarm process') + + def mark_child_as_ready(self, child_id: int) -> bool: + c = self.children.pop(child_id, None) + if c is None: + return False + self.send_to_prewarm_process(f'ready:{child_id}\n') + return True + + +def reload_kitty_config() -> None: + from kittens.tui.utils import kitty_opts + kitty_opts.clear_cached() + kitty_opts() + + +def prewarm() -> None: + reload_kitty_config() + for kitten in ('hints', 'ssh', 'unicode_input', 'ask', 'show_error'): + import_module(f'kittens.{kitten}.main') + + +class MemoryViewReadWrapperBytes(io.BufferedIOBase): + + def __init__(self, mw: memoryview): + self.mw = mw + self.pos = 0 + + def detach(self) -> io.RawIOBase: + raise io.UnsupportedOperation('detach() not supported') + + def read(self, size: Optional[int] = -1) -> bytes: + if size is None or size < 0: + size = max(0, len(self.mw) - self.pos) + oldpos = self.pos + self.pos = min(len(self.mw), self.pos + size) + if self.pos <= oldpos: + return b'' + return bytes(self.mw[oldpos:self.pos]) + + def readinto(self, b: 'WriteableBuffer') -> int: + if not isinstance(b, memoryview): + b = memoryview(b) + b = b.cast('B') + data = self.read(len(b)) + n = len(data) + b[:n] = data + return n + readinto1 = readinto + + def readall(self) -> bytes: + return self.read() + + def write(self, b: 'ReadableBuffer') -> int: + raise io.UnsupportedOperation('readonly stream') + + def readable(self) -> bool: + return True + + +class MemoryViewReadWrapper(io.TextIOWrapper): + + def __init__(self, mw: memoryview): + super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace') + + +def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn: + cwd = cmd.get('cwd') + if cwd: + try: + os.chdir(cwd) + except OSError: + with suppress(OSError): + os.chdir('/') + os.setsid() + env = cmd.get('env') + if env is not None: + os.environ.clear() + os.environ.update(env) + argv = cmd.get('argv') + if argv: + sys.argv = list(argv) + poll = select.poll() + poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP) + poll.poll() + os.close(ready_fd) + tty_name = cmd.get('tty_name') + if tty_name: + sys.__stdout__.flush() + sys.__stderr__.flush() + establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) + main_entry_point() + raise SystemExit(0) + + +def fork(shm_address: str, ready_fd: int) -> int: + sz = pos = 0 + with SharedMemory(name=shm_address, unlink_on_exit=True) as shm: + data = shm.read_data_with_size() + cmd = json.loads(data) + sz = cmd.get('stdin_size', 0) + if sz: + pos = shm.tell() + shm.unlink_on_exit = False + + try: + child_pid = os.fork() + except OSError: + if sz: + with SharedMemory(shm_address, unlink_on_exit=True): + pass + if child_pid: + # master process + return child_pid + # child process + if shm.unlink_on_exit: + child_main(cmd, ready_fd) + else: + with SharedMemory(shm_address, unlink_on_exit=True) as shm: + stdin_data = memoryview(shm.mmap)[pos:pos + sz] + if stdin_data: + sys.stdin = MemoryViewReadWrapper(stdin_data) + try: + child_main(cmd, ready_fd) + finally: + stdin_data.release() + sys.stdin = sys.__stdin__ + + +def waitstatus_to_exit_code(status: int) -> int: + with suppress(ValueError, AttributeError): + return os.waitstatus_to_exitcode(status) + return 0 + + +def main(args: List[str] = sys.argv) -> None: + read_signal_fd, write_signal_fd = safe_pipe() + notify_child_death_fd = int(sys.argv[-1]) + os.set_blocking(notify_child_death_fd, False) + signal.set_wakeup_fd(write_signal_fd) + signal.signal(signal.SIGCHLD, lambda *a: None) + signal.siginterrupt(signal.SIGCHLD, False) + prewarm() + stdin_fd = sys.__stdin__.fileno() + os.set_blocking(stdin_fd, False) + stdout_fd = sys.__stdout__.fileno() + os.set_blocking(stdout_fd, False) + poll = select.poll() + poll.register(stdin_fd, select.POLLIN | basic_events) + poll.register(read_signal_fd, select.POLLIN | basic_events) + input_buf = output_buf = child_death_buf = b'' + child_ready_fds: Dict[int, int] = {} + child_id_map: Dict[int, int] = {} + self_pid = os.getpid() + + def check_event(event: int, err_msg: str) -> None: + if event & hangup_events: + raise SystemExit(0) + if event & error_events: + raise SystemExit(err_msg) + + def handle_input(event: int) -> None: + nonlocal input_buf, output_buf + check_event(event, 'Polling of STDIN failed') + if not (event & select.POLLIN): + return + d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE) + if not d: + raise SystemExit(0) + input_buf += d + while (idx := input_buf.find(b'\n')) > -1: + line = input_buf[:idx].decode() + input_buf = input_buf[idx+1:] + cmd, _, payload = line.partition(':') + if cmd == 'reload_kitty_config': + reload_kitty_config() + elif cmd == 'ready': + child_id = int(payload) + cfd = child_ready_fds.pop(child_id) + if cfd is not None: + os.write(cfd, b'1') + os.close(cfd) + elif cmd == 'fork': + read_fd, write_fd = safe_pipe(False) + remove_cloexec(read_fd) + try: + child_pid = fork(payload, read_fd) + except Exception as e: + es = str(e).replace('\n', ' ') + output_buf += f'ERR:{es}\n'.encode() + else: + child_id = len(child_id_map) + 1 + child_id_map[child_id] = child_pid + child_ready_fds[child_id] = write_fd + output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode() + finally: + if os.getpid() == self_pid: + os.close(read_fd) + elif cmd == 'echo': + output_buf += f'{payload}\n'.encode() + + def handle_output(event: int) -> None: + nonlocal output_buf + check_event(event, 'Polling of STDOUT failed') + if not (event & select.POLLOUT): + return + if output_buf: + n = os.write(stdout_fd, output_buf) + if not n: + raise SystemExit(0) + output_buf = output_buf[n:] + if not output_buf: + poll.unregister(stdout_fd) + + def handle_notify_child_death(event: int) -> None: + nonlocal child_death_buf + check_event(event, 'Polling of notify child death fd failed') + if not (event & select.POLLOUT): + return + if child_death_buf: + n = os.write(notify_child_death_fd, child_death_buf) + if not n: + raise SystemExit(0) + child_death_buf = child_death_buf[n:] + if not child_death_buf: + poll.unregister(notify_child_death_fd) + + def handle_signal(event: int) -> None: + nonlocal child_death_buf + check_event(event, 'Polling of signal fd failed') + if not (event & select.POLLIN): + return + d = os.read(read_signal_fd, io.DEFAULT_BUFFER_SIZE) + if not d: + raise SystemExit(0) + signals = set(bytearray(d)) + if signal.SIGCHLD in signals: + while True: + try: + pid, exit_status = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + break + matched_child_id = -1 + for child_id, child_pid in child_id_map.items(): + if child_pid == pid: + matched_child_id = child_id + break + if matched_child_id > -1: + del child_id_map[matched_child_id] + child_ready_fds.pop(matched_child_id, None) + child_death_buf += f'{pid}\n'.encode() + try: + while True: + if output_buf: + poll.register(stdout_fd, select.POLLOUT | basic_events) + if child_death_buf: + poll.register(notify_child_death_fd, select.POLLOUT | basic_events) + for (q, event) in poll.poll(): + if q == stdin_fd: + handle_input(event) + elif q == stdout_fd: + handle_output(event) + elif q == read_signal_fd: + handle_signal(event) + elif q == notify_child_death_fd: + handle_notify_child_death(event) + except (KeyboardInterrupt, EOFError, BrokenPipeError): + if os.getpid() == self_pid: + raise SystemExit(1) + raise + except Exception: + if os.getpid() == self_pid: + import traceback + traceback.print_exc() + raise + finally: + if os.getpid() == self_pid: + for fmd in child_ready_fds.values(): + with suppress(OSError): + os.close(fmd) + + +if __name__ == '__main__': + main() diff --git a/kitty_tests/__init__.py b/kitty_tests/__init__.py index 25728b31b..1d7c70231 100644 --- a/kitty_tests/__init__.py +++ b/kitty_tests/__init__.py @@ -177,7 +177,7 @@ class BaseTest(TestCase): s = Screen(c, lines, cols, scrollback, cell_width, cell_height, 0, c) return s - def create_pty(self, argv, cols=80, lines=100, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None): + def create_pty(self, argv=None, cols=80, lines=100, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None): self.set_options(options) return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd, env) @@ -228,9 +228,11 @@ class PTY: def __del__(self): if not self.is_child: - fd = self.master_fd + os.close(self.master_fd) + if hasattr(self, 'slave_fd'): + os.close(self.slave_fd) + del self.slave_fd del self.master_fd - os.close(fd) def write_to_child(self, data): if isinstance(data, str): diff --git a/kitty_tests/prewarm.py b/kitty_tests/prewarm.py new file mode 100644 index 000000000..2a62f5895 --- /dev/null +++ b/kitty_tests/prewarm.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# License: GPLv3 Copyright: 2022, Kovid Goyal + + +import json +import os +import tempfile + +from kitty.constants import kitty_exe + +from . import BaseTest + + +class Prewarm(BaseTest): + + maxDiff = None + + def test_prewarming(self): + from kittens.prewarm.main import PrewarmProcess + + p = PrewarmProcess(create_file_to_read_from_worker=True) + cwd = tempfile.gettempdir() + env = {'TEST_ENV_PASS': 'xyz'} + cols = 117 + stdin_data = 'from_stdin' + pty = self.create_pty(cols=cols) + ttyname = os.ttyname(pty.slave_fd) + child = p(pty.slave_fd, [kitty_exe(), '+runpy', """import os, json; from kitty.utils import *; print(json.dumps({ + 'cterm': os.ctermid(), + 'ttyname': os.ttyname(sys.stdout.fileno()), + 'cols': read_screen_size().cols, + 'cwd': os.getcwd(), + 'env': os.environ.get('TEST_ENV_PASS'), + 'pid': os.getpid(), + 'stdin': sys.stdin.read(), + + 'done': 'hello', + }, indent=2))"""], cwd=cwd, env=env, stdin_data=stdin_data) + self.assertFalse(pty.screen_contents().strip()) + p.mark_child_as_ready(child.child_id) + pty.wait_till(lambda: 'hello' in pty.screen_contents()) + data = json.loads(pty.screen_contents()) + self.ae(data['cols'], cols) + self.assertTrue(data['cterm']) + self.ae(data['ttyname'], ttyname) + self.ae(data['cwd'], cwd) + self.ae(data['env'], env['TEST_ENV_PASS']) + self.ae(int(p.from_worker.readline()), data['pid'])