Add basic tests for socket prewarm

This commit is contained in:
Kovid Goyal 2022-07-05 12:49:22 +05:30
parent d1b028c27a
commit 7b7f1ecc54
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 98 additions and 31 deletions

View File

@ -251,7 +251,7 @@ class Child:
env['COLORTERM'] = 'truecolor'
env['KITTY_PID'] = getpid()
if not self.is_prewarmed:
env['KITTY_PREWARM_SOCKET'] = f'{os.geteuid()}:{os.getegid()}:{fast_data_types.get_boss().prewarm.unix_socket_name}'
env['KITTY_PREWARM_SOCKET'] = fast_data_types.get_boss().prewarm.socket_env_var()
if self.cwd:
# needed in case cwd is a symlink, in which case shells
# can use it to display the current directory name rather

View File

@ -15,7 +15,7 @@ from dataclasses import dataclass
from importlib import import_module
from itertools import count
from typing import (
IO, TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, NoReturn, Optional,
IO, TYPE_CHECKING, Any, Callable, Dict, Iterator, List, NoReturn, Optional,
Tuple, Union, cast
)
@ -85,6 +85,9 @@ class PrewarmProcess:
self.poll.register(self.read_from_process_fd, select.POLLIN)
self.unix_socket_name = unix_socket_name
def socket_env_var(self) -> str:
return f'{os.geteuid()}:{os.getegid()}:{self.unix_socket_name}'
def take_from_worker_fd(self, create_file: bool = False) -> int:
if create_file:
os.set_blocking(self.from_prewarm_death_notify, True)
@ -293,7 +296,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
raise SystemExit(0)
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tuple[int, int]:
global is_zygote
sz = pos = 0
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
@ -331,9 +334,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
remove_signal_handlers()
os.close(r)
os.close(ready_fd_write)
for fd in all_non_child_fds:
if fd > -1:
os.close(fd)
free_non_child_resources()
os.setsid()
tty_name = cmd.get('tty_name')
if tty_name:
@ -363,12 +364,11 @@ class SocketClosed(Exception):
class SocketChild:
def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll):
self.fd = conn.fileno()
poll.register(self.fd, select.POLLIN)
self.registered = True
self.poll = poll
self.addr = addr
self.conn = conn
self.poll.register(self.conn.fileno(), select.POLLIN)
self.input_buf = self.output_buf = b''
self.fds: List[int] = []
self.child_id = -1
@ -377,10 +377,13 @@ class SocketChild:
self.argv: List[str] = []
self.stdin = self.stdout = self.stderr = -1
self.pid = -1
self.closed = False
def unregister_from_poll(self) -> None:
if self.registered:
self.poll.unregister(self.fd)
fd = self.conn.fileno()
if fd > -1:
self.poll.unregister(self.conn.fileno())
self.registered = False
def read(self) -> bool:
@ -426,7 +429,7 @@ class SocketChild:
return False
def fork(self, all_non_child_fds: Iterable[int]) -> None:
def fork(self, free_non_child_resources: Callable[[], None]) -> None:
global is_zygote
r, w = safe_pipe()
self.pid = os.fork()
@ -470,9 +473,7 @@ class SocketChild:
if self.stderr > -1:
os.dup2(self.stderr, sys.__stderr__.fileno())
os.close(w)
for fd in all_non_child_fds:
if fd > -1:
os.close(fd)
free_non_child_resources()
child_main({'cwd': self.cwd, 'env': self.env, 'argv': self.argv})
raise SystemExit(0)
@ -499,6 +500,23 @@ class SocketChild:
return False
return True
def close(self) -> None:
if self.closed:
return
self.unregister_from_poll()
self.closed = True
self.conn.close()
if self.stdin > -1:
os.close(self.stdin)
self.stdin = -1
if self.stdout > -1:
os.close(self.stdout)
self.stdout = -1
if self.stderr > -1:
os.close(self.stderr)
self.stderr = -1
__del__ = close
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
global parent_tty_name
@ -525,10 +543,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
prewarm()
def remove_socket_child(sc: SocketChild) -> None:
socket_children.pop(sc.fd, None)
sc.unregister_from_poll()
socket_children.pop(sc.conn.fileno(), None)
socket_pid_map.pop(sc.pid, None)
sc.conn.close()
sc.close()
def get_all_non_child_fds() -> Iterator[int]:
yield notify_child_death_fd
@ -536,14 +553,14 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
yield stdout_fd
# the signal fds are closed by remove_signal_handlers()
yield from child_ready_fds.values()
for sc in socket_children.values():
yield sc.fd
if sc.stdin > -1:
yield sc.stdin
if sc.stdout > -1:
yield sc.stdout
if sc.stderr > -1:
yield sc.stderr
def free_non_child_resources() -> None:
for fd in get_all_non_child_fds():
if fd > -1:
os.close(fd)
unix_socket.close()
for sc in tuple(socket_children.values()):
remove_socket_child(sc)
def check_event(event: int, err_msg: str) -> None:
if event & select.POLLHUP:
@ -576,7 +593,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
raise SystemExit(0)
elif cmd == 'fork':
try:
child_pid, ready_fd_write = fork(payload, get_all_non_child_fds())
child_pid, ready_fd_write = fork(payload, free_non_child_resources)
except Exception as e:
es = str(e).replace('\n', ' ')
output_buf += f'ERR:{es}\n'.encode()
@ -654,7 +671,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
check_event(event, 'UNIX socket fd listener failed')
conn, addr = unix_socket.accept()
sc = SocketChild(conn, addr, poll)
socket_children[sc.fd] = sc
socket_children[sc.conn.fileno()] = sc
def handle_socket_launch(fd: int, event: int) -> None:
scq = socket_children.get(q)
@ -664,7 +681,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
try:
if scq.read():
scq.unregister_from_poll()
scq.fork(get_all_non_child_fds())
scq.fork(free_non_child_resources)
socket_pid_map[scq.pid] = scq
scq.child_id = next(child_id_counter)
except SocketClosed:
@ -714,6 +731,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
for fmd in child_ready_fds.values():
with suppress(OSError):
os.close(fmd)
for sc in tuple(socket_children.values()):
remove_socket_child(sc)
def get_socket_name(unix_socket: socket.socket) -> str:
@ -743,7 +762,8 @@ def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int, unix_
main(stdin_read, stdout_write, death_notify_write, unix_socket)
finally:
set_options(None)
unix_socket.close()
if is_zygote:
unix_socket.close()
def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[PrewarmProcess]:

