diff --git a/prewarm-launcher.c b/prewarm-launcher.c index 79d9d7e94..a3902dc85 100644 --- a/prewarm-launcher.c +++ b/prewarm-launcher.c @@ -47,6 +47,12 @@ // }}} #define IO_BUZ_SZ 8192 +#define remove_i_from_array(array, i, count) { \ + (count)--; \ + if ((i) < (count)) { \ + memmove((array) + (i), (array) + (i) + 1, sizeof((array)[0]) * ((count) - (i))); \ + }} + typedef struct transfer_buf { char *buf; @@ -454,24 +460,31 @@ read_signals(void) { return true; } +struct pollees { + struct pollfd poll_data[8]; + struct { + int self_ttyfd, signal_read_fd, socket_fd, child_master_fd; + } idx; + size_t num_fds; +}; +static struct pollees pollees = {0}; + +#define register_for_poll(which) pollees.poll_data[pollees.num_fds].fd = which; pollees.poll_data[pollees.num_fds].events = POLLIN; pollees.idx.which = pollees.num_fds++; +#define unregister_for_poll(which) if (pollees.idx.which > -1) { remove_i_from_array(pollees.poll_data, pollees.idx.which, pollees.num_fds); pollees.idx.which = -1; } +#define set_poll_events(which, val) if (pollees.idx.which > -1) { pollees.poll_data[pollees.idx.which].events = (val); } +#define poll_revents(which) ((pollees.idx.which > -1) ? pollees.poll_data[pollees.idx.which].revents : 0) + static void loop(void) { #define fail(s) { print_error(s, errno); return; } -#define pd(which) poll_data[which##_idx] -#define check_fd(name) { if (pd(name).revents & POLLERR) { pe("File descriptor %s failed", #name); return; } if (pd(name).revents & POLLHUP) { pe("File descriptor %s hungup", #name); return; } } - struct pollfd poll_data[4]; - enum { self_ttyfd_idx, signal_read_fd_idx, socket_fd_idx, child_master_fd_idx }; -#define init(name) pd(name).fd = name; pd(name).events = POLLIN; - init(self_ttyfd); init(socket_fd); init(signal_read_fd); init(child_master_fd); -#undef init +#define check_fd(name) { if (poll_revents(name) & POLLERR) { pe("File descriptor %s failed", #name); return; } if (poll_revents(name) & POLLHUP) { pe("File descriptor %s hungup", #name); return; } } + register_for_poll(self_ttyfd); register_for_poll(socket_fd); register_for_poll(signal_read_fd); register_for_poll(child_master_fd); while (true) { int ret; - pd(self_ttyfd).events = (to_child_tty.sz < IO_BUZ_SZ ? POLLIN : 0) | (from_child_tty.sz ? POLLOUT : 0); - if (socket_fd > -1) pd(socket_fd).events = POLLIN | (launch_msg.iov_len ? POLLOUT : 0); - else pd(socket_fd).events = 0; - if (child_master_fd > -1) pd(child_master_fd).events = (from_child_tty.sz < IO_BUZ_SZ ? POLLIN : 0) | (to_child_tty.sz ? POLLOUT : 0); - else pd(child_master_fd).events = 0; + set_poll_events(self_ttyfd, (to_child_tty.sz < IO_BUZ_SZ ? POLLIN : 0) | (from_child_tty.sz ? POLLOUT : 0)); + set_poll_events(socket_fd, POLLIN | (launch_msg.iov_len ? POLLOUT : 0)); + set_poll_events(child_master_fd, (from_child_tty.sz < IO_BUZ_SZ ? POLLIN : 0) | (to_child_tty.sz ? POLLOUT : 0)); if (window_size_dirty && child_master_fd > -1 ) { if (!get_window_size()) fail("getting window size for self tty failed"); @@ -479,44 +492,43 @@ loop(void) { window_size_dirty = false; } - for (size_t i = 0; i < arraysz(poll_data); i++) poll_data[i].revents = 0; - while ((ret = poll(poll_data, arraysz(poll_data), -1)) == -1) { if (errno != EINTR) fail("poll() failed"); } + for (size_t i = 0; i < pollees.num_fds; i++) pollees.poll_data[i].revents = 0; + while ((ret = poll(pollees.poll_data, pollees.num_fds, -1)) == -1) { if (errno != EINTR) fail("poll() failed"); } if (!ret) continue; - if (pd(child_master_fd).revents & POLLIN) if (!read_or_transfer_from_child_tty()) fail("reading from child tty failed"); - if (pd(self_ttyfd).revents & POLLOUT) if (!from_child_to_self()) fail("writing to self tty failed"); - if (pd(self_ttyfd).revents & POLLIN) if (!read_or_transfer_from_self_tty()) fail("reading from self tty failed"); - if (pd(child_master_fd).revents & POLLOUT) if (!from_self_to_child()) fail("writing to child tty failed"); - if (pd(child_master_fd).revents & POLLHUP) { + if (poll_revents(child_master_fd) & POLLIN) if (!read_or_transfer_from_child_tty()) fail("reading from child tty failed"); + if (poll_revents(self_ttyfd) & POLLOUT) if (!from_child_to_self()) fail("writing to self tty failed"); + if (poll_revents(self_ttyfd) & POLLIN) if (!read_or_transfer_from_self_tty()) fail("reading from self tty failed"); + if (poll_revents(child_master_fd) & POLLOUT) if (!from_self_to_child()) fail("writing to child tty failed"); + if (poll_revents(child_master_fd) & POLLHUP) { // child has closed its tty, wait for exit code from prewarm zygote safe_close(child_master_fd); child_master_fd = -1; + unregister_for_poll(child_master_fd); if (!child_pid) return; } - check_fd(self_ttyfd); - if (child_master_fd > -1) check_fd(child_master_fd); - check_fd(signal_read_fd); + check_fd(self_ttyfd); check_fd(child_master_fd); check_fd(signal_read_fd); // signal_read_fd - if (pd(signal_read_fd).revents & POLLIN) if (!read_signals()) fail("reading from signal fd failed"); + if (poll_revents(signal_read_fd) & POLLIN) if (!read_signals()) fail("reading from signal fd failed"); // socket_fd - if (pd(socket_fd).revents & POLLERR) { + if (poll_revents(socket_fd) & POLLERR) { print_error("File descriptor socket_fd failed", 0); return; } - if (pd(socket_fd).revents & POLLIN) { + if (poll_revents(socket_fd) & POLLIN) { if (!read_child_data()) fail("reading information about child failed"); } - if (pd(socket_fd).revents & POLLHUP) { + if (poll_revents(socket_fd) & POLLHUP) { if (from_child_buf[0]) { parse_int(from_child_buf, &exit_status); } - child_pid = 0; + child_pid = 0; safe_close(socket_fd); socket_fd = -1; + unregister_for_poll(socket_fd); if (child_master_fd < 0) return; } - if (pd(socket_fd).revents & POLLOUT) { + if (poll_revents(socket_fd) & POLLOUT) { if (!send_launch_msg()) fail("sending launch message failed"); } } -#undef pd #undef check_fd #undef fail }