Avoid needing to call os.getpid() repeatedly

This commit is contained in:
Kovid Goyal 2022-07-03 21:19:29 +05:30
parent 4e29c0c16b
commit 8332cd2f79
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -263,6 +263,15 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace') super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace')
parent_tty_name = ''
is_zygote = True
def debug(*a: Any) -> None:
with open(parent_tty_name, 'w') as f:
print(*a, file=f)
def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn: def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
cwd = cmd.get('cwd') cwd = cmd.get('cwd')
if cwd: if cwd:
@ -285,6 +294,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]: def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
global is_zygote
sz = pos = 0 sz = pos = 0
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm: with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
data = shm.read_data_with_size() data = shm.read_data_with_size()
@ -317,6 +327,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
os.close(r) os.close(r)
return child_pid, ready_fd_write return child_pid, ready_fd_write
# child process # child process
is_zygote = False
remove_signal_handlers() remove_signal_handlers()
os.close(r) os.close(r)
os.close(ready_fd_write) os.close(ready_fd_write)
@ -408,6 +419,7 @@ class SocketChild:
return False return False
def fork(self, all_non_child_fds: Iterable[int]) -> None: def fork(self, all_non_child_fds: Iterable[int]) -> None:
global is_zygote
r, w = safe_pipe() r, w = safe_pipe()
self.pid = os.fork() self.pid = os.fork()
if self.pid > 0: if self.pid > 0:
@ -429,6 +441,7 @@ class SocketChild:
self.handle_creation() self.handle_creation()
return return
# child process # child process
is_zygote = False
os.close(r) os.close(r)
os.setsid() os.setsid()
remove_signal_handlers() remove_signal_handlers()
@ -475,6 +488,8 @@ class SocketChild:
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None: def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
global parent_tty_name
parent_tty_name = os.ttyname(sys.stdout.fileno())
os.set_blocking(notify_child_death_fd, False) os.set_blocking(notify_child_death_fd, False)
os.set_blocking(stdin_fd, False) os.set_blocking(stdin_fd, False)
os.set_blocking(stdout_fd, False) os.set_blocking(stdout_fd, False)
@ -491,7 +506,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
socket_pid_map: Dict[int, SocketChild] = {} socket_pid_map: Dict[int, SocketChild] = {}
socket_children: Dict[int, SocketChild] = {} socket_children: Dict[int, SocketChild] = {}
child_id_counter = count() child_id_counter = count()
self_pid = os.getpid()
# runpy issues a warning when running modules that have already been # runpy issues a warning when running modules that have already been
# imported. Ignore it. # imported. Ignore it.
warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy') warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy')
@ -548,7 +562,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
es = str(e).replace('\n', ' ') es = str(e).replace('\n', ' ')
output_buf += f'ERR:{es}\n'.encode() output_buf += f'ERR:{es}\n'.encode()
else: else:
if os.getpid() == self_pid: if is_zygote:
child_id = next(child_id_counter) child_id = next(child_id_counter)
child_pid_map[child_pid] = child_id child_pid_map[child_pid] = child_id
child_ready_fds[child_id] = ready_fd_write child_ready_fds[child_id] = ready_fd_write
@ -635,7 +649,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
except SocketClosed: except SocketClosed:
socket_children.pop(q, None) socket_children.pop(q, None)
except OSError as e: except OSError as e:
if os.getpid() == self_pid: if is_zygote:
socket_children.pop(q, None) socket_children.pop(q, None)
print_error(f'Failed to fork socket child with error: {e}') print_error(f'Failed to fork socket child with error: {e}')
else: else:
@ -643,8 +657,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
if event & error_events: if event & error_events:
socket_children.pop(q, None) socket_children.pop(q, None)
keep_type_checker_happy = True
try: try:
while True: while is_zygote and keep_type_checker_happy:
if output_buf: if output_buf:
poll.register(stdout_fd, select.POLLOUT) poll.register(stdout_fd, select.POLLOUT)
if child_death_buf: if child_death_buf:
@ -663,16 +678,16 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
else: else:
handle_socket_launch(q, event) handle_socket_launch(q, event)
except (KeyboardInterrupt, EOFError, BrokenPipeError): except (KeyboardInterrupt, EOFError, BrokenPipeError):
if os.getpid() == self_pid: if is_zygote:
raise SystemExit(1) raise SystemExit(1)
raise raise
except Exception: except Exception:
if os.getpid() == self_pid: if is_zygote:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
finally: finally:
if os.getpid() == self_pid: if is_zygote:
remove_signal_handlers() remove_signal_handlers()
for fmd in child_ready_fds.values(): for fmd in child_ready_fds.values():
with suppress(OSError): with suppress(OSError):