diff --git a/prewarm-launcher.c b/prewarm-launcher.c index 9aa41fb23..ebbcaa02a 100644 --- a/prewarm-launcher.c +++ b/prewarm-launcher.c @@ -159,7 +159,6 @@ connect_to_socket_synchronously(const char *addr) { if (getsockopt (fd, SOL_SOCKET, SO_ERROR, &socket_error_code, &sizeof_socket_error_code) == -1) return -1; if (socket_error_code != 0) return -1; } - if (fd > -1) set_blocking(fd, false); return fd; } @@ -222,7 +221,6 @@ open_pty(void) { while (openpty(&child_master_fd, &child_slave_fd, NULL, &self_termios, &self_winsize) == -1) { if (errno != EINTR) return false; } - set_blocking(child_master_fd, false); return true; } @@ -234,12 +232,9 @@ handle_signal(int sig_num, siginfo_t *si, void *ucontext) { size_t sz = sizeof(siginfo_t); while (signal_write_fd != -1 && sz) { // as long as sz is less than PIPE_BUF write will either write all or return -1 with EAGAIN - // so we are guaranteed atomic writes - ssize_t ret = write(signal_write_fd, buf, sz); - if (ret <= 0) { - if (errno == EINTR) continue; - break; - } + // so we are guaranteed atomic writes, barring implementation bugs + ssize_t ret = safe_write(signal_write_fd, buf, sz); + if (ret <= 0) break; sz -= ret; buf += ret; } @@ -251,7 +246,7 @@ setup_signal_handler(void) { int fds[2]; if (pipe(fds) != 0) return false; signal_read_fd = fds[0]; signal_write_fd = fds[1]; - set_blocking(signal_read_fd, false); set_blocking(signal_write_fd, false); + set_blocking(signal_write_fd, false); struct sigaction act = {.sa_sigaction=handle_signal, .sa_flags=SA_SIGINFO | SA_RESTART}; #define a(which) if (sigaction(which, &act, NULL) != 0) return false; a(SIGWINCH); a(SIGINT); a(SIGTERM); a(SIGQUIT); a(SIGHUP); @@ -320,14 +315,16 @@ static int exit_status = EXIT_FAILURE; static char from_child_buf[64] = {0}; static size_t from_child_buf_pos = 0; static int pending_signals[32] = {0}; +enum ChildState { CHILD_NOT_STARTED, CHILD_STARTED, CHILD_EXITED }; +static enum ChildState child_state = CHILD_NOT_STARTED; static bool read_child_data(void) { ssize_t n; if (from_child_buf_pos >= sizeof(from_child_buf) - 2) { print_error("Too much data from prewarm socket", 0); return false; } - while ((n = read(socket_fd, from_child_buf, sizeof(from_child_buf) - 2 - from_child_buf_pos)) < 0 && errno == EINTR); + n = safe_read(socket_fd, from_child_buf, sizeof(from_child_buf) - 2 - from_child_buf_pos); if (n < 0) { - print_error("Failed to read from prewarm socket", errno); + if (errno == EIO || errno == EPIPE) { socket_fd = -1; return true; } return false; } if (n) { @@ -339,6 +336,7 @@ read_child_data(void) { if (!parse_long(from_child_buf, &cp)) { print_error("Could not parse child pid from prewarm socket", 0); return false; } if (cp == 0) { print_error("Got zero child pid from prewarm socket", 0); return false; } child_pid = cp; + child_state = CHILD_STARTED; if (child_slave_fd > -1) { safe_close(child_slave_fd); child_slave_fd = -1; } memset(from_child_buf, 0, (p - from_child_buf) + 1); from_child_buf_pos -= (p - from_child_buf) + 1; @@ -348,7 +346,7 @@ read_child_data(void) { } memset(pending_signals, 0, sizeof(pending_signals)); } - } + } else { socket_fd = -1; return true; } return true; } @@ -366,10 +364,8 @@ static bool send_launch_msg(void) { ssize_t n; while ((n = sendmsg(socket_fd, &launch_msg_container, MSG_NOSIGNAL)) < 0 && errno == EINTR); - if (n < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) return true; - return false; - } + if (n < 0) return false; + if (n == 0) { errno = EPIPE; return false; } // some bytes sent, null out the control msg data as it is already sent launch_msg_container.msg_controllen = 0; launch_msg_container.msg_control = NULL; @@ -382,58 +378,63 @@ send_launch_msg(void) { return true; } +struct fd_to_watch { + bool want_read, want_write, want_error; +}; + +struct watched_fds { + struct fd_to_watch self_ttyfd, signal_read_fd, socket_fd, child_master_fd; +}; +static struct watched_fds wf = {0}; + static bool -read_or_transfer(int src_fd, int dest_fd, transfer_buf *t) { - (void)dest_fd; - while(t->sz < IO_BUZ_SZ) { - ssize_t n = safe_read(src_fd, t->buf + t->sz, IO_BUZ_SZ - t->sz); +read_from_tty(int *fd, transfer_buf *t) { + if (*fd < 0) return true; + if (t->sz < IO_BUZ_SZ) { + ssize_t n = safe_read(*fd, t->buf + t->sz, IO_BUZ_SZ - t->sz); if (n < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) return true; + if (errno == EPIPE || errno == EIO) { *fd = -1; return true; } return false; } - if (!n) break; + if (n == 0) *fd = -1; // hangup t->sz += n; } return true; } static bool -read_or_transfer_from_child_tty(void) { - if (child_master_fd < 0) return true; - return read_or_transfer(child_master_fd, self_ttyfd, &from_child_tty); +read_from_child_tty(void) { + return read_from_tty(&child_master_fd, &from_child_tty); } static bool -write_from_to(transfer_buf *src, int dest_fd) { - while (src->sz) { - ssize_t n = safe_write(dest_fd, src->buf, src->sz); - if (n < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) return true; - return false; - } +write_to_tty(transfer_buf *src, int *dest_fd) { + if (*dest_fd < 0) return true; + if (src->sz) { + ssize_t n = safe_write(*dest_fd, src->buf, src->sz); + if (n < 0) return false; if (n > 0) { src->sz -= n; memmove(src->buf, src->buf + n, src->sz); - } else break; + } else *dest_fd = -1; } return true; } static bool from_child_to_self(void) { - return write_from_to(&from_child_tty, self_ttyfd); + return write_to_tty(&from_child_tty, &self_ttyfd); } static bool from_self_to_child(void) { - if (child_master_fd < 0) return true; - return write_from_to(&to_child_tty, child_master_fd); + return write_to_tty(&to_child_tty, &child_master_fd); } static bool -read_or_transfer_from_self_tty(void) { - return read_or_transfer(self_ttyfd, child_master_fd, &to_child_tty); +read_from_self_tty(void) { + return read_from_tty(&self_ttyfd, &to_child_tty); } static bool window_size_dirty = false; @@ -442,113 +443,110 @@ static bool read_signals(void) { static char buf[sizeof(siginfo_t) * 8]; static size_t buf_pos = 0; - while(true) { - ssize_t len = safe_read(signal_read_fd, buf + buf_pos, sizeof(buf) - buf_pos); - if (len < 0) { - if (errno == EWOULDBLOCK || errno == EAGAIN) return true; - return false; - } - buf_pos += len; - while (buf_pos >= sizeof(siginfo_t)) { - siginfo_t *sig = (siginfo_t*)buf; - switch(sig->si_signo) { - case SIGWINCH: - window_size_dirty = true; break; - case SIGINT: case SIGTERM: case SIGHUP: case SIGQUIT: - if (child_pid > 0) kill(child_pid, sig->si_signo); - else { - for (size_t i = 0; i < arraysz(pending_signals); i++) { - if (!pending_signals[i]) { - pending_signals[i] = sig->si_signo; - break; - } + ssize_t len = safe_read(signal_read_fd, buf + buf_pos, sizeof(buf) - buf_pos); + if (len < 0) return false; + if (len == 0) return true; + buf_pos = len; + while (buf_pos >= sizeof(siginfo_t)) { + siginfo_t *sig = (siginfo_t*)buf; + switch(sig->si_signo) { + case SIGWINCH: + window_size_dirty = true; break; + case SIGINT: case SIGTERM: case SIGHUP: case SIGQUIT: + if (child_pid > 0) kill(child_pid, sig->si_signo); + else { + for (size_t i = 0; i < arraysz(pending_signals); i++) { + if (!pending_signals[i]) { + pending_signals[i] = sig->si_signo; + break; } } - break; - } - memmove(buf, buf + sizeof(siginfo_t), sizeof(siginfo_t)); - buf_pos -= sizeof(siginfo_t); + } + break; } - if (len == 0) break; + memmove(buf, buf + sizeof(siginfo_t), sizeof(siginfo_t)); + buf_pos -= sizeof(siginfo_t); } return true; } -struct pollees { - struct pollfd poll_data[8]; - struct { - int self_ttyfd, signal_read_fd, socket_fd, child_master_fd; - } idx; - size_t num_fds; -}; -static struct pollees pollees = {0}; +static bool +keep_going(void) { + switch(child_state) { + case CHILD_NOT_STARTED: + return self_ttyfd > -1 && signal_read_fd > -1 && socket_fd > -1 && child_master_fd > -1; + case CHILD_STARTED: + return self_ttyfd > -1 && signal_read_fd > -1 && socket_fd > -1; + case CHILD_EXITED: + return self_ttyfd > -1 && signal_read_fd > -1 && child_master_fd > -1; + } + return false; +} -#define register_for_poll(which) pollees.poll_data[pollees.num_fds].fd = which; pollees.poll_data[pollees.num_fds].events = POLLIN; pollees.idx.which = pollees.num_fds++; -#define unregister_for_poll(which) if (pollees.idx.which > -1) { remove_i_from_array(pollees.poll_data, pollees.idx.which, pollees.num_fds); \ - if (pollees.idx.self_ttyfd > pollees.idx.which) pollees.idx.self_ttyfd--; \ - if (pollees.idx.signal_read_fd > pollees.idx.which) pollees.idx.signal_read_fd--; \ - if (pollees.idx.socket_fd > pollees.idx.which) pollees.idx.socket_fd--; \ - if (pollees.idx.child_master_fd > pollees.idx.which) pollees.idx.child_master_fd--; \ - pollees.idx.which = -1; } -#define set_poll_events(which, val) if (pollees.idx.which > -1) { pollees.poll_data[pollees.idx.which].events = (val); } -#define poll_revents(which) ((pollees.idx.which > -1) ? pollees.poll_data[pollees.idx.which].revents : 0) +static void +flush_data(void) { + if (child_master_fd > -1 && from_child_tty.sz < IO_BUZ_SZ) { + set_blocking(child_master_fd, false); + read_from_child_tty(); + } + if (self_ttyfd > -1 && from_child_tty.sz > 0) { + set_blocking(self_ttyfd, false); + from_child_to_self(); + } +} static void loop(void) { #define fail(s) { print_error(s, errno); return; } -#define check_fd(name) { if (poll_revents(name) & POLLERR) { pe("File descriptor %s failed", #name); return; } if (poll_revents(name) & POLLHUP) { pe("File descriptor %s hungup", #name); return; } } - register_for_poll(self_ttyfd); register_for_poll(signal_read_fd); register_for_poll(socket_fd); register_for_poll(child_master_fd); + int ret, nfds = 0; +#define init(which) wf.which.want_read = true; nfds = MAX(which, nfds); + init(self_ttyfd); init(signal_read_fd); init(socket_fd); init(child_master_fd); +#undef init + fd_set readable, writable, errorable; + nfds++; - while (true) { - int ret; - set_poll_events(self_ttyfd, (to_child_tty.sz < IO_BUZ_SZ ? POLLIN : 0) | (from_child_tty.sz ? POLLOUT : 0)); - set_poll_events(socket_fd, POLLIN | (launch_msg.iov_len ? POLLOUT : 0)); - set_poll_events(child_master_fd, (from_child_tty.sz < IO_BUZ_SZ ? POLLIN : 0) | (to_child_tty.sz ? POLLOUT : 0)); + while (keep_going()) { + wf.self_ttyfd.want_read = to_child_tty.sz < IO_BUZ_SZ; wf.self_ttyfd.want_write = from_child_tty.sz > 0; + wf.child_master_fd.want_read = from_child_tty.sz < IO_BUZ_SZ; wf.child_master_fd.want_write = to_child_tty.sz > 0; + wf.socket_fd.want_write = launch_msg.iov_len > 0; if (window_size_dirty && child_master_fd > -1 ) { if (!get_window_size()) fail("getting window size for self tty failed"); if (!safe_winsz(child_master_fd, TIOCSWINSZ, &self_winsize)) fail("setting window size on child pty failed"); window_size_dirty = false; } - - for (size_t i = 0; i < pollees.num_fds; i++) pollees.poll_data[i].revents = 0; - while ((ret = poll(pollees.poll_data, pollees.num_fds, -1)) == -1) { if (errno != EINTR) fail("poll() failed"); } + FD_ZERO(&readable); FD_ZERO(&writable); FD_ZERO(&errorable); +#define set(which) if (which > -1) { if (wf.which.want_read) { FD_SET(which, &readable); } if (wf.which.want_write) { FD_SET(which, &writable); } if (wf.which.want_error) { FD_SET(which, &errorable); } } + set(self_ttyfd); set(child_master_fd); set(socket_fd); set(signal_read_fd); +#undef set + while ((ret = select(nfds, &readable, &writable, &errorable, NULL)) == -1) { if (errno != EINTR) fail("select() failed"); } if (!ret) continue; - if (poll_revents(child_master_fd) & POLLIN) if (!read_or_transfer_from_child_tty()) fail("reading from child tty failed"); - if (poll_revents(self_ttyfd) & POLLOUT) if (!from_child_to_self()) fail("writing to self tty failed"); - if (poll_revents(self_ttyfd) & POLLIN) if (!read_or_transfer_from_self_tty()) fail("reading from self tty failed"); - if (poll_revents(child_master_fd) & POLLOUT) if (!from_self_to_child()) fail("writing to child tty failed"); - if (poll_revents(child_master_fd) & POLLHUP) { - // child has closed its tty, wait for exit code from prewarm zygote - safe_close(child_master_fd); child_master_fd = -1; - unregister_for_poll(child_master_fd); - if (!child_pid) return; + if (child_master_fd > -1) { + if (FD_ISSET(child_master_fd, &writable)) if (!from_self_to_child()) fail("writing to child tty failed"); + if (FD_ISSET(child_master_fd, &readable)) { + if (!read_from_child_tty()) fail("reading from child tty failed"); + } + } + if (self_ttyfd > -1) { + if (FD_ISSET(self_ttyfd, &readable)) if (!read_from_self_tty()) fail("reading from self tty failed"); + if (FD_ISSET(self_ttyfd, &writable)) if (!from_child_to_self()) fail("writing to self tty failed"); } - check_fd(self_ttyfd); check_fd(child_master_fd); check_fd(signal_read_fd); + if (signal_read_fd > -1 && FD_ISSET(signal_read_fd, &readable)) if (!read_signals()) fail("reading from signal fd failed"); - // signal_read_fd - if (poll_revents(signal_read_fd) & POLLIN) if (!read_signals()) fail("reading from signal fd failed"); - - // socket_fd - if (poll_revents(socket_fd) & POLLERR) { - print_error("File descriptor socket_fd failed", 0); return; - } - if (poll_revents(socket_fd) & POLLIN) { - if (!read_child_data()) fail("reading information about child failed"); - } - if (poll_revents(socket_fd) & POLLHUP) { - if (from_child_buf[0]) { parse_int(from_child_buf, &exit_status); } - child_pid = 0; safe_close(socket_fd); socket_fd = -1; - unregister_for_poll(socket_fd); - if (child_master_fd < 0) return; - } - if (poll_revents(socket_fd) & POLLOUT) { - if (!send_launch_msg()) fail("sending launch message failed"); + if (socket_fd > -1) { + if (FD_ISSET(socket_fd, &writable)) if (!send_launch_msg()) fail("sending launch message failed"); + if (FD_ISSET(socket_fd, &readable)) { + if (!read_child_data()) fail("reading information about child failed"); + if (socket_fd < 0) { // hangup + if (from_child_buf[0]) { parse_int(from_child_buf, &exit_status); } + child_pid = 0; + child_state = CHILD_EXITED; + } + } } } -#undef check_fd #undef fail } @@ -578,9 +576,10 @@ use_prewarmed_process(int argc, char *argv[]) { env_addr = check_socket_addr(env_addr); if (!env_addr) return; self_ttyfd = safe_open(ctermid(NULL), O_RDWR | O_NONBLOCK, 0); + if (self_ttyfd < 0) return; + setup_stdio_handles(); #define fail(s) { print_error(s, errno); cleanup(); return; } if (!setup_signal_handler()) fail("Failed to setup signal handling"); - if (self_ttyfd == -1) fail("Failed to open controlling terminal"); if (!get_window_size()) fail("Failed to get window size of controlling terminal"); if (!get_termios_state()) fail("Failed to get termios state of controlling terminal"); if (!open_pty()) fail("Failed to open slave pty"); @@ -589,7 +588,6 @@ use_prewarmed_process(int argc, char *argv[]) { cfmakeraw(&self_termios); if (!safe_tcsetattr(self_ttyfd, TCSANOW, &self_termios)) fail("Failed to put tty into raw mode"); while (tcsetattr(self_ttyfd, TCSANOW, &self_termios) == -1 && errno == EINTR) {} - setup_stdio_handles(); if (!create_launch_msg(argc, argv)) fail("Failed to open controlling terminal"); socket_fd = connect_to_socket_synchronously(env_addr); if (socket_fd < 0) fail("Failed to connect to prewarm socket"); @@ -599,8 +597,7 @@ use_prewarmed_process(int argc, char *argv[]) { #undef fail loop(); - read_or_transfer_from_child_tty(); - from_child_to_self(); + flush_data(); cleanup(); exit(exit_status); }