View File

@ -204,8 +204,8 @@ class PTY:
self.master_fd, self.slave_fd = openpty()
self.is_child = False
else:
pid, self.master_fd = fork()
self.is_child = pid == CHILD
self.child_pid, self.master_fd = fork()
self.is_child = self.child_pid == CHILD
if self.is_child:
while read_screen_size().width != columns * cell_width:
time.sleep(0.01)

View File

@ -4,13 +4,15 @@
import json
import os
import re
import select
import signal
import subprocess
import tempfile
import time
from contextlib import suppress
from kitty.constants import kitty_exe
from kitty.constants import kitty_exe, read_kitty_resource
from kitty.fast_data_types import (
CLD_EXITED, CLD_KILLED, get_options, has_sigqueue, install_signal_handlers,
read_signals, remove_signal_handlers, sigqueue
@ -19,10 +21,55 @@ from kitty.fast_data_types import (
from . import BaseTest
def socket_child_main(exit_code=0):
import json
import os
from kitty.fast_data_types import get_options
from kitty.utils import read_screen_size
output = {
'test_env': os.environ.get('TEST_ENV_PASS', ''),
'cwd': os.getcwd(),
'font_family': get_options().font_family,
'cols': read_screen_size().cols,
'done': 'hello',
}
print(json.dumps(output, indent=2))
raise SystemExit(exit_code)
# END_socket_child_main
class Prewarm(BaseTest):
maxDiff = None
def test_socket_prewarming(self):
from kitty.prewarm import fork_prewarm_process
exit_code = 17
src = re.search(
r'^(def socket_child_main.+?)^# END_socket_child_main', read_kitty_resource('prewarm.py', 'kitty_tests').decode(),
flags=re.M | re.DOTALL).group(1) + '\n\n'
cwd = tempfile.gettempdir()
opts = self.set_options()
opts.config_overrides = 'font_family prewarm',
p = fork_prewarm_process(opts, use_exec=True)
if p is None:
return
env = {'TEST_ENV_PASS': 'xyz', 'KITTY_PREWARM_SOCKET': p.socket_env_var()}
cols = 117
pty = self.create_pty(argv=[kitty_exe(), '+runpy', src + f'socket_child_main({exit_code})'], cols=cols, env=env, cwd=cwd)
status = os.waitpid(pty.child_pid, 0)[1]
with suppress(AttributeError):
self.assertEqual(os.waitstatus_to_exitcode(status), exit_code)
pty.wait_till(lambda: 'hello' in pty.screen_contents())
output = json.loads(pty.screen_contents().strip())
self.assertEqual(output['test_env'], env['TEST_ENV_PASS'])
self.assertEqual(output['cwd'], cwd)
self.assertEqual(output['font_family'], 'prewarm')
self.assertEqual(output['cols'], cols)
def test_prewarming(self):
from kitty.prewarm import fork_prewarm_process