detect when socket child closes its tty

This commit is contained in:
Kovid Goyal 2022-07-03 13:20:53 +05:30
parent 16e59784c6
commit 73795b5257
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 41 additions and 22 deletions

View File

@ -110,7 +110,7 @@ class PrewarmProcess:
del self.from_worker
if self.worker_pid > 0:
if wait_for_child_death(self.worker_pid) is None:
log_error('Prewarm process failed to quite gracefully, killing it')
log_error('Prewarm process failed to quit gracefully, killing it')
os.kill(self.worker_pid, signal.SIGKILL)
os.waitpid(self.worker_pid, 0)
@ -408,8 +408,10 @@ class SocketChild:
return False
def fork(self, all_non_child_fds: Iterable[int]) -> None:
r, w = safe_pipe()
self.pid = os.fork()
if self.pid > 0:
os.close(w)
# master process
if self.stdin > -1:
os.close(self.stdin)
@ -420,9 +422,14 @@ class SocketChild:
if self.stderr > -1:
os.close(self.stderr)
self.stderr = -1
poll = select.poll()
poll.register(r, select.POLLIN)
tuple(poll.poll())
os.close(r)
self.handle_creation()
return
# child process
os.close(r)
os.setsid()
remove_signal_handlers()
if self.tty_name:
@ -441,6 +448,7 @@ class SocketChild:
os.dup2(self.stdout, sys.__stdout__.fileno())
if self.stderr > -1:
os.dup2(self.stderr, sys.__stderr__.fileno())
os.close(w)
for fd in all_non_child_fds:
if fd > -1:
os.close(fd)
@ -698,6 +706,7 @@ def exec_main(stdin_read: int, stdout_write: int, death_notify_write: int, unix_
main(stdin_read, stdout_write, death_notify_write, unix_socket)
finally:
set_options(None)
unix_socket.close()
def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[PrewarmProcess]:
@ -721,7 +730,8 @@ def fork_prewarm_process(opts: Options, use_exec: bool = False) -> Optional[Prew
child_pid = os.fork()
if child_pid:
# master
unix_socket.close()
if not use_exec:
unix_socket.close()
os.close(stdin_read)
os.close(stdout_write)
os.close(death_notify_write)

View File

@ -319,6 +319,7 @@ read_child_data(void) {
if (!parse_long(from_child_buf, &cp)) { print_error("Could not parse child pid from prewarm socket", 0); return false; }
if (cp == 0) { print_error("Got zero child pid from prewarm socket", 0); return false; }
child_pid = cp;
if (child_slave_fd > -1) { safe_close(child_slave_fd); child_slave_fd = -1; }
memset(from_child_buf, 0, (p - from_child_buf) + 1);
from_child_buf_pos -= (p - from_child_buf) + 1;
if (from_child_buf_pos) memmove(from_child_buf, p + 1, from_child_buf_pos);
@ -397,48 +398,56 @@ read_or_transfer_from_self_tty(void) {
static void
loop(void) {
#define fail(s) { print_error(s, errno); return; }
#define check_fd(which, name) { if (poll_data[which].revents & POLLERR) { pe("File descriptor %s failed", #name); return; } if (poll_data[which].revents & POLLHUP) { pe("File descriptor %s hungup", #name); return; } }
#define pd(which) poll_data[which##_idx]
#define check_fd(name) { if (pd(name).revents & POLLERR) { pe("File descriptor %s failed", #name); return; } if (pd(name).revents & POLLHUP) { pe("File descriptor %s hungup", #name); return; } }
struct pollfd poll_data[4];
poll_data[0].fd = self_ttyfd; poll_data[0].events = POLLIN;
poll_data[1].fd = child_master_fd; poll_data[1].events = POLLIN;
poll_data[2].fd = socket_fd; poll_data[2].events = POLLIN;
poll_data[3].fd = signal_read_fd; poll_data[3].events = POLLIN;
enum { self_ttyfd_idx, socket_fd_idx, signal_read_fd_idx, child_master_fd_idx };
#define init(name) pd(name).fd = name; pd(name).events = POLLIN;
init(self_ttyfd); init(socket_fd); init(signal_read_fd); init(child_master_fd);
#undef init
size_t num_to_poll = arraysz(poll_data);
while (true) {
int ret;
// self_ttyfd
poll_data[0].events = (to_child_tty.sz < IO_BUZ_SZ - 1 ? POLLIN : 0) | (from_child_tty.sz ? POLLOUT : 0);
pd(self_ttyfd).events = (to_child_tty.sz < IO_BUZ_SZ - 1 ? POLLIN : 0) | (from_child_tty.sz ? POLLOUT : 0);
pd(socket_fd).events = POLLIN | (launch_msg.iov_len ? POLLOUT : 0);
// child_master_fd
poll_data[1].events = (from_child_tty.sz < IO_BUZ_SZ - 1 ? POLLIN : 0) | (to_child_tty.sz ? POLLOUT : 0);
// socket_fd
poll_data[2].events = POLLIN | (launch_msg.iov_len ? POLLOUT : 0);
pd(child_master_fd).events = (from_child_tty.sz < IO_BUZ_SZ - 1 ? POLLIN : 0) | (to_child_tty.sz ? POLLOUT : 0);
for (size_t i = 0; i < arraysz(poll_data); i++) poll_data[i].revents = 0;
while ((ret = poll(poll_data, arraysz(poll_data), -1)) == -1) { if (errno != EINTR) fail("poll() failed"); }
while ((ret = poll(poll_data, num_to_poll, -1)) == -1) { if (errno != EINTR) fail("poll() failed"); }
if (!ret) continue;
if (poll_data[1].revents & POLLIN) if (!read_or_transfer_from_child_tty()) fail("reading from child tty failed");
if (poll_data[0].revents & POLLOUT) if (!from_child_to_self()) fail("writing to self tty failed");
if (poll_data[0].revents & POLLIN) if (!read_or_transfer_from_self_tty()) fail("reading from self tty failed");
if (poll_data[1].revents & POLLOUT) if (!from_self_to_child()) fail("writing to child tty failed");
if (pd(child_master_fd).revents & POLLIN) if (!read_or_transfer_from_child_tty()) fail("reading from child tty failed");
if (pd(self_ttyfd).revents & POLLOUT) if (!from_child_to_self()) fail("writing to self tty failed");
if (pd(self_ttyfd).revents & POLLIN) if (!read_or_transfer_from_self_tty()) fail("reading from self tty failed");
if (pd(child_master_fd).revents & POLLOUT) if (!from_self_to_child()) fail("writing to child tty failed");
if (pd(child_master_fd).revents & POLLHUP) {
// child has closed its tty, wait for exit code from prewarm zygote
safe_close(child_master_fd); child_master_fd = -1;
num_to_poll--;
}
check_fd(0, self_ttyfd); check_fd(1, child_master_fd); check_fd(3, signal_read_fd);
check_fd(self_ttyfd);
if (child_master_fd > -1) check_fd(child_master_fd);
check_fd(signal_read_fd);
// socket_fd
if (poll_data[2].revents & POLLERR) {
if (pd(socket_fd).revents & POLLERR) {
print_error("File descriptor socket_fd failed", 0); return;
}
if (poll_data[2].revents & POLLIN) {
if (pd(socket_fd).revents & POLLIN) {
if (!read_child_data()) fail("reading information about child failed");
}
if (poll_data[2].revents & POLLHUP) {
if (pd(socket_fd).revents & POLLHUP) {
if (from_child_buf[0]) { parse_int(from_child_buf, &exit_status); }
return;
}
if (poll_data[2].revents & POLLOUT) {
if (pd(socket_fd).revents & POLLOUT) {
if (!send_launch_msg()) fail("sending launch message failed");
}
}
#undef pd
#undef check_fd
#undef fail
}