Cleanup closing of socket child

This commit is contained in:
Kovid Goyal 2022-07-03 21:42:34 +05:30
parent 8332cd2f79
commit dbb084da7a
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -362,8 +362,11 @@ class SocketClosed(Exception):
class SocketChild: class SocketChild:
def __init__(self, conn: socket.socket, addr: bytes): def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll):
self.fd = conn.fileno() self.fd = conn.fileno()
poll.register(self.fd, select.POLLIN)
self.registered = True
self.poll = poll
self.addr = addr self.addr = addr
self.conn = conn self.conn = conn
self.input_buf = self.output_buf = b'' self.input_buf = self.output_buf = b''
@ -375,6 +378,11 @@ class SocketChild:
self.stdin = self.stdout = self.stderr = -1 self.stdin = self.stdout = self.stderr = -1
self.pid = -1 self.pid = -1
def unregister_from_poll(self) -> None:
if self.registered:
self.poll.unregister(self.fd)
self.registered = False
def read(self) -> bool: def read(self) -> bool:
import array import array
fds = array.array("i") # Array of ints fds = array.array("i") # Array of ints
@ -511,6 +519,12 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy') warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy')
prewarm() prewarm()
def remove_socket_child(sc: SocketChild) -> None:
socket_children.pop(sc.fd, None)
sc.unregister_from_poll()
socket_pid_map.pop(sc.pid, None)
sc.conn.close()
def get_all_non_child_fds() -> Iterator[int]: def get_all_non_child_fds() -> Iterator[int]:
yield notify_child_death_fd yield notify_child_death_fd
yield stdin_fd yield stdin_fd
@ -620,9 +634,12 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
break break
child_id = child_pid_map.pop(pid, None) child_id = child_pid_map.pop(pid, None)
if child_id is None: if child_id is None:
sc = socket_pid_map.pop(pid, None) sc = socket_pid_map.get(pid)
if sc is not None: if sc is not None:
sc.handle_death(status) try:
sc.handle_death(status)
finally:
remove_socket_child(sc)
else: else:
handle_child_death(child_id, pid) handle_child_death(child_id, pid)
@ -631,9 +648,8 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
def handle_socket_client(event: int) -> None: def handle_socket_client(event: int) -> None:
check_event(event, 'UNIX socket fd listener failed') check_event(event, 'UNIX socket fd listener failed')
conn, addr = unix_socket.accept() conn, addr = unix_socket.accept()
sc = SocketChild(conn, addr) sc = SocketChild(conn, addr, poll)
socket_children[sc.fd] = sc socket_children[sc.fd] = sc
poll.register(sc.fd, select.POLLIN)
def handle_socket_launch(fd: int, event: int) -> None: def handle_socket_launch(fd: int, event: int) -> None:
scq = socket_children.get(q) scq = socket_children.get(q)
@ -642,20 +658,21 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
if event & select.POLLIN: if event & select.POLLIN:
try: try:
if scq.read(): if scq.read():
poll.unregister(scq.fd) scq.unregister_from_poll()
scq.fork(get_all_non_child_fds()) scq.fork(get_all_non_child_fds())
socket_pid_map[scq.pid] = scq socket_pid_map[scq.pid] = scq
scq.child_id = next(child_id_counter) scq.child_id = next(child_id_counter)
except SocketClosed: except SocketClosed:
socket_children.pop(q, None) if is_zygote:
remove_socket_child(scq)
except OSError as e: except OSError as e:
if is_zygote: if is_zygote:
socket_children.pop(q, None) remove_socket_child(scq)
print_error(f'Failed to fork socket child with error: {e}') print_error(f'Failed to fork socket child with error: {e}')
else: else:
raise raise
if event & error_events: if is_zygote and (event & error_events):
socket_children.pop(q, None) remove_socket_child(scq)
keep_type_checker_happy = True keep_type_checker_happy = True
try: try: