Avoid needing to call os.getpid() repeatedly

This commit is contained in:
Kovid Goyal 2022-07-03 21:19:29 +05:30
parent 4e29c0c16b
commit 8332cd2f79
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -263,6 +263,15 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace')
parent_tty_name = ''
is_zygote = True
def debug(*a: Any) -> None:
with open(parent_tty_name, 'w') as f:
print(*a, file=f)
def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
cwd = cmd.get('cwd')
if cwd:
@ -285,6 +294,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
global is_zygote
sz = pos = 0
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
data = shm.read_data_with_size()
@ -317,6 +327,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
os.close(r)
return child_pid, ready_fd_write
# child process
is_zygote = False
remove_signal_handlers()
os.close(r)
os.close(ready_fd_write)
@ -408,6 +419,7 @@ class SocketChild:
return False
def fork(self, all_non_child_fds: Iterable[int]) -> None:
global is_zygote
r, w = safe_pipe()
self.pid = os.fork()
if self.pid > 0:
@ -429,6 +441,7 @@ class SocketChild:
self.handle_creation()
return
# child process
is_zygote = False
os.close(r)
os.setsid()
remove_signal_handlers()
@ -475,6 +488,8 @@ class SocketChild:
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
global parent_tty_name
parent_tty_name = os.ttyname(sys.stdout.fileno())
os.set_blocking(notify_child_death_fd, False)
os.set_blocking(stdin_fd, False)
os.set_blocking(stdout_fd, False)
@ -491,7 +506,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
socket_pid_map: Dict[int, SocketChild] = {}
socket_children: Dict[int, SocketChild] = {}
child_id_counter = count()
self_pid = os.getpid()
# runpy issues a warning when running modules that have already been
# imported. Ignore it.
warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy')
@ -548,7 +562,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
es = str(e).replace('\n', ' ')
output_buf += f'ERR:{es}\n'.encode()
else:
if os.getpid() == self_pid:
if is_zygote:
child_id = next(child_id_counter)
child_pid_map[child_pid] = child_id
child_ready_fds[child_id] = ready_fd_write
@ -635,7 +649,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
except SocketClosed:
socket_children.pop(q, None)
except OSError as e:
if os.getpid() == self_pid:
if is_zygote:
socket_children.pop(q, None)
print_error(f'Failed to fork socket child with error: {e}')
else:
@ -643,8 +657,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
if event & error_events:
socket_children.pop(q, None)
keep_type_checker_happy = True
try:
while True:
while is_zygote and keep_type_checker_happy:
if output_buf:
poll.register(stdout_fd, select.POLLOUT)
if child_death_buf:
@ -663,16 +678,16 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
else:
handle_socket_launch(q, event)
except (KeyboardInterrupt, EOFError, BrokenPipeError):
if os.getpid() == self_pid:
if is_zygote:
raise SystemExit(1)
raise
except Exception:
if os.getpid() == self_pid:
if is_zygote:
import traceback
traceback.print_exc()
raise
finally:
if os.getpid() == self_pid:
if is_zygote:
remove_signal_handlers()
for fmd in child_ready_fds.values():
with suppress(OSError):