diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index 8e0ab4074..2377e7335 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -10,7 +10,6 @@ import re import secrets import shlex import stat -import struct import sys import tarfile import time @@ -164,8 +163,7 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]: mode = stat.S_IMODE(shm.stats.st_mode) if mode != stat.S_IREAD: raise ValueError('Incorrect permissions on pwfile') - sz = struct.unpack('>I', shm.read(struct.calcsize('>I')))[0] - env_data = json.loads(shm.read(sz)) + env_data = json.loads(shm.read_data_with_size()) if pw != env_data['pw']: raise ValueError('Incorrect password') if rq_id != request_id: @@ -229,11 +227,9 @@ def bootstrap_script( ddir = os.path.join(cache_dir(), 'ssh') os.makedirs(ddir, exist_ok=True) data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname} - db = json.dumps(data).encode('utf-8') - sz = struct.pack('>I', len(db)) - with SharedMemory(size=len(db) + len(sz), mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm: - shm.write(sz) - shm.write(db) + db = json.dumps(data) + with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm: + shm.write_data_with_size(db) shm.flush() atexit.register(shm.unlink) replacements = { diff --git a/kitty/shm.py b/kitty/shm.py index 82eee46bf..4d5c04b4b 100644 --- a/kitty/shm.py +++ b/kitty/shm.py @@ -10,7 +10,8 @@ import mmap import os import secrets import stat -from typing import Optional +import struct +from typing import Optional, Union from kitty.fast_data_types import SHM_NAME_MAX, shm_open, shm_unlink @@ -42,6 +43,8 @@ class SharedMemory: _name: str = '' _mmap: Optional[mmap.mmap] = None _size: int = 0 + size_fmt = '!I' + num_bytes_for_size = struct.calcsize(size_fmt) def __init__( self, name: str = '', size: int = 0, readonly: bool = False, @@ -104,6 +107,17 @@ class SharedMemory: def flush(self) -> None: self.mmap.flush() + def write_data_with_size(self, data: Union[str, bytes]) -> None: + if isinstance(data, str): + data = data.encode('utf-8') + sz = struct.pack(self.size_fmt, len(data)) + self.write(sz) + self.write(data) + + def read_data_with_size(self) -> bytes: + sz = struct.unpack(self.size_fmt, self.read(self.num_bytes_for_size))[0] + return self.read(sz) + def __del__(self) -> None: try: self.close()