Avoid passing around pty paths, instead send the pty fd

This commit is contained in:
Kovid Goyal 2022-07-07 20:07:04 +05:30
parent cf8113ea24
commit 72f3e8cd40
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
5 changed files with 60 additions and 44 deletions

View File

@ -128,14 +128,14 @@ spawn(PyObject *self UNUSED, PyObject *args) {
safe_close(tfd, __FILE__, __LINE__); safe_close(tfd, __FILE__, __LINE__);
// Redirect stdin/stdout/stderr to the pty // Redirect stdin/stdout/stderr to the pty
if (dup2(slave, 1) == -1) exit_on_err("dup2() failed for fd number 1"); if (safe_dup2(slave, 1) == -1) exit_on_err("dup2() failed for fd number 1");
if (dup2(slave, 2) == -1) exit_on_err("dup2() failed for fd number 2"); if (safe_dup2(slave, 2) == -1) exit_on_err("dup2() failed for fd number 2");
if (stdin_read_fd > -1) { if (stdin_read_fd > -1) {
if (dup2(stdin_read_fd, 0) == -1) exit_on_err("dup2() failed for fd number 0"); if (safe_dup2(stdin_read_fd, 0) == -1) exit_on_err("dup2() failed for fd number 0");
safe_close(stdin_read_fd, __FILE__, __LINE__); safe_close(stdin_read_fd, __FILE__, __LINE__);
safe_close(stdin_write_fd, __FILE__, __LINE__); safe_close(stdin_write_fd, __FILE__, __LINE__);
} else { } else {
if (dup2(slave, 0) == -1) exit_on_err("dup2() failed for fd number 0"); if (safe_dup2(slave, 0) == -1) exit_on_err("dup2() failed for fd number 0");
} }
safe_close(slave, __FILE__, __LINE__); safe_close(slave, __FILE__, __LINE__);
safe_close(master, __FILE__, __LINE__); safe_close(master, __FILE__, __LINE__);
@ -185,19 +185,13 @@ spawn(PyObject *self UNUSED, PyObject *args) {
static PyObject* static PyObject*
establish_controlling_tty(PyObject *self UNUSED, PyObject *args) { establish_controlling_tty(PyObject *self UNUSED, PyObject *args) {
const char *ttyname; int tty_fd, stdin_fd = -1, stdout_fd = -1, stderr_fd = -1;
int stdin_fd = -1, stdout_fd = -1, stderr_fd = -1; if (!PyArg_ParseTuple(args, "i|iii", &tty_fd, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL;
if (!PyArg_ParseTuple(args, "s|iii", &ttyname, &stdin_fd, &stdout_fd, &stderr_fd)) return NULL; if (ioctl(tty_fd, TIOCSCTTY, 0) == -1) { safe_close(tty_fd, __FILE__, __LINE__); return PyErr_SetFromErrno(PyExc_OSError); }
int tfd = safe_open(ttyname, O_RDWR, 0); if (stdin_fd > -1 && safe_dup2(tty_fd, stdin_fd) == -1) { safe_close(tty_fd, __FILE__, __LINE__); return PyErr_SetFromErrno(PyExc_OSError); }
if (tfd == -1) return PyErr_SetFromErrnoWithFilename(PyExc_OSError, ttyname); if (stdout_fd > -1 && safe_dup2(tty_fd, stdout_fd) == -1) { safe_close(tty_fd, __FILE__, __LINE__); return PyErr_SetFromErrno(PyExc_OSError); }
#ifdef TIOCSCTTY if (stderr_fd > -1 && safe_dup2(tty_fd, stderr_fd) == -1) { safe_close(tty_fd, __FILE__, __LINE__); return PyErr_SetFromErrno(PyExc_OSError); }
// On BSD open() does not establish the controlling terminal safe_close(tty_fd, __FILE__, __LINE__);
if (ioctl(tfd, TIOCSCTTY, 0) == -1) return PyErr_SetFromErrno(PyExc_OSError);
#endif
if (stdin_fd > -1 && dup2(tfd, stdin_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
if (stdout_fd > -1 && dup2(tfd, stdout_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
if (stderr_fd > -1 && dup2(tfd, stderr_fd) == -1) return PyErr_SetFromErrno(PyExc_OSError);
safe_close(tfd, __FILE__, __LINE__);
Py_RETURN_NONE; Py_RETURN_NONE;
} }

View File

@ -1395,7 +1395,7 @@ def sigqueue(pid: int, signal: int, value: int) -> None:
pass pass
def establish_controlling_tty(ttyname: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None: def establish_controlling_tty(tty_fd: int, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None:
pass pass

View File

@ -351,7 +351,7 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl
if tty_name: if tty_name:
sys.__stdout__.flush() sys.__stdout__.flush()
sys.__stderr__.flush() sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) establish_controlling_tty(os.open(tty_name, os.O_RDWR | os.O_CLOEXEC, 0), sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
os.close(w) os.close(w)
if shm.unlink_on_exit: if shm.unlink_on_exit:
child_main(cmd, ready_fd_read) child_main(cmd, ready_fd_read)
@ -389,7 +389,7 @@ class SocketChild:
self.input_buf = self.output_buf = b'' self.input_buf = self.output_buf = b''
self.fds: List[int] = [] self.fds: List[int] = []
self.child_id = -1 self.child_id = -1
self.cwd = self.tty_name = '' self.cwd = ''
self.env: Dict[str, str] = {} self.env: Dict[str, str] = {}
self.argv: List[str] = [] self.argv: List[str] = []
self.stdin = self.stdout = self.stderr = -1 self.stdin = self.stdout = self.stderr = -1
@ -428,6 +428,9 @@ class SocketChild:
self.input_buf = self.input_buf[idx+1:] self.input_buf = self.input_buf[idx+1:]
cmd, _, payload = line.partition(':') cmd, _, payload = line.partition(':')
if cmd == 'finish': if cmd == 'finish':
for x in self.fds:
os.set_inheritable(x, x is not self.fds[0])
os.set_blocking(x, True)
if self.stdin > -1: if self.stdin > -1:
self.stdin = self.fds[self.stdin] self.stdin = self.fds[self.stdin]
if self.stdout > -1: if self.stdout > -1:
@ -437,8 +440,6 @@ class SocketChild:
return True return True
elif cmd == 'cwd': elif cmd == 'cwd':
self.cwd = payload self.cwd = payload
elif cmd == 'tty_name':
self.tty_name = payload
elif cmd == 'env': elif cmd == 'env':
k, _, v = payload.partition('=') k, _, v = payload.partition('=')
self.env[k] = v self.env[k] = v
@ -480,14 +481,14 @@ class SocketChild:
os.close(r) os.close(r)
os.setsid() os.setsid()
restore_python_signal_handlers() restore_python_signal_handlers()
if self.tty_name: if self.fds:
sys.__stdout__.flush() sys.__stdout__.flush()
sys.__stderr__.flush() sys.__stderr__.flush()
establish_controlling_tty( establish_controlling_tty(
self.tty_name, self.fds[0],
sys.__stdin__.fileno() if self.stdin == -1 else -1, sys.__stdin__.fileno() if self.stdin < 0 else -1,
sys.__stdout__.fileno() if self.stdout == -1 else -1, sys.__stdout__.fileno() if self.stdout < 0 else -1,
sys.__stderr__.fileno() if self.stderr == -1 else -1) sys.__stderr__.fileno() if self.stderr < 0 else -1)
# the std streams fds are in all_non_child_fds already # the std streams fds are in all_non_child_fds already
# so they will be closed there # so they will be closed there
if self.stdin > -1: if self.stdin > -1:

View File

@ -37,3 +37,10 @@ safe_close(int fd, const char* file UNUSED, const int line UNUSED) {
#endif #endif
while(close(fd) != 0 && errno == EINTR); while(close(fd) != 0 && errno == EINTR);
} }
static inline int
safe_dup2(int a, int b) {
int ret;
while((ret = dup2(a, b)) < 0 && errno == EINTR);
return ret;
}

View File

@ -102,6 +102,13 @@ safe_close(int fd) {
while(close(fd) != 0 && errno == EINTR); while(close(fd) != 0 && errno == EINTR);
} }
static inline int
safe_dup2(int a, int b) {
int ret;
while((ret = dup2(a, b)) < 0 && errno == EINTR);
return ret;
}
static inline bool static inline bool
safe_tcsetattr(int fd, int actions, const struct termios *tp) { safe_tcsetattr(int fd, int actions, const struct termios *tp) {
int ret = 0; int ret = 0;
@ -166,7 +173,6 @@ is_prewarmable(int argc, char *argv[]) {
} }
static int child_master_fd = -1, child_slave_fd = -1; static int child_master_fd = -1, child_slave_fd = -1;
static char child_tty_name[PATH_MAX] = {0};
static struct winsize self_winsize = {0}; static struct winsize self_winsize = {0};
static struct termios self_termios = {0}, restore_termios = {0}; static struct termios self_termios = {0}, restore_termios = {0};
static bool termios_needs_restore = false; static bool termios_needs_restore = false;
@ -213,7 +219,7 @@ get_termios_state(void) {
static bool static bool
open_pty(void) { open_pty(void) {
while (openpty(&child_master_fd, &child_slave_fd, child_tty_name, &self_termios, &self_winsize) == -1) { while (openpty(&child_master_fd, &child_slave_fd, NULL, &self_termios, &self_winsize) == -1) {
if (errno != EINTR) return false; if (errno != EINTR) return false;
} }
set_blocking(child_master_fd, false); set_blocking(child_master_fd, false);
@ -255,7 +261,7 @@ setup_signal_handler(void) {
static void static void
setup_stdio_handles(void) { setup_stdio_handles(void) {
int pos = 0; int pos = 1;
if (!isatty(STDIN_FILENO)) stdin_pos = pos++; if (!isatty(STDIN_FILENO)) stdin_pos = pos++;
if (!isatty(STDOUT_FILENO)) stdout_pos = pos++; if (!isatty(STDOUT_FILENO)) stdout_pos = pos++;
if (!isatty(STDERR_FILENO)) stderr_pos = pos++; if (!isatty(STDERR_FILENO)) stderr_pos = pos++;
@ -290,27 +296,22 @@ static bool
create_launch_msg(int argc, char *argv[]) { create_launch_msg(int argc, char *argv[]) {
#define w(prefix, data) { if (!write_item_to_launch_msg(prefix, data)) return false; } #define w(prefix, data) { if (!write_item_to_launch_msg(prefix, data)) return false; }
static char buf[4*PATH_MAX]; static char buf[4*PATH_MAX];
w("tty_name", child_tty_name);
if (getcwd(buf, sizeof(buf))) { w("cwd", buf); } if (getcwd(buf, sizeof(buf))) { w("cwd", buf); }
for (int i = 0; i < argc; i++) w("argv", argv[i]); for (int i = 0; i < argc; i++) w("argv", argv[i]);
char **s = environ; char **s = environ;
for (; *s; s++) w("env", *s); for (; *s; s++) w("env", *s);
int num_fds = 0, fds[8]; int num_fds = 0, fds[4];
fds[num_fds++] = child_slave_fd;
#define sio(which, x) if (which##_pos > -1) { snprintf(buf, sizeof(buf), "%d", which##_pos); w(#which, buf); fds[num_fds++] = x; } #define sio(which, x) if (which##_pos > -1) { snprintf(buf, sizeof(buf), "%d", which##_pos); w(#which, buf); fds[num_fds++] = x; }
sio(stdin, STDIN_FILENO); sio(stdout, STDOUT_FILENO); sio(stderr, STDERR_FILENO); sio(stdin, STDIN_FILENO); sio(stdout, STDOUT_FILENO); sio(stderr, STDERR_FILENO);
#undef sio #undef sio
w("finish", ""); w("finish", "");
if (num_fds) {
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&launch_msg_container); struct cmsghdr *cmsg = CMSG_FIRSTHDR(&launch_msg_container);
cmsg->cmsg_len = CMSG_LEN(sizeof(fds[0]) * num_fds); cmsg->cmsg_len = CMSG_LEN(sizeof(fds[0]) * num_fds);
memcpy(CMSG_DATA(cmsg), fds, num_fds * sizeof(fds[0])); memcpy(CMSG_DATA(cmsg), fds, num_fds * sizeof(fds[0]));
launch_msg_container.msg_controllen = cmsg->cmsg_len; launch_msg_container.msg_controllen = cmsg->cmsg_len;
cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS; cmsg->cmsg_type = SCM_RIGHTS;
} else {
launch_msg_container.msg_controllen = 0;
launch_msg_container.msg_control = 0;
}
return true; return true;
#undef w #undef w
} }
@ -351,6 +352,16 @@ read_child_data(void) {
return true; return true;
} }
static void
close_sent_fds(void) {
if (child_slave_fd > -1) { safe_close(child_slave_fd); child_slave_fd = -1; }
#define redirect(which, mode) { int fd = safe_open("/dev/null", mode | O_CLOEXEC, 0); if (fd > -1) { safe_dup2(fd, which); safe_close(fd); } }
if (stdin_pos > -1) redirect(STDIN_FILENO, O_RDONLY);
if (stdout_pos > -1) redirect(STDOUT_FILENO, O_WRONLY);
if (stderr_pos > -1) redirect(STDERR_FILENO, O_WRONLY);
#undef redirect
}
static bool static bool
send_launch_msg(void) { send_launch_msg(void) {
ssize_t n; ssize_t n;
@ -362,7 +373,10 @@ send_launch_msg(void) {
// some bytes sent, null out the control msg data as it is already sent // some bytes sent, null out the control msg data as it is already sent
launch_msg_container.msg_controllen = 0; launch_msg_container.msg_controllen = 0;
launch_msg_container.msg_control = NULL; launch_msg_container.msg_control = NULL;
if ((size_t)n > launch_msg.iov_len) launch_msg.iov_len = 0; if ((size_t)n > launch_msg.iov_len) {
launch_msg.iov_len = 0;
close_sent_fds();
}
else launch_msg.iov_len -= n; else launch_msg.iov_len -= n;
launch_msg.iov_base = (char*)launch_msg.iov_base + n; launch_msg.iov_base = (char*)launch_msg.iov_base + n;
return true; return true;