diff --git a/kitty/prewarm.py b/kitty/prewarm.py index b37232090..08c2a131b 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -367,10 +367,6 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl sys.stdin = sys.__stdin__ -class SocketClosed(Exception): - pass - - def verify_socket_creds(conn: socket.socket) -> bool: # needed as abstract unix sockets used on Linux have no permissions and # older BSDs ignore socket file permissions @@ -385,6 +381,7 @@ class SocketChild: self.poll = poll self.addr = addr self.conn = conn + self.winsize = 8 self.poll.register(self.conn.fileno(), select.POLLIN) self.input_buf = self.output_buf = b'' self.fds: List[int] = [] @@ -395,6 +392,7 @@ class SocketChild: self.stdin = self.stdout = self.stderr = -1 self.pid = -1 self.closed = False + self.launch_msg_read = False def unregister_from_poll(self) -> None: if self.registered: @@ -403,7 +401,21 @@ class SocketChild: self.poll.unregister(self.conn.fileno()) self.registered = False - def read(self) -> bool: + def read(self) -> None: + import fcntl + import termios + msg = self.conn.recv(io.DEFAULT_BUFFER_SIZE) + if not msg: + return + self.input_buf += msg + data = memoryview(self.input_buf) + while len(data) >= self.winsize: + record, data = data[:self.winsize], data[self.winsize:] + with open(os.open(self.tty_name, os.O_RDWR | os.O_CLOEXEC | os.O_NOCTTY, 0), 'rb') as f: + fcntl.ioctl(f.fileno(), termios.TIOCSWINSZ, record) + self.input_buf = bytes(data) + + def read_launch_msg(self) -> bool: import array fds = array.array("i") # Array of ints try: @@ -421,13 +433,14 @@ class SocketChild: fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) self.fds += list(fds) if not msg: - raise SocketClosed('socket unexpectedly closed') + return False self.input_buf += msg while (idx := self.input_buf.find(b'\0')) > -1: line = self.input_buf[:idx].decode('utf-8') self.input_buf = self.input_buf[idx+1:] cmd, _, payload = line.partition(':') if cmd == 'finish': + self.launch_msg_read = True for x in self.fds: os.set_inheritable(x, x is not self.fds[0]) os.set_blocking(x, True) @@ -454,6 +467,8 @@ class SocketChild: self.stderr = int(payload) elif cmd == 'tty_name': self.tty_name = payload + elif cmd == 'winsize': + self.winsize = int(payload) return False @@ -506,6 +521,8 @@ class SocketChild: raise SystemExit(0) def handle_death(self, status: int) -> None: + if self.closed: + return if hasattr(os, 'waitstatus_to_exitcode'): status = os.waitstatus_to_exitcode(status) # negative numbers are signals usually and shells report these as @@ -517,12 +534,10 @@ class SocketChild: self.conn.sendall(f'{status}'.encode('ascii')) except OSError as e: print_error(f'Failed to send exit status of socket child with error: {e}') - with suppress(OSError): - self.conn.shutdown(socket.SHUT_RDWR) - with suppress(OSError): - self.conn.close() def handle_creation(self) -> bool: + if self.closed: + return False try: self.conn.sendall(f'{self.pid}:'.encode('ascii')) except OSError as e: @@ -535,7 +550,11 @@ class SocketChild: return self.unregister_from_poll() self.closed = True - self.conn.close() + if is_zygote: + with suppress(OSError): + self.conn.shutdown(socket.SHUT_RDWR) + with suppress(OSError): + self.conn.close() for x in self.fds: os.close(x) del self.fds[:] @@ -694,6 +713,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: if sc is not None: try: sc.handle_death(status) + except Exception: + import traceback + traceback.print_exc() finally: remove_socket_child(sc) else: @@ -711,30 +733,32 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: sc = SocketChild(conn, addr, poll) socket_children[sc.conn.fileno()] = sc - def handle_socket_launch(fd: int, event: int) -> None: + def handle_socket_input(fd: int, event: int) -> None: scq = socket_children.get(q) if scq is None: return if event & select.POLLIN: - try: - if scq.read(): - scq.unregister_from_poll() - scq.fork(free_non_child_resources) - socket_pid_map[scq.pid] = scq - scq.child_id = next(child_id_counter) - except SocketClosed: - if is_zygote: - remove_socket_child(scq) - except OSError: - if is_zygote: - remove_socket_child(scq) - import traceback - tb = traceback.format_exc() - print_error(f'Failed to fork socket child with error: {tb}') - else: - raise + if scq.launch_msg_read: + scq.read() + else: + try: + if scq.read_launch_msg(): + scq.fork(free_non_child_resources) + socket_pid_map[scq.pid] = scq + scq.child_id = next(child_id_counter) + except OSError: + if is_zygote: + remove_socket_child(scq) + import traceback + tb = traceback.format_exc() + print_error(f'Failed to fork socket child with error: {tb}') + else: + raise if is_zygote and (event & error_events): - remove_socket_child(scq) + if event & select.POLLHUP: + scq.unregister_from_poll() + else: + remove_socket_child(scq) keep_type_checker_happy = True try: @@ -755,7 +779,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: elif q == unix_socket.fileno(): handle_socket_client(event) else: - handle_socket_launch(q, event) + handle_socket_input(q, event) except (KeyboardInterrupt, EOFError, BrokenPipeError): if is_zygote: raise SystemExit(1) diff --git a/prewarm-launcher.c b/prewarm-launcher.c index aeb81b572..dc3c0066e 100644 --- a/prewarm-launcher.c +++ b/prewarm-launcher.c @@ -124,6 +124,14 @@ safe_read(int fd, void *buf, size_t n) { return ret; } +static ssize_t +safe_send(int fd, void *buf, size_t n, int flags) { + ssize_t ret = 0; + while((ret = send(fd, buf, n, flags)) ==-1 && errno == EINTR); + return ret; +} + + static ssize_t safe_write(int fd, void *buf, size_t n) { ssize_t ret = 0; @@ -307,6 +315,8 @@ create_launch_msg(int argc, char *argv[]) { #define w(prefix, data) { if (!write_item_to_launch_msg(prefix, data)) return false; } static char buf[4*PATH_MAX]; w("tty_name", child_tty_name); + snprintf(buf, sizeof(buf), "%zu", sizeof(self_winsize)); + w("winsize", buf); if (getcwd(buf, sizeof(buf))) { w("cwd", buf); } for (int i = 0; i < argc; i++) w("argv", argv[i]); char **s = environ; @@ -512,6 +522,40 @@ flush_data(void) { } } +static char sosbuf[2 * sizeof(self_winsize)] = {0}; +static transfer_buf send_on_socket = {.buf=sosbuf}; + +static void +add_window_size_to_buffer(void) { + char *p; + if (send_on_socket.sz % sizeof(self_winsize)) { + // partial send + if (send_on_socket.sz > sizeof(self_winsize)) send_on_socket.sz -= sizeof(self_winsize); // replace second size + p = send_on_socket.buf + send_on_socket.sz; + send_on_socket.sz += sizeof(self_winsize); + } else { + // replace all sizes + p = send_on_socket.buf; + send_on_socket.sz = sizeof(self_winsize); + } + memcpy(p, &self_winsize, sizeof(self_winsize)); +} + +static bool +send_over_socket(void) { + if (!send_on_socket.sz || socket_fd < 0) return true; + ssize_t n = safe_send(socket_fd, send_on_socket.buf, send_on_socket.sz, MSG_NOSIGNAL); + if (n < 0) return false; + if (n) { + if (n >= send_on_socket.sz) send_on_socket.sz = 0; + else { + send_on_socket.sz -= n; + memmove(send_on_socket.buf, send_on_socket.buf + n, send_on_socket.sz); + } + } + return true; +} + static void loop(void) { #define fail(s) { print_error(s, errno); return; } @@ -523,15 +567,17 @@ loop(void) { nfds++; while (keep_going()) { - wf.self_ttyfd.want_read = to_child_tty.sz < IO_BUZ_SZ; wf.self_ttyfd.want_write = from_child_tty.sz > 0; - wf.child_master_fd.want_read = from_child_tty.sz < IO_BUZ_SZ; wf.child_master_fd.want_write = to_child_tty.sz > 0; - wf.socket_fd.want_write = launch_msg.iov_len > 0; - - if (window_size_dirty && child_master_fd > -1 ) { + if (window_size_dirty) { if (!get_window_size()) fail("getting window size for self tty failed"); - if (!safe_winsz(child_master_fd, TIOCSWINSZ, &self_winsize)) fail("setting window size on child pty failed"); + // macOS barfs with ENOTTY if we try to use TIOCSWINSZ from this process, so send it to the zygote + /* if (!safe_winsz(child_master_fd, TIOCSWINSZ, &self_winsize)) fail("setting window size on child pty failed"); */ + add_window_size_to_buffer(); window_size_dirty = false; } + wf.self_ttyfd.want_read = to_child_tty.sz < IO_BUZ_SZ; wf.self_ttyfd.want_write = from_child_tty.sz > 0; + wf.child_master_fd.want_read = from_child_tty.sz < IO_BUZ_SZ; wf.child_master_fd.want_write = to_child_tty.sz > 0; + wf.socket_fd.want_write = launch_msg.iov_len > 0 || send_on_socket.sz > 0; + FD_ZERO(&readable); FD_ZERO(&writable); FD_ZERO(&errorable); #define set(which) if (which > -1) { if (wf.which.want_read) { FD_SET(which, &readable); } if (wf.which.want_write) { FD_SET(which, &writable); } if (wf.which.want_error) { FD_SET(which, &errorable); } } set(self_ttyfd); set(child_master_fd); set(socket_fd); set(signal_read_fd); @@ -553,7 +599,10 @@ loop(void) { if (signal_read_fd > -1 && FD_ISSET(signal_read_fd, &readable)) if (!read_signals()) fail("reading from signal fd failed"); if (socket_fd > -1) { - if (FD_ISSET(socket_fd, &writable)) if (!send_launch_msg()) fail("sending launch message failed"); + if (FD_ISSET(socket_fd, &writable)) { + if (launch_msg.iov_len > 0) { if (!send_launch_msg()) fail("sending launch message failed"); } + else if (send_on_socket.sz > 0) { if (!send_over_socket()) fail("sending on socket failed"); } + } if (FD_ISSET(socket_fd, &readable)) { if (!read_child_data()) fail("reading information about child failed"); if (socket_fd < 0) { // hangup