Get prewarm working

Needed to wait in the control process for child to set its controlling
terminal so that closing the slave fd in kitty is safe.
This commit is contained in:
Kovid Goyal 2022-06-08 11:33:18 +05:30
parent 0c870c5fcd
commit 116128ebb5
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 83 additions and 50 deletions

View File

@ -194,11 +194,24 @@ class ProcessDesc(TypedDict):
cmdline: Optional[Sequence[str]]
def is_prewarmable(argv: Sequence[str]) -> bool:
if len(argv) < 3 or os.path.basename(argv[0]) != 'kitty':
return False
if argv[1][:1] not in '@+':
return False
if argv[1][0] == '@':
return True
if argv[1] == '+':
return argv[2] != 'open'
return argv[1] != '+open'
class Child:
child_fd: Optional[int] = None
pid: Optional[int] = None
forked = False
is_prewarmed = False
def __init__(
self,
@ -262,14 +275,16 @@ class Child:
self.forked = True
master, slave = openpty()
stdin, self.stdin = self.stdin, None
ready_read_fd, ready_write_fd = os.pipe()
remove_cloexec(ready_read_fd)
if stdin is not None:
stdin_read_fd, stdin_write_fd = os.pipe()
remove_cloexec(stdin_read_fd)
else:
stdin_read_fd = stdin_write_fd = -1
env = tuple(f'{k}={v}' for k, v in self.final_env().items())
self.is_prewarmed = is_prewarmable(self.argv)
if not self.is_prewarmed:
ready_read_fd, ready_write_fd = os.pipe()
remove_cloexec(ready_read_fd)
if stdin is not None:
stdin_read_fd, stdin_write_fd = os.pipe()
remove_cloexec(stdin_read_fd)
else:
stdin_read_fd = stdin_write_fd = -1
env = tuple(f'{k}={v}' for k, v in self.final_env().items())
argv = list(self.argv)
exe = argv[0]
if is_macos and exe == shell_path:
@ -290,24 +305,33 @@ class Child:
argv[0] = (f'-{exe.split("/")[-1]}')
self.final_exe = which(exe) or exe
self.final_argv0 = argv[0]
pid = fast_data_types.spawn(
self.final_exe, self.cwd, tuple(argv), env, master, slave,
stdin_read_fd, stdin_write_fd, ready_read_fd, ready_write_fd, tuple(handled_signals))
if self.is_prewarmed:
fe = self.final_env()
self.prewarmed_child = fast_data_types.get_boss().prewarm(slave, self.argv, self.cwd, fe, stdin)
pid = self.prewarmed_child.child_process_pid
else:
pid = fast_data_types.spawn(
self.final_exe, self.cwd, tuple(argv), env, master, slave, stdin_read_fd, stdin_write_fd,
ready_read_fd, ready_write_fd, tuple(handled_signals))
os.close(slave)
self.pid = pid
self.child_fd = master
if stdin is not None:
os.close(stdin_read_fd)
fast_data_types.thread_write(stdin_write_fd, stdin)
os.close(ready_read_fd)
self.terminal_ready_fd = ready_write_fd
if not self.is_prewarmed:
if stdin is not None:
os.close(stdin_read_fd)
fast_data_types.thread_write(stdin_write_fd, stdin)
os.close(ready_read_fd)
self.terminal_ready_fd = ready_write_fd
if self.child_fd is not None:
remove_blocking(self.child_fd)
return pid
def mark_terminal_ready(self) -> None:
os.close(self.terminal_ready_fd)
self.terminal_ready_fd = -1
if self.is_prewarmed:
fast_data_types.get_boss().prewarm.mark_child_as_ready(self.prewarmed_child.child_id)
else:
os.close(self.terminal_ready_fd)
self.terminal_ready_fd = -1
def cmdline_of_pid(self, pid: int) -> List[str]:
try:

View File

@ -15,7 +15,6 @@ from typing import (
IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast
)
from kitty.child import remove_cloexec
from kitty.constants import kitty_exe
from kitty.entry_points import main as main_entry_point
from kitty.fast_data_types import (
@ -27,9 +26,7 @@ if TYPE_CHECKING:
from _typeshed import ReadableBuffer, WriteableBuffer
hangup_events = select.POLLHUP
error_events = select.POLLERR | select.POLLNVAL
basic_events = hangup_events | error_events
error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
class PrewarmProcessFailed(Exception):
@ -98,11 +95,11 @@ class PrewarmProcess:
os.set_blocking(self.write_to_process_fd, False)
os.set_blocking(self.read_from_process_fd, False)
self.poll = select.poll()
self.poll.register(self.process.stdout.fileno(), select.POLLIN | basic_events)
self.poll.register(self.process.stdout.fileno(), select.POLLIN)
def poll_to_send(self, yes: bool = True) -> None:
if yes:
self.poll.register(self.write_to_process_fd, select.POLLOUT | basic_events)
self.poll.register(self.write_to_process_fd, select.POLLOUT)
else:
self.poll.unregister(self.write_to_process_fd)
@ -143,7 +140,7 @@ class PrewarmProcess:
st = time.monotonic()
while time.monotonic() - st < 2:
for (fd, event) in self.poll.poll(0.2):
if event & basic_events:
if event & error_events:
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
if fd == self.read_from_process_fd and event & select.POLLIN:
d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE)
@ -171,8 +168,8 @@ class PrewarmProcess:
while time.monotonic() - st < timeout and output_buf:
self.poll_to_send(bool(output_buf))
for (fd, event) in self.poll.poll(0.2):
if event & basic_events:
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
if event & error_events:
raise PrewarmProcessFailed('Failed doing I/O with prewarm process: {event}')
if fd == self.write_to_process_fd and event & select.POLLOUT:
n = os.write(self.write_to_process_fd, output_buf)
output_buf = output_buf[n:]
@ -198,7 +195,8 @@ def prewarm() -> None:
reload_kitty_config()
from kittens.runner import all_kitten_names
for kitten in all_kitten_names():
import_module(f'kittens.{kitten}.main')
with suppress(Exception):
import_module(f'kittens.{kitten}.main')
import_module('kitty.complete')
@ -249,12 +247,8 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
cwd = cmd.get('cwd')
if cwd:
try:
with suppress(OSError):
os.chdir(cwd)
except OSError:
with suppress(OSError):
os.chdir('/')
os.setsid()
env = cmd.get('env')
if env is not None:
os.environ.clear()
@ -264,13 +258,8 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
sys.argv = list(argv)
poll = select.poll()
poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP)
poll.poll()
tuple(poll.poll())
os.close(ready_fd)
tty_name = cmd.get('tty_name')
if tty_name:
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
main_entry_point()
raise SystemExit(0)
@ -285,6 +274,8 @@ def fork(shm_address: str, ready_fd: int) -> int:
pos = shm.tell()
shm.unlink_on_exit = False
r, w = os.pipe()
os.set_inheritable(r, False)
try:
child_pid = os.fork()
except OSError:
@ -293,8 +284,23 @@ def fork(shm_address: str, ready_fd: int) -> int:
pass
if child_pid:
# master process
os.close(w)
try:
poll = select.poll()
poll.register(r, select.POLLIN | select.POLLHUP | select.POLLERR)
tuple(poll.poll())
finally:
os.close(r)
return child_pid
# child process
os.setsid()
tty_name = cmd.get('tty_name')
if tty_name:
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
os.write(w, b'1')
os.close(w)
if shm.unlink_on_exit:
child_main(cmd, ready_fd)
else:
@ -321,15 +327,16 @@ def main(notify_child_death_fd: int) -> None:
stdout_fd = sys.__stdout__.fileno()
os.set_blocking(stdout_fd, False)
poll = select.poll()
poll.register(stdin_fd, select.POLLIN | basic_events)
poll.register(read_signal_fd, select.POLLIN | basic_events)
poll.register(stdin_fd, select.POLLIN)
poll.register(read_signal_fd, select.POLLIN)
input_buf = output_buf = child_death_buf = b''
child_ready_fds: Dict[int, int] = {}
child_id_map: Dict[int, int] = {}
self_pid = os.getpid()
os.setsid()
def check_event(event: int, err_msg: str) -> None:
if event & hangup_events:
if event & select.POLLHUP:
raise SystemExit(0)
if event & error_events:
raise SystemExit(err_msg)
@ -357,10 +364,10 @@ def main(notify_child_death_fd: int) -> None:
os.write(cfd, b'1')
os.close(cfd)
elif cmd == 'fork':
read_fd, write_fd = safe_pipe(False)
remove_cloexec(read_fd)
r, w = os.pipe()
os.set_inheritable(w, False)
try:
child_pid = fork(payload, read_fd)
child_pid = fork(payload, r)
except Exception as e:
es = str(e).replace('\n', ' ')
output_buf += f'ERR:{es}\n'.encode()
@ -368,11 +375,11 @@ def main(notify_child_death_fd: int) -> None:
if os.getpid() == self_pid:
child_id = len(child_id_map) + 1
child_id_map[child_id] = child_pid
child_ready_fds[child_id] = write_fd
child_ready_fds[child_id] = w
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
finally:
if os.getpid() == self_pid:
os.close(read_fd)
os.close(r)
elif cmd == 'echo':
output_buf += f'{payload}\n'.encode()
@ -428,14 +435,16 @@ def main(notify_child_death_fd: int) -> None:
break
if matched_child_id > -1:
del child_id_map[matched_child_id]
child_ready_fds.pop(matched_child_id, None)
ofd = child_ready_fds.pop(matched_child_id, None)
if ofd is not None:
os.close(ofd)
child_death_buf += f'{pid}\n'.encode()
try:
while True:
if output_buf:
poll.register(stdout_fd, select.POLLOUT | basic_events)
poll.register(stdout_fd, select.POLLOUT)
if child_death_buf:
poll.register(notify_child_death_fd, select.POLLOUT | basic_events)
poll.register(notify_child_death_fd, select.POLLOUT)
for (q, event) in poll.poll():
if q == stdin_fd:
handle_input(event)