Get prewarm working
Needed to wait in the control process for child to set its controlling terminal so that closing the slave fd in kitty is safe.
This commit is contained in:
parent
0c870c5fcd
commit
116128ebb5
@ -194,11 +194,24 @@ class ProcessDesc(TypedDict):
|
|||||||
cmdline: Optional[Sequence[str]]
|
cmdline: Optional[Sequence[str]]
|
||||||
|
|
||||||
|
|
||||||
|
def is_prewarmable(argv: Sequence[str]) -> bool:
|
||||||
|
if len(argv) < 3 or os.path.basename(argv[0]) != 'kitty':
|
||||||
|
return False
|
||||||
|
if argv[1][:1] not in '@+':
|
||||||
|
return False
|
||||||
|
if argv[1][0] == '@':
|
||||||
|
return True
|
||||||
|
if argv[1] == '+':
|
||||||
|
return argv[2] != 'open'
|
||||||
|
return argv[1] != '+open'
|
||||||
|
|
||||||
|
|
||||||
class Child:
|
class Child:
|
||||||
|
|
||||||
child_fd: Optional[int] = None
|
child_fd: Optional[int] = None
|
||||||
pid: Optional[int] = None
|
pid: Optional[int] = None
|
||||||
forked = False
|
forked = False
|
||||||
|
is_prewarmed = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -262,6 +275,8 @@ class Child:
|
|||||||
self.forked = True
|
self.forked = True
|
||||||
master, slave = openpty()
|
master, slave = openpty()
|
||||||
stdin, self.stdin = self.stdin, None
|
stdin, self.stdin = self.stdin, None
|
||||||
|
self.is_prewarmed = is_prewarmable(self.argv)
|
||||||
|
if not self.is_prewarmed:
|
||||||
ready_read_fd, ready_write_fd = os.pipe()
|
ready_read_fd, ready_write_fd = os.pipe()
|
||||||
remove_cloexec(ready_read_fd)
|
remove_cloexec(ready_read_fd)
|
||||||
if stdin is not None:
|
if stdin is not None:
|
||||||
@ -290,12 +305,18 @@ class Child:
|
|||||||
argv[0] = (f'-{exe.split("/")[-1]}')
|
argv[0] = (f'-{exe.split("/")[-1]}')
|
||||||
self.final_exe = which(exe) or exe
|
self.final_exe = which(exe) or exe
|
||||||
self.final_argv0 = argv[0]
|
self.final_argv0 = argv[0]
|
||||||
|
if self.is_prewarmed:
|
||||||
|
fe = self.final_env()
|
||||||
|
self.prewarmed_child = fast_data_types.get_boss().prewarm(slave, self.argv, self.cwd, fe, stdin)
|
||||||
|
pid = self.prewarmed_child.child_process_pid
|
||||||
|
else:
|
||||||
pid = fast_data_types.spawn(
|
pid = fast_data_types.spawn(
|
||||||
self.final_exe, self.cwd, tuple(argv), env, master, slave,
|
self.final_exe, self.cwd, tuple(argv), env, master, slave, stdin_read_fd, stdin_write_fd,
|
||||||
stdin_read_fd, stdin_write_fd, ready_read_fd, ready_write_fd, tuple(handled_signals))
|
ready_read_fd, ready_write_fd, tuple(handled_signals))
|
||||||
os.close(slave)
|
os.close(slave)
|
||||||
self.pid = pid
|
self.pid = pid
|
||||||
self.child_fd = master
|
self.child_fd = master
|
||||||
|
if not self.is_prewarmed:
|
||||||
if stdin is not None:
|
if stdin is not None:
|
||||||
os.close(stdin_read_fd)
|
os.close(stdin_read_fd)
|
||||||
fast_data_types.thread_write(stdin_write_fd, stdin)
|
fast_data_types.thread_write(stdin_write_fd, stdin)
|
||||||
@ -306,6 +327,9 @@ class Child:
|
|||||||
return pid
|
return pid
|
||||||
|
|
||||||
def mark_terminal_ready(self) -> None:
|
def mark_terminal_ready(self) -> None:
|
||||||
|
if self.is_prewarmed:
|
||||||
|
fast_data_types.get_boss().prewarm.mark_child_as_ready(self.prewarmed_child.child_id)
|
||||||
|
else:
|
||||||
os.close(self.terminal_ready_fd)
|
os.close(self.terminal_ready_fd)
|
||||||
self.terminal_ready_fd = -1
|
self.terminal_ready_fd = -1
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from typing import (
|
|||||||
IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast
|
IO, TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union, cast
|
||||||
)
|
)
|
||||||
|
|
||||||
from kitty.child import remove_cloexec
|
|
||||||
from kitty.constants import kitty_exe
|
from kitty.constants import kitty_exe
|
||||||
from kitty.entry_points import main as main_entry_point
|
from kitty.entry_points import main as main_entry_point
|
||||||
from kitty.fast_data_types import (
|
from kitty.fast_data_types import (
|
||||||
@ -27,9 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from _typeshed import ReadableBuffer, WriteableBuffer
|
from _typeshed import ReadableBuffer, WriteableBuffer
|
||||||
|
|
||||||
|
|
||||||
hangup_events = select.POLLHUP
|
error_events = select.POLLERR | select.POLLNVAL | select.POLLHUP
|
||||||
error_events = select.POLLERR | select.POLLNVAL
|
|
||||||
basic_events = hangup_events | error_events
|
|
||||||
|
|
||||||
|
|
||||||
class PrewarmProcessFailed(Exception):
|
class PrewarmProcessFailed(Exception):
|
||||||
@ -98,11 +95,11 @@ class PrewarmProcess:
|
|||||||
os.set_blocking(self.write_to_process_fd, False)
|
os.set_blocking(self.write_to_process_fd, False)
|
||||||
os.set_blocking(self.read_from_process_fd, False)
|
os.set_blocking(self.read_from_process_fd, False)
|
||||||
self.poll = select.poll()
|
self.poll = select.poll()
|
||||||
self.poll.register(self.process.stdout.fileno(), select.POLLIN | basic_events)
|
self.poll.register(self.process.stdout.fileno(), select.POLLIN)
|
||||||
|
|
||||||
def poll_to_send(self, yes: bool = True) -> None:
|
def poll_to_send(self, yes: bool = True) -> None:
|
||||||
if yes:
|
if yes:
|
||||||
self.poll.register(self.write_to_process_fd, select.POLLOUT | basic_events)
|
self.poll.register(self.write_to_process_fd, select.POLLOUT)
|
||||||
else:
|
else:
|
||||||
self.poll.unregister(self.write_to_process_fd)
|
self.poll.unregister(self.write_to_process_fd)
|
||||||
|
|
||||||
@ -143,7 +140,7 @@ class PrewarmProcess:
|
|||||||
st = time.monotonic()
|
st = time.monotonic()
|
||||||
while time.monotonic() - st < 2:
|
while time.monotonic() - st < 2:
|
||||||
for (fd, event) in self.poll.poll(0.2):
|
for (fd, event) in self.poll.poll(0.2):
|
||||||
if event & basic_events:
|
if event & error_events:
|
||||||
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
|
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
|
||||||
if fd == self.read_from_process_fd and event & select.POLLIN:
|
if fd == self.read_from_process_fd and event & select.POLLIN:
|
||||||
d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE)
|
d = os.read(self.read_from_process_fd, io.DEFAULT_BUFFER_SIZE)
|
||||||
@ -171,8 +168,8 @@ class PrewarmProcess:
|
|||||||
while time.monotonic() - st < timeout and output_buf:
|
while time.monotonic() - st < timeout and output_buf:
|
||||||
self.poll_to_send(bool(output_buf))
|
self.poll_to_send(bool(output_buf))
|
||||||
for (fd, event) in self.poll.poll(0.2):
|
for (fd, event) in self.poll.poll(0.2):
|
||||||
if event & basic_events:
|
if event & error_events:
|
||||||
raise PrewarmProcessFailed('Failed doing I/O with prewarm process')
|
raise PrewarmProcessFailed('Failed doing I/O with prewarm process: {event}')
|
||||||
if fd == self.write_to_process_fd and event & select.POLLOUT:
|
if fd == self.write_to_process_fd and event & select.POLLOUT:
|
||||||
n = os.write(self.write_to_process_fd, output_buf)
|
n = os.write(self.write_to_process_fd, output_buf)
|
||||||
output_buf = output_buf[n:]
|
output_buf = output_buf[n:]
|
||||||
@ -198,6 +195,7 @@ def prewarm() -> None:
|
|||||||
reload_kitty_config()
|
reload_kitty_config()
|
||||||
from kittens.runner import all_kitten_names
|
from kittens.runner import all_kitten_names
|
||||||
for kitten in all_kitten_names():
|
for kitten in all_kitten_names():
|
||||||
|
with suppress(Exception):
|
||||||
import_module(f'kittens.{kitten}.main')
|
import_module(f'kittens.{kitten}.main')
|
||||||
import_module('kitty.complete')
|
import_module('kitty.complete')
|
||||||
|
|
||||||
@ -249,12 +247,8 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
|
|||||||
def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
||||||
cwd = cmd.get('cwd')
|
cwd = cmd.get('cwd')
|
||||||
if cwd:
|
if cwd:
|
||||||
try:
|
|
||||||
os.chdir(cwd)
|
|
||||||
except OSError:
|
|
||||||
with suppress(OSError):
|
with suppress(OSError):
|
||||||
os.chdir('/')
|
os.chdir(cwd)
|
||||||
os.setsid()
|
|
||||||
env = cmd.get('env')
|
env = cmd.get('env')
|
||||||
if env is not None:
|
if env is not None:
|
||||||
os.environ.clear()
|
os.environ.clear()
|
||||||
@ -264,13 +258,8 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
|||||||
sys.argv = list(argv)
|
sys.argv = list(argv)
|
||||||
poll = select.poll()
|
poll = select.poll()
|
||||||
poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP)
|
poll.register(ready_fd, select.POLLIN | select.POLLERR | select.POLLHUP)
|
||||||
poll.poll()
|
tuple(poll.poll())
|
||||||
os.close(ready_fd)
|
os.close(ready_fd)
|
||||||
tty_name = cmd.get('tty_name')
|
|
||||||
if tty_name:
|
|
||||||
sys.__stdout__.flush()
|
|
||||||
sys.__stderr__.flush()
|
|
||||||
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
|
||||||
main_entry_point()
|
main_entry_point()
|
||||||
raise SystemExit(0)
|
raise SystemExit(0)
|
||||||
|
|
||||||
@ -285,6 +274,8 @@ def fork(shm_address: str, ready_fd: int) -> int:
|
|||||||
pos = shm.tell()
|
pos = shm.tell()
|
||||||
shm.unlink_on_exit = False
|
shm.unlink_on_exit = False
|
||||||
|
|
||||||
|
r, w = os.pipe()
|
||||||
|
os.set_inheritable(r, False)
|
||||||
try:
|
try:
|
||||||
child_pid = os.fork()
|
child_pid = os.fork()
|
||||||
except OSError:
|
except OSError:
|
||||||
@ -293,8 +284,23 @@ def fork(shm_address: str, ready_fd: int) -> int:
|
|||||||
pass
|
pass
|
||||||
if child_pid:
|
if child_pid:
|
||||||
# master process
|
# master process
|
||||||
|
os.close(w)
|
||||||
|
try:
|
||||||
|
poll = select.poll()
|
||||||
|
poll.register(r, select.POLLIN | select.POLLHUP | select.POLLERR)
|
||||||
|
tuple(poll.poll())
|
||||||
|
finally:
|
||||||
|
os.close(r)
|
||||||
return child_pid
|
return child_pid
|
||||||
# child process
|
# child process
|
||||||
|
os.setsid()
|
||||||
|
tty_name = cmd.get('tty_name')
|
||||||
|
if tty_name:
|
||||||
|
sys.__stdout__.flush()
|
||||||
|
sys.__stderr__.flush()
|
||||||
|
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
||||||
|
os.write(w, b'1')
|
||||||
|
os.close(w)
|
||||||
if shm.unlink_on_exit:
|
if shm.unlink_on_exit:
|
||||||
child_main(cmd, ready_fd)
|
child_main(cmd, ready_fd)
|
||||||
else:
|
else:
|
||||||
@ -321,15 +327,16 @@ def main(notify_child_death_fd: int) -> None:
|
|||||||
stdout_fd = sys.__stdout__.fileno()
|
stdout_fd = sys.__stdout__.fileno()
|
||||||
os.set_blocking(stdout_fd, False)
|
os.set_blocking(stdout_fd, False)
|
||||||
poll = select.poll()
|
poll = select.poll()
|
||||||
poll.register(stdin_fd, select.POLLIN | basic_events)
|
poll.register(stdin_fd, select.POLLIN)
|
||||||
poll.register(read_signal_fd, select.POLLIN | basic_events)
|
poll.register(read_signal_fd, select.POLLIN)
|
||||||
input_buf = output_buf = child_death_buf = b''
|
input_buf = output_buf = child_death_buf = b''
|
||||||
child_ready_fds: Dict[int, int] = {}
|
child_ready_fds: Dict[int, int] = {}
|
||||||
child_id_map: Dict[int, int] = {}
|
child_id_map: Dict[int, int] = {}
|
||||||
self_pid = os.getpid()
|
self_pid = os.getpid()
|
||||||
|
os.setsid()
|
||||||
|
|
||||||
def check_event(event: int, err_msg: str) -> None:
|
def check_event(event: int, err_msg: str) -> None:
|
||||||
if event & hangup_events:
|
if event & select.POLLHUP:
|
||||||
raise SystemExit(0)
|
raise SystemExit(0)
|
||||||
if event & error_events:
|
if event & error_events:
|
||||||
raise SystemExit(err_msg)
|
raise SystemExit(err_msg)
|
||||||
@ -357,10 +364,10 @@ def main(notify_child_death_fd: int) -> None:
|
|||||||
os.write(cfd, b'1')
|
os.write(cfd, b'1')
|
||||||
os.close(cfd)
|
os.close(cfd)
|
||||||
elif cmd == 'fork':
|
elif cmd == 'fork':
|
||||||
read_fd, write_fd = safe_pipe(False)
|
r, w = os.pipe()
|
||||||
remove_cloexec(read_fd)
|
os.set_inheritable(w, False)
|
||||||
try:
|
try:
|
||||||
child_pid = fork(payload, read_fd)
|
child_pid = fork(payload, r)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
es = str(e).replace('\n', ' ')
|
es = str(e).replace('\n', ' ')
|
||||||
output_buf += f'ERR:{es}\n'.encode()
|
output_buf += f'ERR:{es}\n'.encode()
|
||||||
@ -368,11 +375,11 @@ def main(notify_child_death_fd: int) -> None:
|
|||||||
if os.getpid() == self_pid:
|
if os.getpid() == self_pid:
|
||||||
child_id = len(child_id_map) + 1
|
child_id = len(child_id_map) + 1
|
||||||
child_id_map[child_id] = child_pid
|
child_id_map[child_id] = child_pid
|
||||||
child_ready_fds[child_id] = write_fd
|
child_ready_fds[child_id] = w
|
||||||
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
|
output_buf += f'CHILD:{child_id}:{child_pid}\n'.encode()
|
||||||
finally:
|
finally:
|
||||||
if os.getpid() == self_pid:
|
if os.getpid() == self_pid:
|
||||||
os.close(read_fd)
|
os.close(r)
|
||||||
elif cmd == 'echo':
|
elif cmd == 'echo':
|
||||||
output_buf += f'{payload}\n'.encode()
|
output_buf += f'{payload}\n'.encode()
|
||||||
|
|
||||||
@ -428,14 +435,16 @@ def main(notify_child_death_fd: int) -> None:
|
|||||||
break
|
break
|
||||||
if matched_child_id > -1:
|
if matched_child_id > -1:
|
||||||
del child_id_map[matched_child_id]
|
del child_id_map[matched_child_id]
|
||||||
child_ready_fds.pop(matched_child_id, None)
|
ofd = child_ready_fds.pop(matched_child_id, None)
|
||||||
|
if ofd is not None:
|
||||||
|
os.close(ofd)
|
||||||
child_death_buf += f'{pid}\n'.encode()
|
child_death_buf += f'{pid}\n'.encode()
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if output_buf:
|
if output_buf:
|
||||||
poll.register(stdout_fd, select.POLLOUT | basic_events)
|
poll.register(stdout_fd, select.POLLOUT)
|
||||||
if child_death_buf:
|
if child_death_buf:
|
||||||
poll.register(notify_child_death_fd, select.POLLOUT | basic_events)
|
poll.register(notify_child_death_fd, select.POLLOUT)
|
||||||
for (q, event) in poll.poll():
|
for (q, event) in poll.poll():
|
||||||
if q == stdin_fd:
|
if q == stdin_fd:
|
||||||
handle_input(event)
|
handle_input(event)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user