diff --git a/prewarm-launcher.h b/prewarm-launcher.h index 72120d2a7..9cdf10e0d 100644 --- a/prewarm-launcher.h +++ b/prewarm-launcher.h @@ -218,9 +218,9 @@ setup_signal_handler(void) { static void setup_stdio_handles(void) { int pos = 0; - if (isatty(STDIN_FILENO)) stdin_pos = pos++; - if (isatty(STDOUT_FILENO)) stdout_pos = pos++; - if (isatty(STDERR_FILENO)) stderr_pos = pos++; + if (!isatty(STDIN_FILENO)) stdin_pos = pos++; + if (!isatty(STDOUT_FILENO)) stdout_pos = pos++; + if (!isatty(STDERR_FILENO)) stderr_pos = pos++; } static bool @@ -257,15 +257,17 @@ create_launch_msg(int argc, char *argv[]) { for (int i = 0; i < argc; i++) w("argv", argv[i]); char **s = environ; for (; *s; s++) w("env", *s); + int num_fds = 0, fds[8]; +#define sio(which, x) if (which##_pos > -1) { snprintf(buf, sizeof(buf), "%d", which##_pos); w(#which, buf); fds[num_fds++] = x; } + sio(stdin, STDIN_FILENO); sio(stdout, STDOUT_FILENO); sio(stderr, STDERR_FILENO); +#undef sio + w("finish", ""); struct cmsghdr *cmsg = CMSG_FIRSTHDR(&launch_msg_container); cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = 0; -#define sio(which, x) if (which##_pos > -1) { snprintf(buf, sizeof(buf), "%d", which##_pos); w(#which, buf); int fd = x; memcpy(CMSG_DATA(cmsg) + cmsg->cmsg_len, &fd, sizeof(fd)); cmsg->cmsg_len += sizeof(fd); } - sio(stdin, STDIN_FILENO); sio(stdout, STDOUT_FILENO); sio(stderr, STDERR_FILENO); -#undef sio - launch_msg_container.msg_controllen = CMSG_SPACE(cmsg->cmsg_len); - w("finish", ""); + cmsg->cmsg_len = CMSG_LEN(sizeof(fds[0]) * num_fds); + memcpy(CMSG_DATA(cmsg), fds, num_fds * sizeof(fds[0])); + launch_msg_container.msg_controllen = cmsg->cmsg_len; return true; #undef w } @@ -275,7 +277,7 @@ static char from_child_buf[64] = {0}; static size_t from_child_buf_pos = 0; static bool -read_child_data(int socket_fd) { +read_child_data(void) { ssize_t n; if (from_child_buf_pos >= sizeof(from_child_buf) - 2) { print_error("Too much data from prewarm socket", 0); return false; } while ((n = read(socket_fd, from_child_buf, sizeof(from_child_buf) - 2 - from_child_buf_pos)) < 0 && errno == EINTR); @@ -315,7 +317,7 @@ send_launch_msg(void) { static void -loop(int socket_fd) { +loop(void) { #define fail(s) { print_error(s, errno); return; } #define check_fd(which, name) { if (poll_data[which].revents & POLLERR) { pe("File descriptor %s failed", #name); return; } if (poll_data[which].revents & POLLHUP); { pe("File descriptor %s hungup", #name); return; } } struct pollfd poll_data[4]; @@ -330,14 +332,16 @@ loop(int socket_fd) { for (size_t i = 0; i < arraysz(poll_data); i++) poll_data[i].revents = 0; while (poll(poll_data, arraysz(poll_data), -1) == -1) { if (errno != EINTR) fail("poll() failed"); } - check_fd(0, self_ttyfd); check_fd(1, child_master_fd); check_fd(3, signal_read_fd); + check_fd(0, self_ttyfd); + check_fd(1, child_master_fd); + check_fd(3, signal_read_fd); // socket_fd if (poll_data[2].revents & POLLERR) { print_error("File descriptor socket_fd failed", 0); return; } if (poll_data[2].revents & POLLIN) { - if (!read_child_data(socket_fd)) fail("reading information about child failed"); + if (!read_child_data()) fail("reading information about child failed"); } if (poll_data[2].revents & POLLHUP) { if (from_child_buf[0]) { char *p = memchr(from_child_buf, ':', sizeof(from_child_buf)); if (p) parse_int(p+1, &exit_status); } @@ -373,7 +377,7 @@ use_prewarmed_process(int argc, char *argv[]) { if (socket_fd < 0) fail("Failed to connect to prewarm socket"); #undef fail - loop(socket_fd); + loop(); cleanup(); exit(exit_status); }