diff --git a/kitty/child.c b/kitty/child.c index 0169b0038..d8a60b063 100644 --- a/kitty/child.c +++ b/kitty/child.c @@ -128,14 +128,14 @@ spawn(PyObject *self UNUSED, PyObject *args) { safe_close(tfd, __FILE__, __LINE__); // Redirect stdin/stdout/stderr to the pty - if (dup2(slave, 1) == -1) exit_on_err("dup2() failed for fd number 1"); - if (dup2(slave, 2) == -1) exit_on_err("dup2() failed for fd number 2"); + if (safe_dup2(slave, 1) == -1) exit_on_err("dup2() failed for fd number 1"); + if (safe_dup2(slave, 2) == -1) exit_on_err("dup2() failed for fd number 2"); if (stdin_read_fd > -1) { - if (dup2(stdin_read_fd, 0) == -1) exit_on_err("dup2() failed for fd number 0"); + if (safe_dup2(stdin_read_fd, 0) == -1) exit_on_err("dup2() failed for fd number 0"); safe_close(stdin_read_fd, __FILE__, __LINE__); safe_close(stdin_write_fd, __FILE__, __LINE__); } else { - if (dup2(slave, 0) == -1) exit_on_err("dup2() failed for fd number 0"); + if (safe_dup2(slave, 0) == -1) exit_on_err("dup2() failed for fd number 0"); } safe_close(slave, __FILE__, __LINE__); safe_close(master, __FILE__, __LINE__); @@ -185,19 +185,13 @@ spawn(PyObject *self UNUSED, PyObject *args) { static PyObject* establish_controlling_tty(PyObject *self UNUSED, PyObject *args) { - const char *ttyname; - int stdin_fd = -1, stdout_fd = -1, stderr_fd = -1; - if (!PyArg_ParseTuple(args, "s|iii", &ttyname, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL; - int tfd = safe_open(ttyname, O_RDWR, 0); - if (tfd == -1) return PyErr_SetFromErrnoWithFilename(PyExc_OSError, ttyname); -#ifdef TIOCSCTTY - // On BSD open() does not establish the controlling terminal - if (ioctl(tfd, TIOCSCTTY, 0) == -1) return PyErr_SetFromErrno(PyExc_OSError); -#endif - if (stdin_fd > -1 && dup2(tfd, stdin_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); - if (stdout_fd > -1 && dup2(tfd, stdout_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); - if (stderr_fd > -1 && dup2(tfd, stderr_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError); - safe_close(tfd, __FILE__, __LINE__); + int tty_fd, stdin_fd = -1, stdout_fd = -1, stderr_fd = -1; + if (!PyArg_ParseTuple(args, "i|iii", &tty_fd, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL; + if (ioctl(tty_fd, TIOCSCTTY, 0) == -1) { safe_close(tty_fd, __FILE__, __LINE__); return PyErr_SetFromErrno(PyExc_OSError); } + if (stdin_fd > -1 && safe_dup2(tty_fd, stdin_fd) == -1) { safe_close(tty_fd, __FILE__, __LINE__); return PyErr_SetFromErrno(PyExc_OSError); } + if (stdout_fd > -1 && safe_dup2(tty_fd, stdout_fd) == -1) { safe_close(tty_fd, __FILE__, __LINE__); return PyErr_SetFromErrno(PyExc_OSError); } + if (stderr_fd > -1 && safe_dup2(tty_fd, stderr_fd) == -1) { safe_close(tty_fd, __FILE__, __LINE__); return PyErr_SetFromErrno(PyExc_OSError); } + safe_close(tty_fd, __FILE__, __LINE__); Py_RETURN_NONE; } diff --git a/kitty/fast_data_types.pyi b/kitty/fast_data_types.pyi index d096283e8..dbe164b3e 100644 --- a/kitty/fast_data_types.pyi +++ b/kitty/fast_data_types.pyi @@ -1395,7 +1395,7 @@ def sigqueue(pid: int, signal: int, value: int) -> None: pass -def establish_controlling_tty(ttyname: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None: +def establish_controlling_tty(tty_fd: int, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None: pass diff --git a/kitty/prewarm.py b/kitty/prewarm.py index 79681f84b..ebaef2a77 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -351,7 +351,7 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl if tty_name: sys.__stdout__.flush() sys.__stderr__.flush() - establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) + establish_controlling_tty(os.open(tty_name, os.O_RDWR | os.O_CLOEXEC, 0), sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) os.close(w) if shm.unlink_on_exit: child_main(cmd, ready_fd_read) @@ -389,7 +389,7 @@ class SocketChild: self.input_buf = self.output_buf = b'' self.fds: List[int] = [] self.child_id = -1 - self.cwd = self.tty_name = '' + self.cwd = '' self.env: Dict[str, str] = {} self.argv: List[str] = [] self.stdin = self.stdout = self.stderr = -1 @@ -428,6 +428,9 @@ class SocketChild: self.input_buf = self.input_buf[idx+1:] cmd, _, payload = line.partition(':') if cmd == 'finish': + for x in self.fds: + os.set_inheritable(x, x is not self.fds[0]) + os.set_blocking(x, True) if self.stdin > -1: self.stdin = self.fds[self.stdin] if self.stdout > -1: @@ -437,8 +440,6 @@ class SocketChild: return True elif cmd == 'cwd': self.cwd = payload - elif cmd == 'tty_name': - self.tty_name = payload elif cmd == 'env': k, _, v = payload.partition('=') self.env[k] = v @@ -480,14 +481,14 @@ class SocketChild: os.close(r) os.setsid() restore_python_signal_handlers() - if self.tty_name: + if self.fds: sys.__stdout__.flush() sys.__stderr__.flush() establish_controlling_tty( - self.tty_name, - sys.__stdin__.fileno() if self.stdin == -1 else -1, - sys.__stdout__.fileno() if self.stdout == -1 else -1, - sys.__stderr__.fileno() if self.stderr == -1 else -1) + self.fds[0], + sys.__stdin__.fileno() if self.stdin < 0 else -1, + sys.__stdout__.fileno() if self.stdout < 0 else -1, + sys.__stderr__.fileno() if self.stderr < 0 else -1) # the std streams fds are in all_non_child_fds already # so they will be closed there if self.stdin > -1: diff --git a/kitty/safe-wrappers.h b/kitty/safe-wrappers.h index c6444f359..8fa19ce8c 100644 --- a/kitty/safe-wrappers.h +++ b/kitty/safe-wrappers.h @@ -37,3 +37,10 @@ safe_close(int fd, const char* file UNUSED, const int line UNUSED) { #endif while(close(fd) != 0 && errno == EINTR); } + +static inline int +safe_dup2(int a, int b) { + int ret; + while((ret = dup2(a, b)) < 0 && errno == EINTR); + return ret; +} diff --git a/prewarm-launcher.c b/prewarm-launcher.c index a3902dc85..3ce7a619a 100644 --- a/prewarm-launcher.c +++ b/prewarm-launcher.c @@ -102,6 +102,13 @@ safe_close(int fd) { while(close(fd) != 0 && errno == EINTR); } +static inline int +safe_dup2(int a, int b) { + int ret; + while((ret = dup2(a, b)) < 0 && errno == EINTR); + return ret; +} + static inline bool safe_tcsetattr(int fd, int actions, const struct termios *tp) { int ret = 0; @@ -166,7 +173,6 @@ is_prewarmable(int argc, char *argv[]) { } static int child_master_fd = -1, child_slave_fd = -1; -static char child_tty_name[PATH_MAX] = {0}; static struct winsize self_winsize = {0}; static struct termios self_termios = {0}, restore_termios = {0}; static bool termios_needs_restore = false; @@ -213,7 +219,7 @@ get_termios_state(void) { static bool open_pty(void) { - while (openpty(&child_master_fd, &child_slave_fd, child_tty_name, &self_termios, &self_winsize) == -1) { + while (openpty(&child_master_fd, &child_slave_fd, NULL, &self_termios, &self_winsize) == -1) { if (errno != EINTR) return false; } set_blocking(child_master_fd, false); @@ -255,7 +261,7 @@ setup_signal_handler(void) { static void setup_stdio_handles(void) { - int pos = 0; + int pos = 1; if (!isatty(STDIN_FILENO)) stdin_pos = pos++; if (!isatty(STDOUT_FILENO)) stdout_pos = pos++; if (!isatty(STDERR_FILENO)) stderr_pos = pos++; @@ -290,27 +296,22 @@ static bool create_launch_msg(int argc, char *argv[]) { #define w(prefix, data) { if (!write_item_to_launch_msg(prefix, data)) return false; } static char buf[4*PATH_MAX]; - w("tty_name", child_tty_name); if (getcwd(buf, sizeof(buf))) { w("cwd", buf); } for (int i = 0; i < argc; i++) w("argv", argv[i]); char **s = environ; for (; *s; s++) w("env", *s); - int num_fds = 0, fds[8]; + int num_fds = 0, fds[4]; + fds[num_fds++] = child_slave_fd; #define sio(which, x) if (which##_pos > -1) { snprintf(buf, sizeof(buf), "%d", which##_pos); w(#which, buf); fds[num_fds++] = x; } sio(stdin, STDIN_FILENO); sio(stdout, STDOUT_FILENO); sio(stderr, STDERR_FILENO); #undef sio w("finish", ""); - if (num_fds) { - struct cmsghdr *cmsg = CMSG_FIRSTHDR(&launch_msg_container); - cmsg->cmsg_len = CMSG_LEN(sizeof(fds[0]) * num_fds); - memcpy(CMSG_DATA(cmsg), fds, num_fds * sizeof(fds[0])); - launch_msg_container.msg_controllen = cmsg->cmsg_len; - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - } else { - launch_msg_container.msg_controllen = 0; - launch_msg_container.msg_control = 0; - } + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&launch_msg_container); + cmsg->cmsg_len = CMSG_LEN(sizeof(fds[0]) * num_fds); + memcpy(CMSG_DATA(cmsg), fds, num_fds * sizeof(fds[0])); + launch_msg_container.msg_controllen = cmsg->cmsg_len; + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; return true; #undef w } @@ -351,6 +352,16 @@ read_child_data(void) { return true; } +static void +close_sent_fds(void) { + if (child_slave_fd > -1) { safe_close(child_slave_fd); child_slave_fd = -1; } +#define redirect(which, mode) { int fd = safe_open("/dev/null", mode | O_CLOEXEC, 0); if (fd > -1) { safe_dup2(fd, which); safe_close(fd); } } + if (stdin_pos > -1) redirect(STDIN_FILENO, O_RDONLY); + if (stdout_pos > -1) redirect(STDOUT_FILENO, O_WRONLY); + if (stderr_pos > -1) redirect(STDERR_FILENO, O_WRONLY); +#undef redirect +} + static bool send_launch_msg(void) { ssize_t n; @@ -362,7 +373,10 @@ send_launch_msg(void) { // some bytes sent, null out the control msg data as it is already sent launch_msg_container.msg_controllen = 0; launch_msg_container.msg_control = NULL; - if ((size_t)n > launch_msg.iov_len) launch_msg.iov_len = 0; + if ((size_t)n > launch_msg.iov_len) { + launch_msg.iov_len = 0; + close_sent_fds(); + } else launch_msg.iov_len -= n; launch_msg.iov_base = (char*)launch_msg.iov_base + n; return true;