Start work on socket based prewarm

This commit is contained in:
Kovid Goyal 2022-06-13 21:34:49 +05:30
parent 7a31c7ff50
commit b222ab1bf6
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 195 additions and 20 deletions

View File

@ -186,17 +186,17 @@ spawn(PyObject *self UNUSED, PyObject *args) {
static PyObject*
establish_controlling_tty(PyObject *self UNUSED, PyObject *args) {
const char *ttyname;
int stdin_fd, stdout_fd, stderr_fd;
if (!PyArg_ParseTuple(args, "siii", &ttyname, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL;
int stdin_fd = -1, stdout_fd = -1, stderr_fd = -1;
if (!PyArg_ParseTuple(args, "s|iii", &ttyname, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL;
int tfd = safe_open(ttyname, O_RDWR, 0);
if (tfd == -1) return PyErr_SetFromErrnoWithFilename(PyExc_OSError, ttyname);
#ifdef TIOCSCTTY
// On BSD open() does not establish the controlling terminal
if (ioctl(tfd, TIOCSCTTY, 0) == -1) return PyErr_SetFromErrno(PyExc_OSError);
#endif
if (dup2(tfd, stdin_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
if (dup2(tfd, stdout_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
if (dup2(tfd, stderr_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
if (stdin_fd > -1 && dup2(tfd, stdin_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
if (stdout_fd > -1 && dup2(tfd, stdout_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
if (stderr_fd > -1 && dup2(tfd, stderr_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
safe_close(tfd, __FILE__, __LINE__);
Py_RETURN_NONE;
}

View File

@ -1395,7 +1395,7 @@ def sigqueue(pid: int, signal: int, value: int) -> None:
pass
def establish_controlling_tty(ttyname: str, stdin: int, stdout: int, stderr: int) -> None:
def establish_controlling_tty(ttyname: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None:
pass

View File

@ -6,6 +6,7 @@ import json
import os
import select
import signal
import socket
import sys
import time
import warnings
@ -27,7 +28,7 @@ from kitty.fast_data_types import (
from kitty.options.types import Options
from kitty.shm import SharedMemory
from kitty.types import SignalInfo
from kitty.utils import log_error
from kitty.utils import log_error, random_unix_socket
if TYPE_CHECKING:
from _typeshed import ReadableBuffer, WriteableBuffer
@ -69,6 +70,7 @@ class PrewarmProcess:
to_prewarm_stdin: int,
from_prewarm_stdout: int,
from_prewarm_death_notify: int,
unix_socket_name: str,
) -> None:
self.children: Dict[int, Child] = {}
self.worker_pid = prewarm_process_pid
@ -77,6 +79,7 @@ class PrewarmProcess:
self.read_from_process_fd = from_prewarm_stdout
self.poll = select.poll()
self.poll.register(self.read_from_process_fd, select.POLLIN)
self.unix_socket_name = unix_socket_name
def take_from_worker_fd(self, create_file: bool = False) -> int:
if create_file:
@ -256,7 +259,7 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace')
def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
cwd = cmd.get('cwd')
if cwd:
with suppress(OSError):
@ -268,10 +271,11 @@ def child_main(cmd: Dict[str, Any], ready_fd: int) -> NoReturn:
argv = cmd.get('argv')
if argv:
sys.argv = list(argv)
poll = select.poll()
poll.register(ready_fd, select.POLLIN)
tuple(poll.poll())
os.close(ready_fd)
if ready_fd > -1:
poll = select.poll()
poll.register(ready_fd, select.POLLIN)
tuple(poll.poll())
os.close(ready_fd)
main_entry_point()
raise SystemExit(0)
@ -291,9 +295,14 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
try:
child_pid = os.fork()
except OSError:
os.close(r)
os.close(w)
os.close(ready_fd_read)
os.close(ready_fd_write)
if sz:
with SharedMemory(shm_address, unlink_on_exit=True):
pass
raise
if child_pid:
# master process
os.close(w)
@ -308,7 +317,8 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
os.close(r)
os.close(ready_fd_write)
for fd in all_non_child_fds:
os.close(fd)
if fd > -1:
os.close(fd)
os.setsid()
tty_name = cmd.get('tty_name')
if tty_name:
@ -331,17 +341,123 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
sys.stdin = sys.__stdin__
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
class SocketClosed(Exception):
pass
class SocketChild:
def __init__(self, conn: socket.socket, addr: bytes):
self.fd = conn.fileno()
self.addr = addr
self.conn = conn
self.input_buf = self.output_buf = b''
self.fds: List[int] = []
self.child_id = -1
self.cwd = self.tty_name = ''
self.env: Dict[str, str] = {}
self.argv: List[str] = []
self.stdin = self.stdout = self.stderr = -1
self.pid = -1
def read(self) -> bool:
import array
fds = array.array("i") # Array of ints
maxfds = 3
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, socket.CMSG_LEN(maxfds * fds.itemsize))
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
# Append data, ignoring any truncated integers at the end.
fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
self.fds += list(fds)
if not msg:
raise SocketClosed('socket unexpectedly closed')
self.input_buf += msg
while (idx := self.input_buf.find(b'\0')) > -1:
line = self.input_buf[:idx].decode('utf-8')
self.input_buf = self.input_buf[idx+1:]
cmd, _, payload = line.partition(':')
if cmd == 'finish':
if self.stdin > -1:
self.stdin = self.fds[self.stdin]
if self.stdout > -1:
self.stdout = self.fds[self.stdout]
if self.stderr > -1:
self.stderr = self.fds[self.stderr]
return True
elif cmd == 'cwd':
self.cwd = payload
elif cmd == 'tty_name':
self.tty_name = payload
elif cmd == 'env':
k, _, v = payload.partition('=')
self.env[k] = v
elif cmd == 'argv':
self.argv.append(payload)
elif cmd == 'stdin':
self.stdin = int(payload)
elif cmd == 'stdout':
self.stdout = int(payload)
elif cmd == 'stderr':
self.stderr = int(payload)
return False
def fork(self, all_non_child_fds: Iterable[int]) -> None:
self.pid = os.fork()
if self.pid > 0:
# master process
if self.stdin > -1:
os.close(self.stdin)
self.stdin = -1
if self.stdout > -1:
os.close(self.stdout)
self.stdout = -1
if self.stderr > -1:
os.close(self.stderr)
self.stderr = -1
return
# child process
os.setsid()
remove_signal_handlers()
if self.tty_name:
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(
self.tty_name,
sys.__stdin__.fileno() if self.stdin == -1 else -1,
sys.__stdout__.fileno() if self.stdout == -1 else -1,
sys.__stderr__.fileno() if self.stderr == -1 else -1)
# the std streams fds are in all_non_child_fds already
# so they will be closed there
if self.stdin > -1:
os.dup2(self.stdin, sys.__stdin__.fileno())
if self.stdout > -1:
os.dup2(self.stdout, sys.__stdout__.fileno())
if self.stderr > -1:
os.dup2(self.stderr, sys.__stderr__.fileno())
for fd in all_non_child_fds:
if fd > -1:
os.close(fd)
child_main({'cwd': self.cwd, 'env': self.env, 'argv': self.argv})
raise SystemExit(0)
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
os.set_blocking(notify_child_death_fd, False)
os.set_blocking(stdin_fd, False)
os.set_blocking(stdout_fd, False)
signal_read_fd = install_signal_handlers(signal.SIGCHLD, signal.SIGUSR1)[0]
os.set_blocking(unix_socket.fileno(), False)
unix_socket.listen(5)
poll = select.poll()
poll.register(stdin_fd, select.POLLIN)
poll.register(signal_read_fd, select.POLLIN)
poll.register(unix_socket.fileno(), select.POLLIN)
input_buf = output_buf = child_death_buf = b''
child_ready_fds: Dict[int, int] = {}
child_pid_map: Dict[int, int] = {}
socket_children: Dict[int, SocketChild] = {}
child_id_counter = count()
self_pid = os.getpid()
# runpy issues a warning when running modules that have already been
@ -355,6 +471,14 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
yield stdout_fd
# the signal fds are closed by remove_signal_handlers()
yield from child_ready_fds.values()
for sc in socket_children.values():
yield sc.fd
if sc.stdin > -1:
yield sc.stdin
if sc.stdout > -1:
yield sc.stdout
if sc.stderr > -1:
yield sc.stderr
def check_event(event: int, err_msg: str) -> None:
if event & select.POLLHUP:
@ -448,11 +572,25 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
if not pid:
break
child_id = child_pid_map.pop(pid, None)
if child_id is not None:
if child_id is None:
for sc in socket_children.values():
if sc.pid == pid:
sc.conn.sendall(f'{status}'.encode('ascii'))
sc.conn.shutdown(socket.SHUT_RDWR)
sc.conn.close()
break
else:
handle_child_death(child_id, pid)
read_signals(signal_read_fd, handle_signal)
def handle_socket_client(event: int) -> None:
check_event(event, 'UNIX socket fd listener failed')
conn, addr = unix_socket.accept()
sc = SocketChild(conn, addr)
socket_children[sc.fd] = sc
poll.register(sc.fd, select.POLLIN)
try:
while True:
if output_buf:
@ -468,6 +606,24 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
handle_signals(event)
elif q == notify_child_death_fd:
handle_notify_child_death(event)
elif q == unix_socket.fileno():
handle_socket_client(event)
else:
scq = socket_children.get(q)
if scq is not None:
if event & select.POLLIN:
try:
if scq.read():
poll.unregister(scq.fd)
scq.conn.shutdown(socket.SHUT_RD)
scq.fork(get_all_non_child_fds())
scq.child_id = child_pid_map[scq.pid] = next(child_id_counter)
except SocketClosed:
socket_children.pop(q)
continue
if event & error_events:
socket_children.pop(q)
continue
except (KeyboardInterrupt, EOFError, BrokenPipeError):
if os.getpid() == self_pid:
raise SystemExit(1)
@ -485,7 +641,17 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int) -> None:
os.close(fmd)
def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int) -> None:
def get_socket_name(unix_socket: socket.socket) -> str:
sname = unix_socket.getsockname()
if isinstance(sname, bytes):
sname = sname.decode('utf-8')
assert isinstance(sname, str)
if sname.startswith('\0'):
sname = '@' + sname[1:]
return sname
def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int, unix_socket: Optional[socket.socket] = None) -> None:
os.setsid()
# SIGUSR1 is used for reloading kitty config, we rely on the parent process
# to inform us of that
@ -495,9 +661,11 @@ def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int) -> No
os.set_inheritable(stdout_write, False)
os.set_inheritable(death_notify_write, False)
running_in_kitty(False)
if unix_socket is None:
unix_socket = random_unix_socket()
os.write(stdout_write, f'{get_socket_name(unix_socket)}\n'.encode('utf-8'))
try:
main(stdin_read, stdout_write, death_notify_write)
main(stdin_read, stdout_write, death_notify_write, unix_socket)
finally:
set_options(None)
@ -513,14 +681,21 @@ def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[Prew
pass_fds=(stdin_read, stdout_write, death_notify_write))
child_pid = tp.pid
tp.returncode = 0 # prevent a warning when the popen object is deleted with the process still running
os.set_blocking(stdout_read, True)
with open(stdout_read, 'rb', closefd=False) as f:
socket_name = f.readline().decode('utf-8').rstrip()
os.set_blocking(stdout_read, False)
else:
unix_socket = random_unix_socket()
socket_name = get_socket_name(unix_socket)
child_pid = os.fork()
if child_pid:
# master
unix_socket.close()
os.close(stdin_read)
os.close(stdout_write)
os.close(death_notify_write)
p = PrewarmProcess(child_pid, stdin_write, stdout_read, death_notify_read)
p = PrewarmProcess(child_pid, stdin_write, stdout_read, death_notify_read, socket_name)
if use_exec:
p.reload_kitty_config()
return p
@ -529,5 +704,5 @@ def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[Prew
os.close(stdout_read)
os.close(death_notify_read)
set_options(opts)
exec_main(stdin_read, stdout_write, death_notify_write)
exec_main(stdin_read, stdout_write, death_notify_write, unix_socket)
raise SystemExit(0)