Get prewarm working

Needed to wait in the control process for child to set its controlling
terminal so that closing the slave fd in kitty is safe.
This commit is contained in:
Kovid Goyal 2022-06-08 11:33:18 +05:30
parent 0c870c5fcd
commit 116128ebb5
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 83 additions and 50 deletions

View File

@ -194,11 +194,24 @@ class ProcessDesc(TypedDict):
cmdline: Optional[Sequence[str]] cmdline: Optional[Sequence[str]]
def is_prewarmable(argv: Sequence[str]) -> bool:
if len(argv) < 3 or os.path.basename(argv[0]) != 'kitty':
return False
if argv[1][:1] not in '@+':
return False
if argv[1][0] == '@':
return True
if argv[1] == '+':
return argv[2] != 'open'
return argv[1] != '+open'
class Child: class Child:
child_fd: Optional[int] = None child_fd: Optional[int] = None
pid: Optional[int] = None pid: Optional[int] = None
forked = False forked = False
is_prewarmed = False
def __init__( def __init__(
self, self,
@ -262,6 +275,8 @@ class Child:
self.forked = True self.forked = True
master, slave = openpty() master, slave = openpty()
stdin, self.stdin = self.stdin, None stdin, self.stdin = self.stdin, None
self.is_prewarmed = is_prewarmable(self.argv)
if not self.is_prewarmed:
ready_read_fd, ready_write_fd = os.pipe() ready_read_fd, ready_write_fd = os.pipe()
remove_cloexec(ready_read_fd) remove_cloexec(ready_read_fd)
if stdin is not None: if stdin is not None:
@ -290,12 +305,18 @@ class Child:
argv[0] = (f'-{exe.split("/")[-1]}') argv[0] = (f'-{exe.split("/")[-1]}')
self.final_exe = which(exe) or exe self.final_exe = which(exe) or exe
self.final_argv0 = argv[0] self.final_argv0 = argv[0]
if self.is_prewarmed:
fe = self.final_env()
self.prewarmed_child = fast_data_types.get_boss().prewarm(slave, self.argv, self.cwd, fe, stdin)
pid = self.prewarmed_child.child_process_pid
else:
pid = fast_data_types.spawn( pid = fast_data_types.spawn(
self.final_exe, self.cwd, tuple(argv), env, master, slave, self.final_exe, self.cwd, tuple(argv), env, master, slave, stdin_read_fd, stdin_write_fd,
stdin_read_fd, stdin_write_fd, ready_read_fd, ready_write_fd, tuple(handled_signals)) ready_read_fd, ready_write_fd, tuple(handled_signals))
os.close(slave) os.close(slave)
self.pid = pid self.pid = pid
self.child_fd = master self.child_fd = master
if not self.is_prewarmed:
if stdin is not None: if stdin is not None:
os.close(stdin_read_fd) os.close(stdin_read_fd)
fast_data_types.thread_write(stdin_write_fd, stdin) fast_data_types.thread_write(stdin_write_fd, stdin)
@ -306,6 +327,9 @@ class Child:
return pid return pid
def mark_terminal_ready(self) -> None: def mark_terminal_ready(self) -> None:
if self.is_prewarmed:
fast_data_types.get_boss().prewarm.mark_child_as_ready(self.prewarmed_child.child_id)
else:
os.close(self.terminal_ready_fd) os.close(self.terminal_ready_fd)
self.terminal_ready_fd = -1 self.terminal_ready_fd = -1

View File

@ -15,7 +15,6 @@ from typing import (
IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast
) )
from kitty.child import remove_cloexec
from kitty.constants import kitty_exe from kitty.constants import kitty_exe
from kitty.entry_points import main as main_entry_point from kitty.entry_points import main as main_entry_point
from kitty.fast_data_types import ( from kitty.fast_data_types import (
@ -27,9 +26,7 @@ if TYPE_CHECKING:
from _typeshed import ReadableBuffer, WriteableBuffer from _typeshed import ReadableBuffer, WriteableBuffer
hangup_events = select.POLLHUP error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
error_events = select.POLLERR | select.POLLNVAL
basic_events = hangup_events | error_events
class PrewarmProcessFailed(Exception): class PrewarmProcessFailed(Exception):
@ -98,11 +95,11 @@ class PrewarmProcess:
os.set_blocking(self.write_to_process_fd, False) os.set_blocking(self.write_to_process_fd, False)
os.set_blocking(self.read_from_process_fd, False) os.set_blocking(self.read_from_process_fd, False)
self.poll = select.poll() self.poll = select.poll()
self.poll.register(self.process.stdout.fileno(), select.POLLIN | basic_events) self.poll.register(self.process.stdout.fileno(), select.POLLIN)
def poll_to_send(self, yes: bool = True) -> None: def poll_to_send(self, yes: bool = True) -> None:
if yes: if yes:
self.poll.register(self.write_to_process_fd, select.POLLOUT | basic_events) self.poll.register(self.write_to_process_fd, select.POLLOUT)
else: else:
self.poll.unregister(self.write_to_process_fd) self.poll.unregister(self.write_to_process_fd)
@ -143,7 +140,7 @@ class PrewarmProcess:
st = time.monotonic() st = time.monotonic()
while time.monotonic() - st < 2: while time.monotonic() - st < 2:
for (fd, event) in self.poll.poll(0.2): for (fd, event) in self.poll.poll(0.2):
if event & basic_events: if event & error_events:
raise PrewarmProcessFailed('Failed doing I/O with prewarm process') raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
if fd == self.read_from_process_fd and event & select.POLLIN: if fd == self.read_from_process_fd and event & select.POLLIN:
d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE) d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE)
@ -171,8 +168,8 @@ class PrewarmProcess:
while time.monotonic() - st < timeout and output_buf: while time.monotonic() - st < timeout and output_buf:
self.poll_to_send(bool(output_buf)) self.poll_to_send(bool(output_buf))
for (fd, event) in self.poll.poll(0.2): for (fd, event) in self.poll.poll(0.2):
if event & basic_events: if event & error_events:
raise PrewarmProcessFailed('Failed doing I/O with prewarm process') raise PrewarmProcessFailed('Failed doing I/O with prewarm process: {event}')
if fd == self.write_to_process_fd and event & select.POLLOUT: if fd == self.write_to_process_fd and event & select.POLLOUT:
n = os.write(self.write_to_process_fd, output_buf) n = os.write(self.write_to_process_fd, output_buf)
output_buf = output_buf[n:] output_buf = output_buf[n:]
@ -198,6 +195,7 @@ def prewarm() -> None:
reload_kitty_config() reload_kitty_config()
from kittens.runner import all_kitten_names from kittens.runner import all_kitten_names
for kitten in all_kitten_names(): for kitten in all_kitten_names():
with suppress(Exception):
import_module(f'kittens.{kitten}.main') import_module(f'kittens.{kitten}.main')
import_module('kitty.complete') import_module('kitty.complete')
@ -249,12 +247,8 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn: def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
cwd = cmd.get('cwd') cwd = cmd.get('cwd')
if cwd: if cwd:
try:
os.chdir(cwd)
except OSError:
with suppress(OSError): with suppress(OSError):
os.chdir('/') os.chdir(cwd)
os.setsid()
env = cmd.get('env') env = cmd.get('env')
if env is not None: if env is not None:
os.environ.clear() os.environ.clear()
@ -264,13 +258,8 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
sys.argv = list(argv) sys.argv = list(argv)
poll = select.poll() poll = select.poll()
poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP) poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP)
poll.poll() tuple(poll.poll())
os.close(ready_fd) os.close(ready_fd)
tty_name = cmd.get('tty_name')
if tty_name:
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
main_entry_point() main_entry_point()
raise SystemExit(0) raise SystemExit(0)
@ -285,6 +274,8 @@ def fork(shm_address: str, ready_fd: int) -> int:
pos = shm.tell() pos = shm.tell()
shm.unlink_on_exit = False shm.unlink_on_exit = False
r, w = os.pipe()
os.set_inheritable(r, False)
try: try:
child_pid = os.fork() child_pid = os.fork()
except OSError: except OSError:
@ -293,8 +284,23 @@ def fork(shm_address: str, ready_fd: int) -> int:
pass pass
if child_pid: if child_pid:
# master process # master process
os.close(w)
try:
poll = select.poll()
poll.register(r, select.POLLIN | select.POLLHUP | select.POLLERR)
tuple(poll.poll())
finally:
os.close(r)
return child_pid return child_pid
# child process # child process
os.setsid()
tty_name = cmd.get('tty_name')
if tty_name:
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
os.write(w, b'1')
os.close(w)
if shm.unlink_on_exit: if shm.unlink_on_exit:
child_main(cmd, ready_fd) child_main(cmd, ready_fd)
else: else:
@ -321,15 +327,16 @@ def main(notify_child_death_fd: int) -> None:
stdout_fd = sys.__stdout__.fileno() stdout_fd = sys.__stdout__.fileno()
os.set_blocking(stdout_fd, False) os.set_blocking(stdout_fd, False)
poll = select.poll() poll = select.poll()
poll.register(stdin_fd, select.POLLIN | basic_events) poll.register(stdin_fd, select.POLLIN)
poll.register(read_signal_fd, select.POLLIN | basic_events) poll.register(read_signal_fd, select.POLLIN)
input_buf = output_buf = child_death_buf = b'' input_buf = output_buf = child_death_buf = b''
child_ready_fds: Dict[int, int] = {} child_ready_fds: Dict[int, int] = {}
child_id_map: Dict[int, int] = {} child_id_map: Dict[int, int] = {}
self_pid = os.getpid() self_pid = os.getpid()
os.setsid()
def check_event(event: int, err_msg: str) -> None: def check_event(event: int, err_msg: str) -> None:
if event & hangup_events: if event & select.POLLHUP:
raise SystemExit(0) raise SystemExit(0)
if event & error_events: if event & error_events:
raise SystemExit(err_msg) raise SystemExit(err_msg)
@ -357,10 +364,10 @@ def main(notify_child_death_fd: int) -> None:
os.write(cfd, b'1') os.write(cfd, b'1')
os.close(cfd) os.close(cfd)
elif cmd == 'fork': elif cmd == 'fork':
read_fd, write_fd = safe_pipe(False) r, w = os.pipe()
remove_cloexec(read_fd) os.set_inheritable(w, False)
try: try:
child_pid = fork(payload, read_fd) child_pid = fork(payload, r)
except Exception as e: except Exception as e:
es = str(e).replace('\n', ' ') es = str(e).replace('\n', ' ')
output_buf += f'ERR:{es}\n'.encode() output_buf += f'ERR:{es}\n'.encode()
@ -368,11 +375,11 @@ def main(notify_child_death_fd: int) -> None:
if os.getpid() == self_pid: if os.getpid() == self_pid:
child_id = len(child_id_map) + 1 child_id = len(child_id_map) + 1
child_id_map[child_id] = child_pid child_id_map[child_id] = child_pid
child_ready_fds[child_id] = write_fd child_ready_fds[child_id] = w
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode() output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
finally: finally:
if os.getpid() == self_pid: if os.getpid() == self_pid:
os.close(read_fd) os.close(r)
elif cmd == 'echo': elif cmd == 'echo':
output_buf += f'{payload}\n'.encode() output_buf += f'{payload}\n'.encode()
@ -428,14 +435,16 @@ def main(notify_child_death_fd: int) -> None:
break break
if matched_child_id > -1: if matched_child_id > -1:
del child_id_map[matched_child_id] del child_id_map[matched_child_id]
child_ready_fds.pop(matched_child_id, None) ofd = child_ready_fds.pop(matched_child_id, None)
if ofd is not None:
os.close(ofd)
child_death_buf += f'{pid}\n'.encode() child_death_buf += f'{pid}\n'.encode()
try: try:
while True: while True:
if output_buf: if output_buf:
poll.register(stdout_fd, select.POLLOUT | basic_events) poll.register(stdout_fd, select.POLLOUT)
if child_death_buf: if child_death_buf:
poll.register(notify_child_death_fd, select.POLLOUT | basic_events) poll.register(notify_child_death_fd, select.POLLOUT)
for (q, event) in poll.poll(): for (q, event) in poll.poll():
if q == stdin_fd: if q == stdin_fd:
handle_input(event) handle_input(event)