Avoid needing to call os.getpid() repeatedly
This commit is contained in:
parent
4e29c0c16b
commit
8332cd2f79
@ -263,6 +263,15 @@ class MemoryViewReadWrapper(io.TextIOWrapper):
|
|||||||
super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace')
|
super().__init__(cast(IO[bytes], MemoryViewReadWrapperBytes(mw)), encoding='utf-8', errors='replace')
|
||||||
|
|
||||||
|
|
||||||
|
parent_tty_name = ''
|
||||||
|
is_zygote = True
|
||||||
|
|
||||||
|
|
||||||
|
def debug(*a: Any) -> None:
|
||||||
|
with open(parent_tty_name, 'w') as f:
|
||||||
|
print(*a, file=f)
|
||||||
|
|
||||||
|
|
||||||
def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
|
def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
|
||||||
cwd = cmd.get('cwd')
|
cwd = cmd.get('cwd')
|
||||||
if cwd:
|
if cwd:
|
||||||
@ -285,6 +294,7 @@ def child_main(cmd: Dict[str, Any], ready_fd: int = -1) -> NoReturn:
|
|||||||
|
|
||||||
|
|
||||||
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
||||||
|
global is_zygote
|
||||||
sz = pos = 0
|
sz = pos = 0
|
||||||
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
|
with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
|
||||||
data = shm.read_data_with_size()
|
data = shm.read_data_with_size()
|
||||||
@ -317,6 +327,7 @@ def fork(shm_address: str, all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
|
|||||||
os.close(r)
|
os.close(r)
|
||||||
return child_pid, ready_fd_write
|
return child_pid, ready_fd_write
|
||||||
# child process
|
# child process
|
||||||
|
is_zygote = False
|
||||||
remove_signal_handlers()
|
remove_signal_handlers()
|
||||||
os.close(r)
|
os.close(r)
|
||||||
os.close(ready_fd_write)
|
os.close(ready_fd_write)
|
||||||
@ -408,6 +419,7 @@ class SocketChild:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def fork(self, all_non_child_fds: Iterable[int]) -> None:
|
def fork(self, all_non_child_fds: Iterable[int]) -> None:
|
||||||
|
global is_zygote
|
||||||
r, w = safe_pipe()
|
r, w = safe_pipe()
|
||||||
self.pid = os.fork()
|
self.pid = os.fork()
|
||||||
if self.pid > 0:
|
if self.pid > 0:
|
||||||
@ -429,6 +441,7 @@ class SocketChild:
|
|||||||
self.handle_creation()
|
self.handle_creation()
|
||||||
return
|
return
|
||||||
# child process
|
# child process
|
||||||
|
is_zygote = False
|
||||||
os.close(r)
|
os.close(r)
|
||||||
os.setsid()
|
os.setsid()
|
||||||
remove_signal_handlers()
|
remove_signal_handlers()
|
||||||
@ -475,6 +488,8 @@ class SocketChild:
|
|||||||
|
|
||||||
|
|
||||||
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
|
def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: socket.socket) -> None:
|
||||||
|
global parent_tty_name
|
||||||
|
parent_tty_name = os.ttyname(sys.stdout.fileno())
|
||||||
os.set_blocking(notify_child_death_fd, False)
|
os.set_blocking(notify_child_death_fd, False)
|
||||||
os.set_blocking(stdin_fd, False)
|
os.set_blocking(stdin_fd, False)
|
||||||
os.set_blocking(stdout_fd, False)
|
os.set_blocking(stdout_fd, False)
|
||||||
@ -491,7 +506,6 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
|||||||
socket_pid_map: Dict[int, SocketChild] = {}
|
socket_pid_map: Dict[int, SocketChild] = {}
|
||||||
socket_children: Dict[int, SocketChild] = {}
|
socket_children: Dict[int, SocketChild] = {}
|
||||||
child_id_counter = count()
|
child_id_counter = count()
|
||||||
self_pid = os.getpid()
|
|
||||||
# runpy issues a warning when running modules that have already been
|
# runpy issues a warning when running modules that have already been
|
||||||
# imported. Ignore it.
|
# imported. Ignore it.
|
||||||
warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy')
|
warnings.filterwarnings('ignore', category=RuntimeWarning, module='runpy')
|
||||||
@ -548,7 +562,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
|||||||
es = str(e).replace('\n', ' ')
|
es = str(e).replace('\n', ' ')
|
||||||
output_buf += f'ERR:{es}\n'.encode()
|
output_buf += f'ERR:{es}\n'.encode()
|
||||||
else:
|
else:
|
||||||
if os.getpid() == self_pid:
|
if is_zygote:
|
||||||
child_id = next(child_id_counter)
|
child_id = next(child_id_counter)
|
||||||
child_pid_map[child_pid] = child_id
|
child_pid_map[child_pid] = child_id
|
||||||
child_ready_fds[child_id] = ready_fd_write
|
child_ready_fds[child_id] = ready_fd_write
|
||||||
@ -635,7 +649,7 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
|||||||
except SocketClosed:
|
except SocketClosed:
|
||||||
socket_children.pop(q, None)
|
socket_children.pop(q, None)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
if os.getpid() == self_pid:
|
if is_zygote:
|
||||||
socket_children.pop(q, None)
|
socket_children.pop(q, None)
|
||||||
print_error(f'Failed to fork socket child with error: {e}')
|
print_error(f'Failed to fork socket child with error: {e}')
|
||||||
else:
|
else:
|
||||||
@ -643,8 +657,9 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
|||||||
if event & error_events:
|
if event & error_events:
|
||||||
socket_children.pop(q, None)
|
socket_children.pop(q, None)
|
||||||
|
|
||||||
|
keep_type_checker_happy = True
|
||||||
try:
|
try:
|
||||||
while True:
|
while is_zygote and keep_type_checker_happy:
|
||||||
if output_buf:
|
if output_buf:
|
||||||
poll.register(stdout_fd, select.POLLOUT)
|
poll.register(stdout_fd, select.POLLOUT)
|
||||||
if child_death_buf:
|
if child_death_buf:
|
||||||
@ -663,16 +678,16 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
|||||||
else:
|
else:
|
||||||
handle_socket_launch(q, event)
|
handle_socket_launch(q, event)
|
||||||
except (KeyboardInterrupt, EOFError, BrokenPipeError):
|
except (KeyboardInterrupt, EOFError, BrokenPipeError):
|
||||||
if os.getpid() == self_pid:
|
if is_zygote:
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
if os.getpid() == self_pid:
|
if is_zygote:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if os.getpid() == self_pid:
|
if is_zygote:
|
||||||
remove_signal_handlers()
|
remove_signal_handlers()
|
||||||
for fmd in child_ready_fds.values():
|
for fmd in child_ready_fds.values():
|
||||||
with suppress(OSError):
|
with suppress(OSError):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user