Use a process supervisor for socket workers

This simplifies the code and also allows SIGTSTP to work as the worker
process is no longer in an orphaned process group.
This commit is contained in:
Kovid Goyal 2022-07-13 09:36:12 +05:30
parent dda28efd66
commit 9c30cd8891
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 182 additions and 224 deletions

View File

@ -194,10 +194,9 @@ establish_controlling_tty(PyObject *self UNUSED, PyObject *args) {
if (stdin_fd > -1 && safe_dup2(tfd, stdin_fd) == -1) fail();
if (stdout_fd > -1 && safe_dup2(tfd, stdout_fd) == -1) fail();
if (stderr_fd > -1 && safe_dup2(tfd, stderr_fd) == -1) fail();
cleanup();
#undef cleanup
#undef fail
Py_RETURN_NONE;
return PyLong_FromLong(tfd);
}
static PyMethodDef module_methods[] = {

View File

@ -1395,7 +1395,7 @@ def sigqueue(pid: int, signal: int, value: int) -> None:
pass
def establish_controlling_tty(tty_name: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> None:
def establish_controlling_tty(tty_name: str, stdin: int = -1, stdout: int = -1, stderr: int = -1) -> int:
pass

View File

@ -47,6 +47,9 @@ def restore_python_signal_handlers() -> None:
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
signal.signal(signal.SIGTSTP, signal.SIG_DFL)
signal.signal(signal.SIGTTIN, signal.SIG_DFL)
signal.signal(signal.SIGTTOU, signal.SIG_DFL)
def print_error(*a: Any) -> None:
@ -354,7 +357,7 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl
if tty_name:
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
open(establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()), 'w').close()
os.close(w)
if shm.unlink_on_exit:
child_main(cmd, ready_fd_read)
@ -377,201 +380,217 @@ def verify_socket_creds(conn: socket.socket) -> bool:
return uid == os.geteuid() and gid == os.getegid()
class SocketChild:
def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll):
self.registered = True
self.poll = poll
self.addr = addr
self.conn = conn
self.winsize = 8
self.poll.register(self.conn.fileno(), select.POLLIN)
self.input_buf = self.output_buf = b''
self.fds: List[int] = []
self.child_id = -1
class SocketChildData:
def __init__(self) -> None:
self.cwd = self.tty_name = ''
self.env: Dict[str, str] = {}
self.argv: List[str] = []
self.stdin = self.stdout = self.stderr = -1
self.pid = -1
self.closed = False
self.launch_msg_read = False
self.env: Dict[str, str] = {}
def unregister_from_poll(self) -> None:
if self.registered:
fd = self.conn.fileno()
if fd > -1:
self.poll.unregister(self.conn.fileno())
self.registered = False
def read(self) -> None:
import fcntl
import termios
msg = self.conn.recv(io.DEFAULT_BUFFER_SIZE)
def fork_socket_child(child_data: SocketChildData, tty_fd: int, stdio_fds: Dict[str, int], free_non_child_resources: Callable[[], None]) -> int:
# see https://www.gnu.org/software/libc/manual/html_node/Launching-Jobs.html
child_pid = safer_fork()
if child_pid:
return child_pid
# child process
os.setpgid(0, 0)
os.tcsetpgrp(tty_fd, os.getpgid(0))
restore_python_signal_handlers()
# the std streams fds are closed in free_non_child_resources()
for which in ('stdin', 'stdout', 'stderr'):
fd = stdio_fds[which] if stdio_fds[which] > -1 else tty_fd
os.dup2(fd, getattr(sys, which).fileno())
free_non_child_resources()
child_main({'cwd': child_data.cwd, 'env': child_data.env, 'argv': child_data.argv})
def fork_socket_child_supervisor(conn: socket.socket, free_non_child_resources: Callable[[], None]) -> None:
import array
import fcntl
import termios
global is_zygote
if safer_fork():
conn.close()
return
is_zygote = False
os.setsid()
restore_python_signal_handlers()
free_non_child_resources()
# See https://www.gnu.org/software/libc/manual/html_node/Initializing-the-Shell.html
signal_read_fd = install_signal_handlers(
signal.SIGCHLD, signal.SIGUSR1, signal.SIGINT, signal.SIGTSTP, signal.SIGTTIN, signal.SIGTTOU, signal.SIGQUIT
)[0]
poll = select.poll()
poll.register(signal_read_fd, select.POLLIN)
from_socket_buf = b''
to_socket_buf = b''
keep_going = True
child_pid = -1
socket_fd = conn.fileno()
launch_msg_read = False
os.set_blocking(socket_fd, False)
received_fds: List[int] = []
stdio_positions = dict.fromkeys(('stdin', 'stdout', 'stderr'), -1)
stdio_fds = dict.fromkeys(('stdin', 'stdout', 'stderr'), -1)
winsize = 8
exit_after_write = False
child_data = SocketChildData()
def handle_signal(siginfo: SignalInfo) -> None:
nonlocal to_socket_buf, exit_after_write, child_pid
if siginfo.si_signo != signal.SIGCHLD or siginfo.si_code not in (CLD_KILLED, CLD_EXITED, CLD_STOPPED):
return
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG | os.WUNTRACED)
except ChildProcessError:
pid = 0
if not pid:
break
if pid != child_pid:
continue
to_socket_buf += struct.pack('q', status)
if not os.WIFSTOPPED(status):
exit_after_write = True
child_pid = -1
def write_to_socket() -> None:
nonlocal keep_going, to_socket_buf, keep_going
buf = memoryview(to_socket_buf)
while buf:
try:
n = os.write(socket_fd, buf)
except OSError:
n = 0
if n == 0:
keep_going = False
return
buf = buf[n:]
to_socket_buf = bytes(buf)
if exit_after_write and not to_socket_buf:
keep_going = False
def read_winsize() -> None:
nonlocal from_socket_buf
msg = conn.recv(io.DEFAULT_BUFFER_SIZE)
if not msg:
return
self.input_buf += msg
data = memoryview(self.input_buf)
while len(data) >= self.winsize:
record, data = data[:self.winsize], data[self.winsize:]
with open(os.open(self.tty_name, os.O_RDWR | os.O_CLOEXEC | os.O_NOCTTY, 0), 'rb') as f:
from_socket_buf += msg
data = memoryview(from_socket_buf)
while len(data) >= winsize:
record, data = data[:winsize], data[winsize:]
with open(os.open(os.ctermid(), os.O_RDWR | os.O_CLOEXEC | os.O_NOCTTY, 0), 'rb') as f:
fcntl.ioctl(f.fileno(), termios.TIOCSWINSZ, record)
self.input_buf = bytes(data)
from_socket_buf = bytes(data)
def read_launch_msg(self) -> bool:
import array
fds = array.array("i") # Array of ints
def read_launch_msg() -> bool:
nonlocal keep_going, from_socket_buf, launch_msg_read, winsize
try:
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024)
msg, ancdata, flags, addr = conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024)
except OSError as e:
if e.errno == errno.ENOMEM:
# macOS does this when no ancilliary data is present
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE)
msg, ancdata, flags, addr = conn.recvmsg(io.DEFAULT_BUFFER_SIZE)
else:
raise
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
fds = array.array("i") # Array of ints
# Append data, ignoring any truncated integers at the end.
fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
self.fds += list(fds)
received_fds.extend(fds)
if not msg:
return False
self.input_buf += msg
while (idx := self.input_buf.find(b'\0')) > -1:
line = self.input_buf[:idx].decode('utf-8')
self.input_buf = self.input_buf[idx+1:]
from_socket_buf += msg
while (idx := from_socket_buf.find(b'\0')) > -1:
line = from_socket_buf[:idx].decode('utf-8')
from_socket_buf = from_socket_buf[idx+1:]
cmd, _, payload = line.partition(':')
if cmd == 'finish':
self.launch_msg_read = True
for x in self.fds:
os.set_inheritable(x, x is not self.fds[0])
for x in received_fds:
os.set_inheritable(x, True)
os.set_blocking(x, True)
if self.stdin > -1:
self.stdin = self.fds[self.stdin]
if self.stdout > -1:
self.stdout = self.fds[self.stdout]
if self.stderr > -1:
self.stderr = self.fds[self.stderr]
del self.fds[:]
for k, pos in stdio_positions.items():
if pos > -1:
stdio_fds[k] = received_fds[pos]
del received_fds[:]
return True
elif cmd == 'cwd':
self.cwd = payload
child_data.cwd = payload
elif cmd == 'env':
k, _, v = payload.partition('=')
self.env[k] = v
child_data.env[k] = v
elif cmd == 'argv':
self.argv.append(payload)
elif cmd == 'stdin':
self.stdin = int(payload)
elif cmd == 'stdout':
self.stdout = int(payload)
elif cmd == 'stderr':
self.stderr = int(payload)
child_data.argv.append(payload)
elif cmd in stdio_positions:
stdio_positions[cmd] = int(payload)
elif cmd == 'tty_name':
self.tty_name = payload
child_data.tty_name = payload
elif cmd == 'winsize':
self.winsize = int(payload)
winsize = int(payload)
return False
def fork(self, free_non_child_resources: Callable[[], None]) -> None:
global is_zygote
r, w = safe_pipe()
self.pid = safer_fork()
if self.pid > 0:
# master process
os.close(w)
if self.stdin > -1:
os.close(self.stdin)
self.stdin = -1
if self.stdout > -1:
os.close(self.stdout)
self.stdout = -1
if self.stderr > -1:
os.close(self.stderr)
self.stderr = -1
poll = select.poll()
poll.register(r, select.POLLIN)
tuple(poll.poll())
os.close(r)
self.handle_creation()
return
# child process
is_zygote = False
os.close(r)
os.setsid()
restore_python_signal_handlers()
if self.tty_name:
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(
self.tty_name,
sys.__stdin__.fileno() if self.stdin < 0 else -1,
sys.__stdout__.fileno() if self.stdout < 0 else -1,
sys.__stderr__.fileno() if self.stderr < 0 else -1)
# the std streams fds are closed in free_non_child_resources(), see
# SocketChild.close()
if self.stdin > -1:
os.dup2(self.stdin, sys.__stdin__.fileno())
if self.stdout > -1:
os.dup2(self.stdout, sys.__stdout__.fileno())
if self.stderr > -1:
os.dup2(self.stderr, sys.__stderr__.fileno())
os.close(w)
free_non_child_resources()
child_main({'cwd': self.cwd, 'env': self.env, 'argv': self.argv})
raise SystemExit(0)
def free_non_child_resources2() -> None:
for fd in received_fds:
os.close(fd)
for k, v in tuple(stdio_fds.items()):
if v > -1:
os.close(v)
stdio_fds[k] = -1
conn.close()
def handle_stop(self, status: int) -> None:
if self.closed:
return
try:
self.conn.sendall(struct.pack('q', status))
except OSError as e:
print_error(f'Failed to send exit status of socket child with error: {e}')
def launch_child() -> None:
nonlocal to_socket_buf, child_pid
sys.__stdout__.flush()
sys.__stderr__.flush()
tty_fd = establish_controlling_tty(child_data.tty_name)
child_pid = fork_socket_child(child_data, tty_fd, stdio_fds, free_non_child_resources2)
if child_pid:
# this is also done in the child process, but we dont
# know when, so do it here as well
os.setpgid(child_pid, child_pid)
os.tcsetpgrp(tty_fd, child_pid)
for fd in stdio_fds.values():
if fd > -1:
os.close(fd)
os.close(tty_fd)
else:
raise SystemExit('fork_socket_child() returned in the child process')
to_socket_buf += struct.pack('q', child_pid)
def handle_death(self, status: int) -> None:
if self.closed:
return
try:
self.conn.sendall(struct.pack('q', status))
except OSError as e:
print_error(f'Failed to send exit status of socket child with error: {e}')
def read_from_socket() -> None:
nonlocal launch_msg_read
if launch_msg_read:
read_winsize()
else:
if read_launch_msg():
launch_msg_read = True
launch_child()
def handle_creation(self) -> bool:
if self.closed:
return False
try:
self.conn.sendall(struct.pack('q', self.pid))
except OSError as e:
print_error(f'Failed to send pid of socket child with error: {e}')
return False
return True
def close(self) -> None:
if self.closed:
return
self.unregister_from_poll()
self.closed = True
if is_zygote:
try:
while keep_going:
poll.register(socket_fd, select.POLLIN | (select.POLLOUT if to_socket_buf else 0))
for fd, event in poll.poll():
if event & error_events:
keep_going = False
break
if fd == socket_fd:
if event & select.POLLOUT:
write_to_socket()
if event & select.POLLIN:
read_from_socket()
elif fd == signal_read_fd and event & select.POLLIN:
read_signals(signal_read_fd, handle_signal)
finally:
if child_pid: # supervisor process
with suppress(OSError):
self.conn.shutdown(socket.SHUT_RDWR)
with suppress(OSError):
self.conn.close()
for x in self.fds:
os.close(x)
del self.fds[:]
if self.stdin > -1:
os.close(self.stdin)
self.stdin = -1
if self.stdout > -1:
os.close(self.stdout)
self.stdout = -1
if self.stderr > -1:
os.close(self.stderr)
self.stderr = -1
__del__ = close
conn.shutdown(socket.SHUT_RDWR)
with suppress(OSError):
conn.close()
raise SystemExit(0)
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
@ -591,19 +610,12 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
input_buf = output_buf = child_death_buf = b''
child_ready_fds: Dict[int, int] = {}
child_pid_map: Dict[int, int] = {}
socket_pid_map: Dict[int, SocketChild] = {}
socket_children: Dict[int, SocketChild] = {}
child_id_counter = count()
# runpy issues a warning when running modules that have already been
# imported. Ignore it.
warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy')
prewarm()
def remove_socket_child(sc: SocketChild) -> None:
socket_children.pop(sc.conn.fileno(), None)
socket_pid_map.pop(sc.pid, None)
sc.close()
def get_all_non_child_fds() -> Iterator[int]:
yield notify_child_death_fd
yield stdin_fd
@ -616,8 +628,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
if fd > -1:
os.close(fd)
unix_socket.close()
for sc in tuple(socket_children.values()):
remove_socket_child(sc)
def check_event(event: int, err_msg: str) -> None:
if event & select.POLLHUP:
@ -711,27 +721,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
pid = 0
if not pid:
break
if os.WIFSTOPPED(status):
sc = socket_pid_map.get(pid)
if sc is not None:
try:
sc.handle_stop(status)
except Exception:
import traceback
traceback.print_exc()
return
child_id = child_pid_map.pop(pid, None)
if child_id is None:
sc = socket_pid_map.get(pid)
if sc is not None:
try:
sc.handle_death(status)
except Exception:
import traceback
traceback.print_exc()
finally:
remove_socket_child(sc)
else:
if child_id is not None:
handle_child_death(child_id, pid)
read_signals(signal_read_fd, handle_signal)
@ -743,35 +734,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
print_error('Connection attempted with invalid credentials ignoring')
conn.close()
return
sc = SocketChild(conn, addr, poll)
socket_children[sc.conn.fileno()] = sc
def handle_socket_input(fd: int, event: int) -> None:
scq = socket_children.get(q)
if scq is None:
return
if event & select.POLLIN:
if scq.launch_msg_read:
scq.read()
else:
try:
if scq.read_launch_msg():
scq.fork(free_non_child_resources)
socket_pid_map[scq.pid] = scq
scq.child_id = next(child_id_counter)
except OSError:
if is_zygote:
remove_socket_child(scq)
import traceback
tb = traceback.format_exc()
print_error(f'Failed to fork socket child with error: {tb}')
else:
raise
if is_zygote and (event & error_events):
if event & select.POLLHUP:
scq.unregister_from_poll()
else:
remove_socket_child(scq)
fork_socket_child_supervisor(conn, free_non_child_resources)
keep_type_checker_happy = True
try:
@ -791,8 +754,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
handle_notify_child_death(event)
elif q == unix_socket.fileno():
handle_socket_client(event)
else:
handle_socket_input(q, event)
except (KeyboardInterrupt, EOFError, BrokenPipeError):
if is_zygote:
raise SystemExit(1)
@ -808,8 +769,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
for fmd in child_ready_fds.values():
with suppress(OSError):
os.close(fmd)
for sc in tuple(socket_children.values()):
remove_socket_child(sc)
def get_socket_name(unix_socket: socket.socket) -> str: