Add basic tests for socket prewarm
This commit is contained in:
parent
d1b028c27a
commit
7b7f1ecc54
@ -251,7 +251,7 @@ class Child:
|
||||
env['COLORTERM'] = 'truecolor'
|
||||
env['KITTY_PID'] = getpid()
|
||||
if not self.is_prewarmed:
|
||||
env['KITTY_PREWARM_SOCKET'] = f'{os.geteuid()}:{os.getegid()}:{fast_data_types.get_boss().prewarm.unix_socket_name}'
|
||||
env['KITTY_PREWARM_SOCKET'] = fast_data_types.get_boss().prewarm.socket_env_var()
|
||||
if self.cwd:
|
||||
# needed in case cwd is a symlink, in which case shells
|
||||
# can use it to display the current directory name rather
|
||||
|
||||
@ -15,7 +15,7 @@ from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from itertools import count
|
||||
from typing import (
|
||||
IO, TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, NoReturn, Optional,
|
||||
IO, TYPE_CHECKING, Any, Callable, Dict, Iterator, List, NoReturn, Optional,
|
||||
Tuple, Union, cast
|
||||
)
|
||||
|
||||
@ -85,6 +85,9 @@ class PrewarmProcess:
|
||||
self.poll.register(self.read_from_process_fd, select.POLLIN)
|
||||
self.unix_socket_name = unix_socket_name
|
||||
|
||||
def socket_env_var(self) -> str:
|
||||
return f'{os.geteuid()}:{os.getegid()}:{self.unix_socket_name}'
|
||||
|
||||
def take_from_worker_fd(self, create_file: bool = False) -> int:
|
||||
if create_file:
|
||||
os.set_blocking(self.from_prewarm_death_notify, True)
|
||||
@ -293,7 +296,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
|
||||
raise SystemExit(0)
|
||||
|
||||
|
||||
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
||||
def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tuple[int, int]:
|
||||
global is_zygote
|
||||
sz = pos = 0
|
||||
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
|
||||
@ -331,9 +334,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
||||
remove_signal_handlers()
|
||||
os.close(r)
|
||||
os.close(ready_fd_write)
|
||||
for fd in all_non_child_fds:
|
||||
if fd > -1:
|
||||
os.close(fd)
|
||||
free_non_child_resources()
|
||||
os.setsid()
|
||||
tty_name = cmd.get('tty_name')
|
||||
if tty_name:
|
||||
@ -363,12 +364,11 @@ class SocketClosed(Exception):
|
||||
class SocketChild:
|
||||
|
||||
def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll):
|
||||
self.fd = conn.fileno()
|
||||
poll.register(self.fd, select.POLLIN)
|
||||
self.registered = True
|
||||
self.poll = poll
|
||||
self.addr = addr
|
||||
self.conn = conn
|
||||
self.poll.register(self.conn.fileno(), select.POLLIN)
|
||||
self.input_buf = self.output_buf = b''
|
||||
self.fds: List[int] = []
|
||||
self.child_id = -1
|
||||
@ -377,10 +377,13 @@ class SocketChild:
|
||||
self.argv: List[str] = []
|
||||
self.stdin = self.stdout = self.stderr = -1
|
||||
self.pid = -1
|
||||
self.closed = False
|
||||
|
||||
def unregister_from_poll(self) -> None:
|
||||
if self.registered:
|
||||
self.poll.unregister(self.fd)
|
||||
fd = self.conn.fileno()
|
||||
if fd > -1:
|
||||
self.poll.unregister(self.conn.fileno())
|
||||
self.registered = False
|
||||
|
||||
def read(self) -> bool:
|
||||
@ -426,7 +429,7 @@ class SocketChild:
|
||||
|
||||
return False
|
||||
|
||||
def fork(self, all_non_child_fds: Iterable[int]) -> None:
|
||||
def fork(self, free_non_child_resources: Callable[[], None]) -> None:
|
||||
global is_zygote
|
||||
r, w = safe_pipe()
|
||||
self.pid = os.fork()
|
||||
@ -470,9 +473,7 @@ class SocketChild:
|
||||
if self.stderr > -1:
|
||||
os.dup2(self.stderr, sys.__stderr__.fileno())
|
||||
os.close(w)
|
||||
for fd in all_non_child_fds:
|
||||
if fd > -1:
|
||||
os.close(fd)
|
||||
free_non_child_resources()
|
||||
child_main({'cwd': self.cwd, 'env': self.env, 'argv': self.argv})
|
||||
raise SystemExit(0)
|
||||
|
||||
@ -499,6 +500,23 @@ class SocketChild:
|
||||
return False
|
||||
return True
|
||||
|
||||
def close(self) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
self.unregister_from_poll()
|
||||
self.closed = True
|
||||
self.conn.close()
|
||||
if self.stdin > -1:
|
||||
os.close(self.stdin)
|
||||
self.stdin = -1
|
||||
if self.stdout > -1:
|
||||
os.close(self.stdout)
|
||||
self.stdout = -1
|
||||
if self.stderr > -1:
|
||||
os.close(self.stderr)
|
||||
self.stderr = -1
|
||||
__del__ = close
|
||||
|
||||
|
||||
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
|
||||
global parent_tty_name
|
||||
@ -525,10 +543,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
||||
prewarm()
|
||||
|
||||
def remove_socket_child(sc: SocketChild) -> None:
|
||||
socket_children.pop(sc.fd, None)
|
||||
sc.unregister_from_poll()
|
||||
socket_children.pop(sc.conn.fileno(), None)
|
||||
socket_pid_map.pop(sc.pid, None)
|
||||
sc.conn.close()
|
||||
sc.close()
|
||||
|
||||
def get_all_non_child_fds() -> Iterator[int]:
|
||||
yield notify_child_death_fd
|
||||
@ -536,14 +553,14 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
||||
yield stdout_fd
|
||||
# the signal fds are closed by remove_signal_handlers()
|
||||
yield from child_ready_fds.values()
|
||||
for sc in socket_children.values():
|
||||
yield sc.fd
|
||||
if sc.stdin > -1:
|
||||
yield sc.stdin
|
||||
if sc.stdout > -1:
|
||||
yield sc.stdout
|
||||
if sc.stderr > -1:
|
||||
yield sc.stderr
|
||||
|
||||
def free_non_child_resources() -> None:
|
||||
for fd in get_all_non_child_fds():
|
||||
if fd > -1:
|
||||
os.close(fd)
|
||||
unix_socket.close()
|
||||
for sc in tuple(socket_children.values()):
|
||||
remove_socket_child(sc)
|
||||
|
||||
def check_event(event: int, err_msg: str) -> None:
|
||||
if event & select.POLLHUP:
|
||||
@ -576,7 +593,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
||||
raise SystemExit(0)
|
||||
elif cmd == 'fork':
|
||||
try:
|
||||
child_pid, ready_fd_write = fork(payload, get_all_non_child_fds())
|
||||
child_pid, ready_fd_write = fork(payload, free_non_child_resources)
|
||||
except Exception as e:
|
||||
es = str(e).replace('\n', ' ')
|
||||
output_buf += f'ERR:{es}\n'.encode()
|
||||
@ -654,7 +671,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
||||
check_event(event, 'UNIX socket fd listener failed')
|
||||
conn, addr = unix_socket.accept()
|
||||
sc = SocketChild(conn, addr, poll)
|
||||
socket_children[sc.fd] = sc
|
||||
socket_children[sc.conn.fileno()] = sc
|
||||
|
||||
def handle_socket_launch(fd: int, event: int) -> None:
|
||||
scq = socket_children.get(q)
|
||||
@ -664,7 +681,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
||||
try:
|
||||
if scq.read():
|
||||
scq.unregister_from_poll()
|
||||
scq.fork(get_all_non_child_fds())
|
||||
scq.fork(free_non_child_resources)
|
||||
socket_pid_map[scq.pid] = scq
|
||||
scq.child_id = next(child_id_counter)
|
||||
except SocketClosed:
|
||||
@ -714,6 +731,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
||||
for fmd in child_ready_fds.values():
|
||||
with suppress(OSError):
|
||||
os.close(fmd)
|
||||
for sc in tuple(socket_children.values()):
|
||||
remove_socket_child(sc)
|
||||
|
||||
|
||||
def get_socket_name(unix_socket: socket.socket) -> str:
|
||||
@ -743,7 +762,8 @@ def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int, unix_
|
||||
main(stdin_read, stdout_write, death_notify_write, unix_socket)
|
||||
finally:
|
||||
set_options(None)
|
||||
unix_socket.close()
|
||||
if is_zygote:
|
||||
unix_socket.close()
|
||||
|
||||
|
||||
def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[PrewarmProcess]:
|
||||
|
||||
@ -204,8 +204,8 @@ class PTY:
|
||||
self.master_fd, self.slave_fd = openpty()
|
||||
self.is_child = False
|
||||
else:
|
||||
pid, self.master_fd = fork()
|
||||
self.is_child = pid == CHILD
|
||||
self.child_pid, self.master_fd = fork()
|
||||
self.is_child = self.child_pid == CHILD
|
||||
if self.is_child:
|
||||
while read_screen_size().width != columns * cell_width:
|
||||
time.sleep(0.01)
|
||||
|
||||
@ -4,13 +4,15 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import suppress
|
||||
|
||||
from kitty.constants import kitty_exe
|
||||
from kitty.constants import kitty_exe, read_kitty_resource
|
||||
from kitty.fast_data_types import (
|
||||
CLD_EXITED, CLD_KILLED, get_options, has_sigqueue, install_signal_handlers,
|
||||
read_signals, remove_signal_handlers, sigqueue
|
||||
@ -19,10 +21,55 @@ from kitty.fast_data_types import (
|
||||
from . import BaseTest
|
||||
|
||||
|
||||
def socket_child_main(exit_code=0):
|
||||
import json
|
||||
import os
|
||||
from kitty.fast_data_types import get_options
|
||||
from kitty.utils import read_screen_size
|
||||
output = {
|
||||
'test_env': os.environ.get('TEST_ENV_PASS', ''),
|
||||
'cwd': os.getcwd(),
|
||||
'font_family': get_options().font_family,
|
||||
'cols': read_screen_size().cols,
|
||||
|
||||
'done': 'hello',
|
||||
}
|
||||
print(json.dumps(output, indent=2))
|
||||
raise SystemExit(exit_code)
|
||||
|
||||
# END_socket_child_main
|
||||
|
||||
|
||||
class Prewarm(BaseTest):
|
||||
|
||||
maxDiff = None
|
||||
|
||||
def test_socket_prewarming(self):
|
||||
from kitty.prewarm import fork_prewarm_process
|
||||
exit_code = 17
|
||||
src = re.search(
|
||||
r'^(def socket_child_main.+?)^# END_socket_child_main', read_kitty_resource('prewarm.py', 'kitty_tests').decode(),
|
||||
flags=re.M | re.DOTALL).group(1) + '\n\n'
|
||||
|
||||
cwd = tempfile.gettempdir()
|
||||
opts = self.set_options()
|
||||
opts.config_overrides = 'font_family prewarm',
|
||||
p = fork_prewarm_process(opts, use_exec=True)
|
||||
if p is None:
|
||||
return
|
||||
env = {'TEST_ENV_PASS': 'xyz', 'KITTY_PREWARM_SOCKET': p.socket_env_var()}
|
||||
cols = 117
|
||||
pty = self.create_pty(argv=[kitty_exe(), '+runpy', src + f'socket_child_main({exit_code})'], cols=cols, env=env, cwd=cwd)
|
||||
status = os.waitpid(pty.child_pid, 0)[1]
|
||||
with suppress(AttributeError):
|
||||
self.assertEqual(os.waitstatus_to_exitcode(status), exit_code)
|
||||
pty.wait_till(lambda: 'hello' in pty.screen_contents())
|
||||
output = json.loads(pty.screen_contents().strip())
|
||||
self.assertEqual(output['test_env'], env['TEST_ENV_PASS'])
|
||||
self.assertEqual(output['cwd'], cwd)
|
||||
self.assertEqual(output['font_family'], 'prewarm')
|
||||
self.assertEqual(output['cols'], cols)
|
||||
|
||||
def test_prewarming(self):
|
||||
from kitty.prewarm import fork_prewarm_process
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user