This commit is contained in:
Kovid Goyal 2022-03-10 15:37:10 +05:30
parent f67009f554
commit c23e04fd03
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 19 additions and 9 deletions

View File

@ -10,7 +10,6 @@ import re
import secrets import secrets
import shlex import shlex
import stat import stat
import struct
import sys import sys
import tarfile import tarfile
import time import time
@ -164,8 +163,7 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
mode = stat.S_IMODE(shm.stats.st_mode) mode = stat.S_IMODE(shm.stats.st_mode)
if mode != stat.S_IREAD: if mode != stat.S_IREAD:
raise ValueError('Incorrect permissions on pwfile') raise ValueError('Incorrect permissions on pwfile')
sz = struct.unpack('>I', shm.read(struct.calcsize('>I')))[0] env_data = json.loads(shm.read_data_with_size())
env_data = json.loads(shm.read(sz))
if pw != env_data['pw']: if pw != env_data['pw']:
raise ValueError('Incorrect password') raise ValueError('Incorrect password')
if rq_id != request_id: if rq_id != request_id:
@ -229,11 +227,9 @@ def bootstrap_script(
ddir = os.path.join(cache_dir(), 'ssh') ddir = os.path.join(cache_dir(), 'ssh')
os.makedirs(ddir, exist_ok=True) os.makedirs(ddir, exist_ok=True)
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname} data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
db = json.dumps(data).encode('utf-8') db = json.dumps(data)
sz = struct.pack('>I', len(db)) with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm:
with SharedMemory(size=len(db) + len(sz), mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm: shm.write_data_with_size(db)
shm.write(sz)
shm.write(db)
shm.flush() shm.flush()
atexit.register(shm.unlink) atexit.register(shm.unlink)
replacements = { replacements = {

View File

@ -10,7 +10,8 @@ import mmap
import os import os
import secrets import secrets
import stat import stat
from typing import Optional import struct
from typing import Optional, Union
from kitty.fast_data_types import SHM_NAME_MAX, shm_open, shm_unlink from kitty.fast_data_types import SHM_NAME_MAX, shm_open, shm_unlink
@ -42,6 +43,8 @@ class SharedMemory:
_name: str = '' _name: str = ''
_mmap: Optional[mmap.mmap] = None _mmap: Optional[mmap.mmap] = None
_size: int = 0 _size: int = 0
size_fmt = '!I'
num_bytes_for_size = struct.calcsize(size_fmt)
def __init__( def __init__(
self, name: str = '', size: int = 0, readonly: bool = False, self, name: str = '', size: int = 0, readonly: bool = False,
@ -104,6 +107,17 @@ class SharedMemory:
def flush(self) -> None: def flush(self) -> None:
self.mmap.flush() self.mmap.flush()
def write_data_with_size(self, data: Union[str, bytes]) -> None:
if isinstance(data, str):
data = data.encode('utf-8')
sz = struct.pack(self.size_fmt, len(data))
self.write(sz)
self.write(data)
def read_data_with_size(self) -> bytes:
sz = struct.unpack(self.size_fmt, self.read(self.num_bytes_for_size))[0]
return self.read(sz)
def __del__(self) -> None: def __del__(self) -> None:
try: try:
self.close() self.close()