Start work on socket based prewarm
This commit is contained in:
parent
7a31c7ff50
commit
b222ab1bf6
@ -186,17 +186,17 @@ spawn(PyObject *self UNUSED, PyObject *args) {
|
||||
static PyObject*
|
||||
establish_controlling_tty(PyObject *self UNUSED, PyObject *args) {
|
||||
const char *ttyname;
|
||||
int stdin_fd, stdout_fd, stderr_fd;
|
||||
if (!PyArg_ParseTuple(args, "siii", &ttyname, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL;
|
||||
int stdin_fd = -1, stdout_fd = -1, stderr_fd = -1;
|
||||
if (!PyArg_ParseTuple(args, "s|iii", &ttyname, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL;
|
||||
int tfd = safe_open(ttyname, O_RDWR, 0);
|
||||
if (tfd == -1) return PyErr_SetFromErrnoWithFilename(PyExc_OSError, ttyname);
|
||||
#ifdef TIOCSCTTY
|
||||
// On BSD open() does not establish the controlling terminal
|
||||
if (ioctl(tfd, TIOCSCTTY, 0) == -1) return PyErr_SetFromErrno(PyExc_OSError);
|
||||
#endif
|
||||
if (dup2(tfd, stdin_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
|
||||
if (dup2(tfd, stdout_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
|
||||
if (dup2(tfd, stderr_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
|
||||
if (stdin_fd > -1 && dup2(tfd, stdin_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
|
||||
if (stdout_fd > -1 && dup2(tfd, stdout_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
|
||||
if (stderr_fd > -1 && dup2(tfd, stderr_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
|
||||
safe_close(tfd, __FILE__, __LINE__);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
@ -1395,7 +1395,7 @@ def sigqueue(pid: int, signal: int, value: int) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def establish_controlling_tty(ttyname: str, stdin: int, stdout: int, stderr: int) -> None:
|
||||
def establish_controlling_tty(ttyname: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
203
kitty/prewarm.py
203
kitty/prewarm.py
@ -6,6 +6,7 @@ import json
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
@ -27,7 +28,7 @@ from kitty.fast_data_types import (
|
||||
from kitty.options.types import Options
|
||||
from kitty.shm import SharedMemory
|
||||
from kitty.types import SignalInfo
|
||||
from kitty.utils import log_error
|
||||
from kitty.utils import log_error, random_unix_socket
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import ReadableBuffer, WriteableBuffer
|
||||
@ -69,6 +70,7 @@ class PrewarmProcess:
|
||||
to_prewarm_stdin: int,
|
||||
from_prewarm_stdout: int,
|
||||
from_prewarm_death_notify: int,
|
||||
unix_socket_name: str,
|
||||
) -> None:
|
||||
self.children: Dict[int, Child] = {}
|
||||
self.worker_pid = prewarm_process_pid
|
||||
@ -77,6 +79,7 @@ class PrewarmProcess:
|
||||
self.read_from_process_fd = from_prewarm_stdout
|
||||
self.poll = select.poll()
|
||||
self.poll.register(self.read_from_process_fd, select.POLLIN)
|
||||
self.unix_socket_name = unix_socket_name
|
||||
|
||||
def take_from_worker_fd(self, create_file: bool = False) -> int:
|
||||
if create_file:
|
||||
@ -256,7 +259,7 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
|
||||
super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace')
|
||||
|
||||
|
||||
def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
||||
def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
|
||||
cwd = cmd.get('cwd')
|
||||
if cwd:
|
||||
with suppress(OSError):
|
||||
@ -268,10 +271,11 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
|
||||
argv = cmd.get('argv')
|
||||
if argv:
|
||||
sys.argv = list(argv)
|
||||
poll = select.poll()
|
||||
poll.register(ready_fd, select.POLLIN)
|
||||
tuple(poll.poll())
|
||||
os.close(ready_fd)
|
||||
if ready_fd > -1:
|
||||
poll = select.poll()
|
||||
poll.register(ready_fd, select.POLLIN)
|
||||
tuple(poll.poll())
|
||||
os.close(ready_fd)
|
||||
main_entry_point()
|
||||
raise SystemExit(0)
|
||||
|
||||
@ -291,9 +295,14 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
||||
try:
|
||||
child_pid = os.fork()
|
||||
except OSError:
|
||||
os.close(r)
|
||||
os.close(w)
|
||||
os.close(ready_fd_read)
|
||||
os.close(ready_fd_write)
|
||||
if sz:
|
||||
with SharedMemory(shm_address, unlink_on_exit=True):
|
||||
pass
|
||||
raise
|
||||
if child_pid:
|
||||
# master process
|
||||
os.close(w)
|
||||
@ -308,7 +317,8 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
||||
os.close(r)
|
||||
os.close(ready_fd_write)
|
||||
for fd in all_non_child_fds:
|
||||
os.close(fd)
|
||||
if fd > -1:
|
||||
os.close(fd)
|
||||
os.setsid()
|
||||
tty_name = cmd.get('tty_name')
|
||||
if tty_name:
|
||||
@ -331,17 +341,123 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
||||
sys.stdin = sys.__stdin__
|
||||
|
||||
|
||||
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
class SocketClosed(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SocketChild:
|
||||
|
||||
def __init__(self, conn: socket.socket, addr: bytes):
|
||||
self.fd = conn.fileno()
|
||||
self.addr = addr
|
||||
self.conn = conn
|
||||
self.input_buf = self.output_buf = b''
|
||||
self.fds: List[int] = []
|
||||
self.child_id = -1
|
||||
self.cwd = self.tty_name = ''
|
||||
self.env: Dict[str, str] = {}
|
||||
self.argv: List[str] = []
|
||||
self.stdin = self.stdout = self.stderr = -1
|
||||
self.pid = -1
|
||||
|
||||
def read(self) -> bool:
|
||||
import array
|
||||
fds = array.array("i") # Array of ints
|
||||
maxfds = 3
|
||||
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, socket.CMSG_LEN(maxfds * fds.itemsize))
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
|
||||
# Append data, ignoring any truncated integers at the end.
|
||||
fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
|
||||
self.fds += list(fds)
|
||||
if not msg:
|
||||
raise SocketClosed('socket unexpectedly closed')
|
||||
self.input_buf += msg
|
||||
while (idx := self.input_buf.find(b'\0')) > -1:
|
||||
line = self.input_buf[:idx].decode('utf-8')
|
||||
self.input_buf = self.input_buf[idx+1:]
|
||||
cmd, _, payload = line.partition(':')
|
||||
if cmd == 'finish':
|
||||
if self.stdin > -1:
|
||||
self.stdin = self.fds[self.stdin]
|
||||
if self.stdout > -1:
|
||||
self.stdout = self.fds[self.stdout]
|
||||
if self.stderr > -1:
|
||||
self.stderr = self.fds[self.stderr]
|
||||
return True
|
||||
elif cmd == 'cwd':
|
||||
self.cwd = payload
|
||||
elif cmd == 'tty_name':
|
||||
self.tty_name = payload
|
||||
elif cmd == 'env':
|
||||
k, _, v = payload.partition('=')
|
||||
self.env[k] = v
|
||||
elif cmd == 'argv':
|
||||
self.argv.append(payload)
|
||||
elif cmd == 'stdin':
|
||||
self.stdin = int(payload)
|
||||
elif cmd == 'stdout':
|
||||
self.stdout = int(payload)
|
||||
elif cmd == 'stderr':
|
||||
self.stderr = int(payload)
|
||||
|
||||
return False
|
||||
|
||||
def fork(self, all_non_child_fds: Iterable[int]) -> None:
|
||||
self.pid = os.fork()
|
||||
if self.pid > 0:
|
||||
# master process
|
||||
if self.stdin > -1:
|
||||
os.close(self.stdin)
|
||||
self.stdin = -1
|
||||
if self.stdout > -1:
|
||||
os.close(self.stdout)
|
||||
self.stdout = -1
|
||||
if self.stderr > -1:
|
||||
os.close(self.stderr)
|
||||
self.stderr = -1
|
||||
return
|
||||
# child process
|
||||
os.setsid()
|
||||
remove_signal_handlers()
|
||||
if self.tty_name:
|
||||
sys.__stdout__.flush()
|
||||
sys.__stderr__.flush()
|
||||
establish_controlling_tty(
|
||||
self.tty_name,
|
||||
sys.__stdin__.fileno() if self.stdin == -1 else -1,
|
||||
sys.__stdout__.fileno() if self.stdout == -1 else -1,
|
||||
sys.__stderr__.fileno() if self.stderr == -1 else -1)
|
||||
# the std streams fds are in all_non_child_fds already
|
||||
# so they will be closed there
|
||||
if self.stdin > -1:
|
||||
os.dup2(self.stdin, sys.__stdin__.fileno())
|
||||
if self.stdout > -1:
|
||||
os.dup2(self.stdout, sys.__stdout__.fileno())
|
||||
if self.stderr > -1:
|
||||
os.dup2(self.stderr, sys.__stderr__.fileno())
|
||||
for fd in all_non_child_fds:
|
||||
if fd > -1:
|
||||
os.close(fd)
|
||||
child_main({'cwd': self.cwd, 'env': self.env, 'argv': self.argv})
|
||||
raise SystemExit(0)
|
||||
|
||||
|
||||
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
|
||||
os.set_blocking(notify_child_death_fd, False)
|
||||
os.set_blocking(stdin_fd, False)
|
||||
os.set_blocking(stdout_fd, False)
|
||||
signal_read_fd = install_signal_handlers(signal.SIGCHLD, signal.SIGUSR1)[0]
|
||||
os.set_blocking(unix_socket.fileno(), False)
|
||||
unix_socket.listen(5)
|
||||
poll = select.poll()
|
||||
poll.register(stdin_fd, select.POLLIN)
|
||||
poll.register(signal_read_fd, select.POLLIN)
|
||||
poll.register(unix_socket.fileno(), select.POLLIN)
|
||||
input_buf = output_buf = child_death_buf = b''
|
||||
child_ready_fds: Dict[int, int] = {}
|
||||
child_pid_map: Dict[int, int] = {}
|
||||
socket_children: Dict[int, SocketChild] = {}
|
||||
child_id_counter = count()
|
||||
self_pid = os.getpid()
|
||||
# runpy issues a warning when running modules that have already been
|
||||
@ -355,6 +471,14 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
yield stdout_fd
|
||||
# the signal fds are closed by remove_signal_handlers()
|
||||
yield from child_ready_fds.values()
|
||||
for sc in socket_children.values():
|
||||
yield sc.fd
|
||||
if sc.stdin > -1:
|
||||
yield sc.stdin
|
||||
if sc.stdout > -1:
|
||||
yield sc.stdout
|
||||
if sc.stderr > -1:
|
||||
yield sc.stderr
|
||||
|
||||
def check_event(event: int, err_msg: str) -> None:
|
||||
if event & select.POLLHUP:
|
||||
@ -448,11 +572,25 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
if not pid:
|
||||
break
|
||||
child_id = child_pid_map.pop(pid, None)
|
||||
if child_id is not None:
|
||||
if child_id is None:
|
||||
for sc in socket_children.values():
|
||||
if sc.pid == pid:
|
||||
sc.conn.sendall(f'{status}'.encode('ascii'))
|
||||
sc.conn.shutdown(socket.SHUT_RDWR)
|
||||
sc.conn.close()
|
||||
break
|
||||
else:
|
||||
handle_child_death(child_id, pid)
|
||||
|
||||
read_signals(signal_read_fd, handle_signal)
|
||||
|
||||
def handle_socket_client(event: int) -> None:
|
||||
check_event(event, 'UNIX socket fd listener failed')
|
||||
conn, addr = unix_socket.accept()
|
||||
sc = SocketChild(conn, addr)
|
||||
socket_children[sc.fd] = sc
|
||||
poll.register(sc.fd, select.POLLIN)
|
||||
|
||||
try:
|
||||
while True:
|
||||
if output_buf:
|
||||
@ -468,6 +606,24 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
handle_signals(event)
|
||||
elif q == notify_child_death_fd:
|
||||
handle_notify_child_death(event)
|
||||
elif q == unix_socket.fileno():
|
||||
handle_socket_client(event)
|
||||
else:
|
||||
scq = socket_children.get(q)
|
||||
if scq is not None:
|
||||
if event & select.POLLIN:
|
||||
try:
|
||||
if scq.read():
|
||||
poll.unregister(scq.fd)
|
||||
scq.conn.shutdown(socket.SHUT_RD)
|
||||
scq.fork(get_all_non_child_fds())
|
||||
scq.child_id = child_pid_map[scq.pid] = next(child_id_counter)
|
||||
except SocketClosed:
|
||||
socket_children.pop(q)
|
||||
continue
|
||||
if event & error_events:
|
||||
socket_children.pop(q)
|
||||
continue
|
||||
except (KeyboardInterrupt, EOFError, BrokenPipeError):
|
||||
if os.getpid() == self_pid:
|
||||
raise SystemExit(1)
|
||||
@ -485,7 +641,17 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
|
||||
os.close(fmd)
|
||||
|
||||
|
||||
def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int) -> None:
|
||||
def get_socket_name(unix_socket: socket.socket) -> str:
|
||||
sname = unix_socket.getsockname()
|
||||
if isinstance(sname, bytes):
|
||||
sname = sname.decode('utf-8')
|
||||
assert isinstance(sname, str)
|
||||
if sname.startswith('\0'):
|
||||
sname = '@' + sname[1:]
|
||||
return sname
|
||||
|
||||
|
||||
def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int, unix_socket: Optional[socket.socket] = None) -> None:
|
||||
os.setsid()
|
||||
# SIGUSR1 is used for reloading kitty config, we rely on the parent process
|
||||
# to inform us of that
|
||||
@ -495,9 +661,11 @@ def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int) -> No
|
||||
os.set_inheritable(stdout_write, False)
|
||||
os.set_inheritable(death_notify_write, False)
|
||||
running_in_kitty(False)
|
||||
|
||||
if unix_socket is None:
|
||||
unix_socket = random_unix_socket()
|
||||
os.write(stdout_write, f'{get_socket_name(unix_socket)}\n'.encode('utf-8'))
|
||||
try:
|
||||
main(stdin_read, stdout_write, death_notify_write)
|
||||
main(stdin_read, stdout_write, death_notify_write, unix_socket)
|
||||
finally:
|
||||
set_options(None)
|
||||
|
||||
@ -513,14 +681,21 @@ def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[Prew
|
||||
pass_fds=(stdin_read, stdout_write, death_notify_write))
|
||||
child_pid = tp.pid
|
||||
tp.returncode = 0 # prevent a warning when the popen object is deleted with the process still running
|
||||
os.set_blocking(stdout_read, True)
|
||||
with open(stdout_read, 'rb', closefd=False) as f:
|
||||
socket_name = f.readline().decode('utf-8').rstrip()
|
||||
os.set_blocking(stdout_read, False)
|
||||
else:
|
||||
unix_socket = random_unix_socket()
|
||||
socket_name = get_socket_name(unix_socket)
|
||||
child_pid = os.fork()
|
||||
if child_pid:
|
||||
# master
|
||||
unix_socket.close()
|
||||
os.close(stdin_read)
|
||||
os.close(stdout_write)
|
||||
os.close(death_notify_write)
|
||||
p = PrewarmProcess(child_pid, stdin_write, stdout_read, death_notify_read)
|
||||
p = PrewarmProcess(child_pid, stdin_write, stdout_read, death_notify_read, socket_name)
|
||||
if use_exec:
|
||||
p.reload_kitty_config()
|
||||
return p
|
||||
@ -529,5 +704,5 @@ def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[Prew
|
||||
os.close(stdout_read)
|
||||
os.close(death_notify_read)
|
||||
set_options(opts)
|
||||
exec_main(stdin_read, stdout_write, death_notify_write)
|
||||
exec_main(stdin_read, stdout_write, death_notify_write, unix_socket)
|
||||
raise SystemExit(0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user