Get prewarm working
Needed to wait in the control process for child to set its controlling terminal so that closing the slave fd in kitty is safe.
This commit is contained in:
parent
0c870c5fcd
commit
116128ebb5
@ -194,11 +194,24 @@ class ProcessDesc(TypedDict):
|
||||
cmdline: Optional[Sequence[str]]
|
||||
|
||||
|
||||
def is_prewarmable(argv: Sequence[str]) -> bool:
|
||||
if len(argv) < 3 or os.path.basename(argv[0]) != 'kitty':
|
||||
return False
|
||||
if argv[1][:1] not in '@+':
|
||||
return False
|
||||
if argv[1][0] == '@':
|
||||
return True
|
||||
if argv[1] == '+':
|
||||
return argv[2] != 'open'
|
||||
return argv[1] != '+open'
|
||||
|
||||
|
||||
class Child:
|
||||
|
||||
child_fd: Optional[int] = None
|
||||
pid: Optional[int] = None
|
||||
forked = False
|
||||
is_prewarmed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -262,14 +275,16 @@ class Child:
|
||||
self.forked = True
|
||||
master, slave = openpty()
|
||||
stdin, self.stdin = self.stdin, None
|
||||
ready_read_fd, ready_write_fd = os.pipe()
|
||||
remove_cloexec(ready_read_fd)
|
||||
if stdin is not None:
|
||||
stdin_read_fd, stdin_write_fd = os.pipe()
|
||||
remove_cloexec(stdin_read_fd)
|
||||
else:
|
||||
stdin_read_fd = stdin_write_fd = -1
|
||||
env = tuple(f'{k}={v}' for k, v in self.final_env().items())
|
||||
self.is_prewarmed = is_prewarmable(self.argv)
|
||||
if not self.is_prewarmed:
|
||||
ready_read_fd, ready_write_fd = os.pipe()
|
||||
remove_cloexec(ready_read_fd)
|
||||
if stdin is not None:
|
||||
stdin_read_fd, stdin_write_fd = os.pipe()
|
||||
remove_cloexec(stdin_read_fd)
|
||||
else:
|
||||
stdin_read_fd = stdin_write_fd = -1
|
||||
env = tuple(f'{k}={v}' for k, v in self.final_env().items())
|
||||
argv = list(self.argv)
|
||||
exe = argv[0]
|
||||
if is_macos and exe == shell_path:
|
||||
@ -290,24 +305,33 @@ class Child:
|
||||
argv[0] = (f'-{exe.split("/")[-1]}')
|
||||
self.final_exe = which(exe) or exe
|
||||
self.final_argv0 = argv[0]
|
||||
pid = fast_data_types.spawn(
|
||||
self.final_exe, self.cwd, tuple(argv), env, master, slave,
|
||||
stdin_read_fd, stdin_write_fd, ready_read_fd, ready_write_fd, tuple(handled_signals))
|
||||
if self.is_prewarmed:
|
||||
fe = self.final_env()
|
||||
self.prewarmed_child = fast_data_types.get_boss().prewarm(slave, self.argv, self.cwd, fe, stdin)
|
||||
pid = self.prewarmed_child.child_process_pid
|
||||
else:
|
||||
pid = fast_data_types.spawn(
|
||||
self.final_exe, self.cwd, tuple(argv), env, master, slave, stdin_read_fd, stdin_write_fd,
|
||||
ready_read_fd, ready_write_fd, tuple(handled_signals))
|
||||
os.close(slave)
|
||||
self.pid = pid
|
||||
self.child_fd = master
|
||||
if stdin is not None:
|
||||
os.close(stdin_read_fd)
|
||||
fast_data_types.thread_write(stdin_write_fd, stdin)
|
||||
os.close(ready_read_fd)
|
||||
self.terminal_ready_fd = ready_write_fd
|
||||
if not self.is_prewarmed:
|
||||
if stdin is not None:
|
||||
os.close(stdin_read_fd)
|
||||
fast_data_types.thread_write(stdin_write_fd, stdin)
|
||||
os.close(ready_read_fd)
|
||||
self.terminal_ready_fd = ready_write_fd
|
||||
if self.child_fd is not None:
|
||||
remove_blocking(self.child_fd)
|
||||
return pid
|
||||
|
||||
def mark_terminal_ready(self) -> None:
|
||||
os.close(self.terminal_ready_fd)
|
||||
self.terminal_ready_fd = -1
|
||||
if self.is_prewarmed:
|
||||
fast_data_types.get_boss().prewarm.mark_child_as_ready(self.prewarmed_child.child_id)
|
||||
else:
|
||||
os.close(self.terminal_ready_fd)
|
||||
self.terminal_ready_fd = -1
|
||||
|
||||
def cmdline_of_pid(self, pid: int) -> List[str]:
|
||||
try:
|
||||
|
||||
@ -15,7 +15,6 @@ from typing import (
|
||||
IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast
|
||||
)
|
||||
|
||||
from kitty.child import remove_cloexec
|
||||
from kitty.constants import kitty_exe
|
||||
from kitty.entry_points import main as main_entry_point
|
||||
from kitty.fast_data_types import (
|
||||
@ -27,9 +26,7 @@ if TYPE_CHECKING:
|
||||
from _typeshed import ReadableBuffer, WriteableBuffer
|
||||
|
||||
|
||||
hangup_events = select.POLLHUP
|
||||
error_events = select.POLLERR | select.POLLNVAL
|
||||
basic_events = hangup_events | error_events
|
||||
error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
|
||||
|
||||
|
||||
class PrewarmProcessFailed(Exception):
|
||||
@ -98,11 +95,11 @@ class PrewarmProcess:
|
||||
os.set_blocking(self.write_to_process_fd, False)
|
||||
os.set_blocking(self.read_from_process_fd, False)
|
||||
self.poll = select.poll()
|
||||
self.poll.register(self.process.stdout.fileno(), select.POLLIN | basic_events)
|
||||
self.poll.register(self.process.stdout.fileno(), select.POLLIN)
|
||||
|
||||
def poll_to_send(self, yes: bool = True) -> None:
|
||||
if yes:
|
||||
self.poll.register(self.write_to_process_fd, select.POLLOUT | basic_events)
|
||||
self.poll.register(self.write_to_process_fd, select.POLLOUT)
|
||||
else:
|
||||
self.poll.unregister(self.write_to_process_fd)
|
||||
|
||||
@ -143,7 +140,7 @@ class PrewarmProcess:
|
||||
st = time.monotonic()
|
||||
while time.monotonic() - st < 2:
|
||||
for (fd, event) in self.poll.poll(0.2):
|
||||
if event & basic_events:
|
||||
if event & error_events:
|
||||
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
|
||||
if fd == self.read_from_process_fd and event & select.POLLIN:
|
||||
d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE)
|
||||
@ -171,8 +168,8 @@ class PrewarmProcess:
|
||||
while time.monotonic() - st < timeout and output_buf:
|
||||
self.poll_to_send(bool(output_buf))
|
||||
for (fd, event) in self.poll.poll(0.2):
|
||||
if event & basic_events:
|
||||
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
|
||||
if event & error_events:
|
||||
raise PrewarmProcessFailed('Failed doing I/O with prewarm process: {event}')
|
||||
if fd == self.write_to_process_fd and event & select.POLLOUT:
|
||||
n = os.write(self.write_to_process_fd, output_buf)
|
||||
output_buf = output_buf[n:]
|
||||
@ -198,7 +195,8 @@ def prewarm() -> None:
|
||||
reload_kitty_config()
|
||||
from kittens.runner import all_kitten_names
|
||||
for kitten in all_kitten_names():
|
||||
import_module(f'kittens.{kitten}.main')
|
||||
with suppress(Exception):
|
||||
import_module(f'kittens.{kitten}.main')
|
||||
import_module('kitty.complete')
|
||||
|
||||
|
||||
@ -249,12 +247,8 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
|
||||
def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
||||
cwd = cmd.get('cwd')
|
||||
if cwd:
|
||||
try:
|
||||
with suppress(OSError):
|
||||
os.chdir(cwd)
|
||||
except OSError:
|
||||
with suppress(OSError):
|
||||
os.chdir('/')
|
||||
os.setsid()
|
||||
env = cmd.get('env')
|
||||
if env is not None:
|
||||
os.environ.clear()
|
||||
@ -264,13 +258,8 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
||||
sys.argv = list(argv)
|
||||
poll = select.poll()
|
||||
poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP)
|
||||
poll.poll()
|
||||
tuple(poll.poll())
|
||||
os.close(ready_fd)
|
||||
tty_name = cmd.get('tty_name')
|
||||
if tty_name:
|
||||
sys.__stdout__.flush()
|
||||
sys.__stderr__.flush()
|
||||
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
||||
main_entry_point()
|
||||
raise SystemExit(0)
|
||||
|
||||
@ -285,6 +274,8 @@ def fork(shm_address: str, ready_fd: int) -> int:
|
||||
pos = shm.tell()
|
||||
shm.unlink_on_exit = False
|
||||
|
||||
r, w = os.pipe()
|
||||
os.set_inheritable(r, False)
|
||||
try:
|
||||
child_pid = os.fork()
|
||||
except OSError:
|
||||
@ -293,8 +284,23 @@ def fork(shm_address: str, ready_fd: int) -> int:
|
||||
pass
|
||||
if child_pid:
|
||||
# master process
|
||||
os.close(w)
|
||||
try:
|
||||
poll = select.poll()
|
||||
poll.register(r, select.POLLIN | select.POLLHUP | select.POLLERR)
|
||||
tuple(poll.poll())
|
||||
finally:
|
||||
os.close(r)
|
||||
return child_pid
|
||||
# child process
|
||||
os.setsid()
|
||||
tty_name = cmd.get('tty_name')
|
||||
if tty_name:
|
||||
sys.__stdout__.flush()
|
||||
sys.__stderr__.flush()
|
||||
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
||||
os.write(w, b'1')
|
||||
os.close(w)
|
||||
if shm.unlink_on_exit:
|
||||
child_main(cmd, ready_fd)
|
||||
else:
|
||||
@ -321,15 +327,16 @@ def main(notify_child_death_fd: int) -> None:
|
||||
stdout_fd = sys.__stdout__.fileno()
|
||||
os.set_blocking(stdout_fd, False)
|
||||
poll = select.poll()
|
||||
poll.register(stdin_fd, select.POLLIN | basic_events)
|
||||
poll.register(read_signal_fd, select.POLLIN | basic_events)
|
||||
poll.register(stdin_fd, select.POLLIN)
|
||||
poll.register(read_signal_fd, select.POLLIN)
|
||||
input_buf = output_buf = child_death_buf = b''
|
||||
child_ready_fds: Dict[int, int] = {}
|
||||
child_id_map: Dict[int, int] = {}
|
||||
self_pid = os.getpid()
|
||||
os.setsid()
|
||||
|
||||
def check_event(event: int, err_msg: str) -> None:
|
||||
if event & hangup_events:
|
||||
if event & select.POLLHUP:
|
||||
raise SystemExit(0)
|
||||
if event & error_events:
|
||||
raise SystemExit(err_msg)
|
||||
@ -357,10 +364,10 @@ def main(notify_child_death_fd: int) -> None:
|
||||
os.write(cfd, b'1')
|
||||
os.close(cfd)
|
||||
elif cmd == 'fork':
|
||||
read_fd, write_fd = safe_pipe(False)
|
||||
remove_cloexec(read_fd)
|
||||
r, w = os.pipe()
|
||||
os.set_inheritable(w, False)
|
||||
try:
|
||||
child_pid = fork(payload, read_fd)
|
||||
child_pid = fork(payload, r)
|
||||
except Exception as e:
|
||||
es = str(e).replace('\n', ' ')
|
||||
output_buf += f'ERR:{es}\n'.encode()
|
||||
@ -368,11 +375,11 @@ def main(notify_child_death_fd: int) -> None:
|
||||
if os.getpid() == self_pid:
|
||||
child_id = len(child_id_map) + 1
|
||||
child_id_map[child_id] = child_pid
|
||||
child_ready_fds[child_id] = write_fd
|
||||
child_ready_fds[child_id] = w
|
||||
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
|
||||
finally:
|
||||
if os.getpid() == self_pid:
|
||||
os.close(read_fd)
|
||||
os.close(r)
|
||||
elif cmd == 'echo':
|
||||
output_buf += f'{payload}\n'.encode()
|
||||
|
||||
@ -428,14 +435,16 @@ def main(notify_child_death_fd: int) -> None:
|
||||
break
|
||||
if matched_child_id > -1:
|
||||
del child_id_map[matched_child_id]
|
||||
child_ready_fds.pop(matched_child_id, None)
|
||||
ofd = child_ready_fds.pop(matched_child_id, None)
|
||||
if ofd is not None:
|
||||
os.close(ofd)
|
||||
child_death_buf += f'{pid}\n'.encode()
|
||||
try:
|
||||
while True:
|
||||
if output_buf:
|
||||
poll.register(stdout_fd, select.POLLOUT | basic_events)
|
||||
poll.register(stdout_fd, select.POLLOUT)
|
||||
if child_death_buf:
|
||||
poll.register(notify_child_death_fd, select.POLLOUT | basic_events)
|
||||
poll.register(notify_child_death_fd, select.POLLOUT)
|
||||
for (q, event) in poll.poll():
|
||||
if q == stdin_fd:
|
||||
handle_input(event)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user