Start work on prewarming
The prewarm process and its controller are implemented with some basic tests.
This commit is contained in:
parent
dec62b1929
commit
98f46f8bd7
0
kittens/prewarm/__init__.py
Normal file
0
kittens/prewarm/__init__.py
Normal file
454
kittens/prewarm/main.py
Normal file
454
kittens/prewarm/main.py
Normal file
@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from typing import (
|
||||
IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast
|
||||
)
|
||||
|
||||
from kitty.child import remove_cloexec
|
||||
from kitty.constants import kitty_exe
|
||||
from kitty.entry_points import main as main_entry_point
|
||||
from kitty.fast_data_types import establish_controlling_tty, safe_pipe
|
||||
from kitty.shm import SharedMemory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import ReadableBuffer, WriteableBuffer
|
||||
|
||||
|
||||
hangup_events = select.POLLHUP
|
||||
error_events = select.POLLERR | select.POLLNVAL
|
||||
basic_events = hangup_events | error_events
|
||||
|
||||
|
||||
class PrewarmProcessFailed(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Child:
|
||||
child_id: int
|
||||
child_process_pid: int
|
||||
|
||||
|
||||
class PrewarmProcess:
|
||||
|
||||
def __init__(self, create_file_to_read_from_worker: bool = False) -> None:
|
||||
self.from_worker_fd, self.in_worker_fd = safe_pipe()
|
||||
self.children: Dict[int, Child] = {}
|
||||
if create_file_to_read_from_worker:
|
||||
os.set_blocking(self.from_worker_fd, True)
|
||||
self.from_worker = open(self.from_worker_fd, mode='r', closefd=True)
|
||||
self.from_worker_fd = -1
|
||||
|
||||
def take_from_worker_fd(self) -> int:
|
||||
ans, self.from_worker_fd = self.from_worker_fd, -1
|
||||
return ans
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.from_worker_fd > -1:
|
||||
os.close(self.from_worker_fd)
|
||||
self.from_worker_fd = -1
|
||||
if hasattr(self, 'from_worker'):
|
||||
self.from_worker.close()
|
||||
del self.from_worker
|
||||
if self.worker_started:
|
||||
import subprocess
|
||||
self.process.stdin and self.process.stdin.close()
|
||||
self.process.stdout and self.process.stdout.close()
|
||||
try:
|
||||
self.process.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
del self.process
|
||||
|
||||
@property
|
||||
def worker_started(self) -> bool:
|
||||
return self.in_worker_fd == -1
|
||||
|
||||
def ensure_worker(self) -> None:
|
||||
if not self.worker_started:
|
||||
import subprocess
|
||||
self.process = subprocess.Popen(
|
||||
[kitty_exe(), '+kitten', 'prewarm', str(self.in_worker_fd)], stdin=subprocess.PIPE, stdout=subprocess.PIPE, pass_fds=(self.in_worker_fd,))
|
||||
os.close(self.in_worker_fd)
|
||||
self.in_worker_fd = -1
|
||||
assert self.process.stdin is not None and self.process.stdout is not None
|
||||
self.write_to_process_fd = self.process.stdin.fileno()
|
||||
self.read_from_process_fd = self.process.stdout.fileno()
|
||||
os.set_blocking(self.write_to_process_fd, False)
|
||||
os.set_blocking(self.read_from_process_fd, False)
|
||||
self.poll = select.poll()
|
||||
self.poll.register(self.process.stdout.fileno(), select.POLLIN | basic_events)
|
||||
|
||||
def poll_to_send(self, yes: bool = True) -> None:
|
||||
if yes:
|
||||
self.poll.register(self.write_to_process_fd, select.POLLOUT | basic_events)
|
||||
else:
|
||||
self.poll.unregister(self.write_to_process_fd)
|
||||
|
||||
def reload_kitty_config(self) -> None:
|
||||
if self.worker_started:
|
||||
self.send_to_prewarm_process('reload_kitty_config:\n')
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tty_fd: int,
|
||||
argv: List[str],
|
||||
cwd: str = '',
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
stdin_data: Optional[Union[str, bytes]] = None
|
||||
) -> Child:
|
||||
self.ensure_worker()
|
||||
tty_name = os.ttyname(tty_fd)
|
||||
if isinstance(stdin_data, str):
|
||||
stdin_data = stdin_data.encode()
|
||||
if env is None:
|
||||
env = dict(os.environ)
|
||||
cmd: Dict[str, Union[int, List[str], str, Dict[str, str]]] = {
|
||||
'tty_name': tty_name, 'cwd': cwd or os.getcwd(), 'argv': argv, 'env': env,
|
||||
}
|
||||
total_size = 0
|
||||
if stdin_data is not None:
|
||||
cmd['stdin_size'] = len(stdin_data)
|
||||
total_size += len(stdin_data)
|
||||
data = json.dumps(cmd).encode()
|
||||
total_size += len(data) + SharedMemory.num_bytes_for_size
|
||||
with SharedMemory(size=total_size, unlink_on_exit=True) as shm:
|
||||
shm.write_data_with_size(data)
|
||||
if stdin_data:
|
||||
shm.write(stdin_data)
|
||||
shm.flush()
|
||||
self.send_to_prewarm_process(f'fork:{shm.name}\n')
|
||||
input_buf = b''
|
||||
st = time.monotonic()
|
||||
while time.monotonic() - st < 2:
|
||||
for (fd, event) in self.poll.poll(0.2):
|
||||
if event & basic_events:
|
||||
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
|
||||
if fd == self.read_from_process_fd and event & select.POLLIN:
|
||||
d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE)
|
||||
input_buf += d
|
||||
while (idx := input_buf.find(b'\n')) > -1:
|
||||
line = input_buf[:idx].decode()
|
||||
input_buf = input_buf[idx+1:]
|
||||
if line.startswith('CHILD:'):
|
||||
_, cid, pid = line.split(':')
|
||||
child = self.add_child(int(cid), int(pid))
|
||||
shm.unlink_on_exit = False
|
||||
return child
|
||||
if line.startswith('ERR:'):
|
||||
raise PrewarmProcessFailed(line.split(':', 1)[-1])
|
||||
raise PrewarmProcessFailed('Timed out waiting for I/O with prewarm process')
|
||||
|
||||
def add_child(self, child_id: int, pid: int) -> Child:
|
||||
self.children[child_id] = c = Child(child_id, pid)
|
||||
return c
|
||||
|
||||
def send_to_prewarm_process(self, output_buf: Union[str, bytes] = b'', timeout: float = 2) -> None:
|
||||
if isinstance(output_buf, str):
|
||||
output_buf = output_buf.encode()
|
||||
st = time.monotonic()
|
||||
while time.monotonic() - st < timeout and output_buf:
|
||||
self.poll_to_send(bool(output_buf))
|
||||
for (fd, event) in self.poll.poll(0.2):
|
||||
if event & basic_events:
|
||||
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
|
||||
if fd == self.write_to_process_fd and event & select.POLLOUT:
|
||||
n = os.write(self.write_to_process_fd, output_buf)
|
||||
output_buf = output_buf[n:]
|
||||
self.poll_to_send(False)
|
||||
if output_buf:
|
||||
raise PrewarmProcessFailed('Timed out waiting to write to prewarm process')
|
||||
|
||||
def mark_child_as_ready(self, child_id: int) -> bool:
|
||||
c = self.children.pop(child_id, None)
|
||||
if c is None:
|
||||
return False
|
||||
self.send_to_prewarm_process(f'ready:{child_id}\n')
|
||||
return True
|
||||
|
||||
|
||||
def reload_kitty_config() -> None:
|
||||
from kittens.tui.utils import kitty_opts
|
||||
kitty_opts.clear_cached()
|
||||
kitty_opts()
|
||||
|
||||
|
||||
def prewarm() -> None:
|
||||
reload_kitty_config()
|
||||
for kitten in ('hints', 'ssh', 'unicode_input', 'ask', 'show_error'):
|
||||
import_module(f'kittens.{kitten}.main')
|
||||
|
||||
|
||||
class MemoryViewReadWrapperBytes(io.BufferedIOBase):
|
||||
|
||||
def __init__(self, mw: memoryview):
|
||||
self.mw = mw
|
||||
self.pos = 0
|
||||
|
||||
def detach(self) -> io.RawIOBase:
|
||||
raise io.UnsupportedOperation('detach() not supported')
|
||||
|
||||
def read(self, size: Optional[int] = -1) -> bytes:
|
||||
if size is None or size < 0:
|
||||
size = max(0, len(self.mw) - self.pos)
|
||||
oldpos = self.pos
|
||||
self.pos = min(len(self.mw), self.pos + size)
|
||||
if self.pos <= oldpos:
|
||||
return b''
|
||||
return bytes(self.mw[oldpos:self.pos])
|
||||
|
||||
def readinto(self, b: 'WriteableBuffer') -> int:
|
||||
if not isinstance(b, memoryview):
|
||||
b = memoryview(b)
|
||||
b = b.cast('B')
|
||||
data = self.read(len(b))
|
||||
n = len(data)
|
||||
b[:n] = data
|
||||
return n
|
||||
readinto1 = readinto
|
||||
|
||||
def readall(self) -> bytes:
|
||||
return self.read()
|
||||
|
||||
def write(self, b: 'ReadableBuffer') -> int:
|
||||
raise io.UnsupportedOperation('readonly stream')
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class MemoryViewReadWrapper(io.TextIOWrapper):
|
||||
|
||||
def __init__(self, mw: memoryview):
|
||||
super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace')
|
||||
|
||||
|
||||
def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
||||
cwd = cmd.get('cwd')
|
||||
if cwd:
|
||||
try:
|
||||
os.chdir(cwd)
|
||||
except OSError:
|
||||
with suppress(OSError):
|
||||
os.chdir('/')
|
||||
os.setsid()
|
||||
env = cmd.get('env')
|
||||
if env is not None:
|
||||
os.environ.clear()
|
||||
os.environ.update(env)
|
||||
argv = cmd.get('argv')
|
||||
if argv:
|
||||
sys.argv = list(argv)
|
||||
poll = select.poll()
|
||||
poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP)
|
||||
poll.poll()
|
||||
os.close(ready_fd)
|
||||
tty_name = cmd.get('tty_name')
|
||||
if tty_name:
|
||||
sys.__stdout__.flush()
|
||||
sys.__stderr__.flush()
|
||||
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
||||
main_entry_point()
|
||||
raise SystemExit(0)
|
||||
|
||||
|
||||
def fork(shm_address: str, ready_fd: int) -> int:
|
||||
sz = pos = 0
|
||||
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
|
||||
data = shm.read_data_with_size()
|
||||
cmd = json.loads(data)
|
||||
sz = cmd.get('stdin_size', 0)
|
||||
if sz:
|
||||
pos = shm.tell()
|
||||
shm.unlink_on_exit = False
|
||||
|
||||
try:
|
||||
child_pid = os.fork()
|
||||
except OSError:
|
||||
if sz:
|
||||
with SharedMemory(shm_address, unlink_on_exit=True):
|
||||
pass
|
||||
if child_pid:
|
||||
# master process
|
||||
return child_pid
|
||||
# child process
|
||||
if shm.unlink_on_exit:
|
||||
child_main(cmd, ready_fd)
|
||||
else:
|
||||
with SharedMemory(shm_address, unlink_on_exit=True) as shm:
|
||||
stdin_data = memoryview(shm.mmap)[pos:pos + sz]
|
||||
if stdin_data:
|
||||
sys.stdin = MemoryViewReadWrapper(stdin_data)
|
||||
try:
|
||||
child_main(cmd, ready_fd)
|
||||
finally:
|
||||
stdin_data.release()
|
||||
sys.stdin = sys.__stdin__
|
||||
|
||||
|
||||
def waitstatus_to_exit_code(status: int) -> int:
|
||||
with suppress(ValueError, AttributeError):
|
||||
return os.waitstatus_to_exitcode(status)
|
||||
return 0
|
||||
|
||||
|
||||
def main(args: List[str] = sys.argv) -> None:
|
||||
read_signal_fd, write_signal_fd = safe_pipe()
|
||||
notify_child_death_fd = int(sys.argv[-1])
|
||||
os.set_blocking(notify_child_death_fd, False)
|
||||
signal.set_wakeup_fd(write_signal_fd)
|
||||
signal.signal(signal.SIGCHLD, lambda *a: None)
|
||||
signal.siginterrupt(signal.SIGCHLD, False)
|
||||
prewarm()
|
||||
stdin_fd = sys.__stdin__.fileno()
|
||||
os.set_blocking(stdin_fd, False)
|
||||
stdout_fd = sys.__stdout__.fileno()
|
||||
os.set_blocking(stdout_fd, False)
|
||||
poll = select.poll()
|
||||
poll.register(stdin_fd, select.POLLIN | basic_events)
|
||||
poll.register(read_signal_fd, select.POLLIN | basic_events)
|
||||
input_buf = output_buf = child_death_buf = b''
|
||||
child_ready_fds: Dict[int, int] = {}
|
||||
child_id_map: Dict[int, int] = {}
|
||||
self_pid = os.getpid()
|
||||
|
||||
def check_event(event: int, err_msg: str) -> None:
|
||||
if event & hangup_events:
|
||||
raise SystemExit(0)
|
||||
if event & error_events:
|
||||
raise SystemExit(err_msg)
|
||||
|
||||
def handle_input(event: int) -> None:
|
||||
nonlocal input_buf, output_buf
|
||||
check_event(event, 'Polling of STDIN failed')
|
||||
if not (event & select.POLLIN):
|
||||
return
|
||||
d = os.read(stdin_fd, io.DEFAULT_BUFFER_SIZE)
|
||||
if not d:
|
||||
raise SystemExit(0)
|
||||
input_buf += d
|
||||
while (idx := input_buf.find(b'\n')) > -1:
|
||||
line = input_buf[:idx].decode()
|
||||
input_buf = input_buf[idx+1:]
|
||||
cmd, _, payload = line.partition(':')
|
||||
if cmd == 'reload_kitty_config':
|
||||
reload_kitty_config()
|
||||
elif cmd == 'ready':
|
||||
child_id = int(payload)
|
||||
cfd = child_ready_fds.pop(child_id)
|
||||
if cfd is not None:
|
||||
os.write(cfd, b'1')
|
||||
os.close(cfd)
|
||||
elif cmd == 'fork':
|
||||
read_fd, write_fd = safe_pipe(False)
|
||||
remove_cloexec(read_fd)
|
||||
try:
|
||||
child_pid = fork(payload, read_fd)
|
||||
except Exception as e:
|
||||
es = str(e).replace('\n', ' ')
|
||||
output_buf += f'ERR:{es}\n'.encode()
|
||||
else:
|
||||
child_id = len(child_id_map) + 1
|
||||
child_id_map[child_id] = child_pid
|
||||
child_ready_fds[child_id] = write_fd
|
||||
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
|
||||
finally:
|
||||
if os.getpid() == self_pid:
|
||||
os.close(read_fd)
|
||||
elif cmd == 'echo':
|
||||
output_buf += f'{payload}\n'.encode()
|
||||
|
||||
def handle_output(event: int) -> None:
|
||||
nonlocal output_buf
|
||||
check_event(event, 'Polling of STDOUT failed')
|
||||
if not (event & select.POLLOUT):
|
||||
return
|
||||
if output_buf:
|
||||
n = os.write(stdout_fd, output_buf)
|
||||
if not n:
|
||||
raise SystemExit(0)
|
||||
output_buf = output_buf[n:]
|
||||
if not output_buf:
|
||||
poll.unregister(stdout_fd)
|
||||
|
||||
def handle_notify_child_death(event: int) -> None:
|
||||
nonlocal child_death_buf
|
||||
check_event(event, 'Polling of notify child death fd failed')
|
||||
if not (event & select.POLLOUT):
|
||||
return
|
||||
if child_death_buf:
|
||||
n = os.write(notify_child_death_fd, child_death_buf)
|
||||
if not n:
|
||||
raise SystemExit(0)
|
||||
child_death_buf = child_death_buf[n:]
|
||||
if not child_death_buf:
|
||||
poll.unregister(notify_child_death_fd)
|
||||
|
||||
def handle_signal(event: int) -> None:
|
||||
nonlocal child_death_buf
|
||||
check_event(event, 'Polling of signal fd failed')
|
||||
if not (event & select.POLLIN):
|
||||
return
|
||||
d = os.read(read_signal_fd, io.DEFAULT_BUFFER_SIZE)
|
||||
if not d:
|
||||
raise SystemExit(0)
|
||||
signals = set(bytearray(d))
|
||||
if signal.SIGCHLD in signals:
|
||||
while True:
|
||||
try:
|
||||
pid, exit_status = os.waitpid(-1, os.WNOHANG)
|
||||
except ChildProcessError:
|
||||
break
|
||||
matched_child_id = -1
|
||||
for child_id, child_pid in child_id_map.items():
|
||||
if child_pid == pid:
|
||||
matched_child_id = child_id
|
||||
break
|
||||
if matched_child_id > -1:
|
||||
del child_id_map[matched_child_id]
|
||||
child_ready_fds.pop(matched_child_id, None)
|
||||
child_death_buf += f'{pid}\n'.encode()
|
||||
try:
|
||||
while True:
|
||||
if output_buf:
|
||||
poll.register(stdout_fd, select.POLLOUT | basic_events)
|
||||
if child_death_buf:
|
||||
poll.register(notify_child_death_fd, select.POLLOUT | basic_events)
|
||||
for (q, event) in poll.poll():
|
||||
if q == stdin_fd:
|
||||
handle_input(event)
|
||||
elif q == stdout_fd:
|
||||
handle_output(event)
|
||||
elif q == read_signal_fd:
|
||||
handle_signal(event)
|
||||
elif q == notify_child_death_fd:
|
||||
handle_notify_child_death(event)
|
||||
except (KeyboardInterrupt, EOFError, BrokenPipeError):
|
||||
if os.getpid() == self_pid:
|
||||
raise SystemExit(1)
|
||||
raise
|
||||
except Exception:
|
||||
if os.getpid() == self_pid:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
if os.getpid() == self_pid:
|
||||
for fmd in child_ready_fds.values():
|
||||
with suppress(OSError):
|
||||
os.close(fmd)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -177,7 +177,7 @@ class BaseTest(TestCase):
|
||||
s = Screen(c, lines, cols, scrollback, cell_width, cell_height, 0, c)
|
||||
return s
|
||||
|
||||
def create_pty(self, argv, cols=80, lines=100, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None):
|
||||
def create_pty(self, argv=None, cols=80, lines=100, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None):
|
||||
self.set_options(options)
|
||||
return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd, env)
|
||||
|
||||
@ -228,9 +228,11 @@ class PTY:
|
||||
|
||||
def __del__(self):
|
||||
if not self.is_child:
|
||||
fd = self.master_fd
|
||||
os.close(self.master_fd)
|
||||
if hasattr(self, 'slave_fd'):
|
||||
os.close(self.slave_fd)
|
||||
del self.slave_fd
|
||||
del self.master_fd
|
||||
os.close(fd)
|
||||
|
||||
def write_to_child(self, data):
|
||||
if isinstance(data, str):
|
||||
|
||||
48
kitty_tests/prewarm.py
Normal file
48
kitty_tests/prewarm.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from kitty.constants import kitty_exe
|
||||
|
||||
from . import BaseTest
|
||||
|
||||
|
||||
class Prewarm(BaseTest):
|
||||
|
||||
maxDiff = None
|
||||
|
||||
def test_prewarming(self):
|
||||
from kittens.prewarm.main import PrewarmProcess
|
||||
|
||||
p = PrewarmProcess(create_file_to_read_from_worker=True)
|
||||
cwd = tempfile.gettempdir()
|
||||
env = {'TEST_ENV_PASS': 'xyz'}
|
||||
cols = 117
|
||||
stdin_data = 'from_stdin'
|
||||
pty = self.create_pty(cols=cols)
|
||||
ttyname = os.ttyname(pty.slave_fd)
|
||||
child = p(pty.slave_fd, [kitty_exe(), '+runpy', """import os, json; from kitty.utils import *; print(json.dumps({
|
||||
'cterm': os.ctermid(),
|
||||
'ttyname': os.ttyname(sys.stdout.fileno()),
|
||||
'cols': read_screen_size().cols,
|
||||
'cwd': os.getcwd(),
|
||||
'env': os.environ.get('TEST_ENV_PASS'),
|
||||
'pid': os.getpid(),
|
||||
'stdin': sys.stdin.read(),
|
||||
|
||||
'done': 'hello',
|
||||
}, indent=2))"""], cwd=cwd, env=env, stdin_data=stdin_data)
|
||||
self.assertFalse(pty.screen_contents().strip())
|
||||
p.mark_child_as_ready(child.child_id)
|
||||
pty.wait_till(lambda: 'hello' in pty.screen_contents())
|
||||
data = json.loads(pty.screen_contents())
|
||||
self.ae(data['cols'], cols)
|
||||
self.assertTrue(data['cterm'])
|
||||
self.ae(data['ttyname'], ttyname)
|
||||
self.ae(data['cwd'], cwd)
|
||||
self.ae(data['env'], env['TEST_ENV_PASS'])
|
||||
self.ae(int(p.from_worker.readline()), data['pid'])
|
||||
Loading…
x
Reference in New Issue
Block a user