Get shared memory based data transfer working
This commit is contained in:
parent
4528173ff5
commit
4c392426f6
@ -10,6 +10,7 @@ import re
|
||||
import secrets
|
||||
import shlex
|
||||
import stat
|
||||
import struct
|
||||
import sys
|
||||
import tarfile
|
||||
import time
|
||||
@ -155,13 +156,13 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
|
||||
try:
|
||||
with SharedMemory(pwfilename, readonly=True) as shm:
|
||||
shm.unlink()
|
||||
st = os.stat(shm.fileno())
|
||||
mode = stat.S_IMODE(st.st_mode)
|
||||
if st.st_uid != os.geteuid() or st.st_gid != os.getegid():
|
||||
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
|
||||
raise ValueError('Incorrect owner on pwfile')
|
||||
mode = stat.S_IMODE(shm.stats.st_mode)
|
||||
if mode != stat.S_IREAD:
|
||||
raise ValueError('Incorrect permissions on pwfile')
|
||||
env_data = json.loads(bytes(shm.buf))
|
||||
sz = struct.unpack('>I', shm.read(struct.calcsize('>I')))[0]
|
||||
env_data = json.loads(shm.read(sz))
|
||||
if pw != env_data['pw']:
|
||||
raise ValueError('Incorrect password')
|
||||
if rq_id != request_id:
|
||||
@ -226,8 +227,10 @@ def bootstrap_script(
|
||||
os.makedirs(ddir, exist_ok=True)
|
||||
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
|
||||
db = json.dumps(data).encode('utf-8')
|
||||
with SharedMemory(create=True, size=len(db), mode=stat.S_IREAD) as shm:
|
||||
shm.buf[:] = db
|
||||
sz = struct.pack('>I', len(db))
|
||||
with SharedMemory(size=len(db) + len(sz), mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm:
|
||||
shm.write(sz)
|
||||
shm.write(db)
|
||||
atexit.register(shm.unlink)
|
||||
replacements = {
|
||||
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
|
||||
|
||||
@ -343,7 +343,7 @@ PyInit_fast_data_types(void) {
|
||||
PyModule_AddIntMacro(m, FILE_TRANSFER_CODE);
|
||||
#ifdef __APPLE__
|
||||
// Apple says its SHM_NAME_MAX but SHM_NAME_MAX is not actually declared in typical CrApple style.
|
||||
// This value is based on experimentation
|
||||
// This value is based on experimentation and from qsharedmemory.cpp in Qt
|
||||
PyModule_AddIntConstant(m, "SHM_NAME_MAX", 30);
|
||||
#else
|
||||
// FreeBSD's man page says this is 1023. Linux says its PATH_MAX.
|
||||
|
||||
89
kitty/shm.py
89
kitty/shm.py
@ -5,6 +5,7 @@
|
||||
# multiprocessing.shared_memory. However, it is crippled in various ways, most
|
||||
# notably using extremely small filenames.
|
||||
|
||||
import errno
|
||||
import mmap
|
||||
import os
|
||||
import secrets
|
||||
@ -18,62 +19,85 @@ def make_filename(prefix: str) -> str:
|
||||
"Create a random filename for the shared memory object."
|
||||
# number of random bytes to use for name. Use a largeish value
|
||||
# to make double unlink safe.
|
||||
safe_length = min(128, SHM_NAME_MAX)
|
||||
if not prefix.startswith('/'):
|
||||
# FreeBSD requires name to start with /
|
||||
prefix = '/' + prefix
|
||||
nbytes = (safe_length - len(prefix)) // 2
|
||||
plen = len(prefix.encode('utf-8'))
|
||||
safe_length = min(plen + 64, SHM_NAME_MAX)
|
||||
if safe_length - plen < 2:
|
||||
raise OSError(errno.ENAMETOOLONG, f'SHM filename prefix {prefix} is too long')
|
||||
nbytes = (safe_length - plen) // 2
|
||||
name = prefix + secrets.token_hex(nbytes)
|
||||
return name
|
||||
|
||||
|
||||
class SharedMemory:
|
||||
_buf: Optional[memoryview] = None
|
||||
'''
|
||||
Create or access randomly named shared memory.
|
||||
|
||||
WARNING: The actual size of the shared memory wmay be larger than the requested size.
|
||||
'''
|
||||
_fd: int = -1
|
||||
_name: str = ''
|
||||
_mmap: Optional[mmap.mmap] = None
|
||||
_size: int = 0
|
||||
|
||||
def __init__(
|
||||
self, name: Optional[str] = None, create: bool = False, size: int = 0, readonly: bool = False,
|
||||
self, name: str = '', size: int = 0, readonly: bool = False,
|
||||
mode: int = stat.S_IREAD | stat.S_IWRITE,
|
||||
prefix: str = 'kitty-'
|
||||
):
|
||||
if not size >= 0:
|
||||
raise ValueError("'size' must be a positive integer")
|
||||
if create:
|
||||
if size < 0:
|
||||
raise TypeError("'size' must be a non-negative integer")
|
||||
if size and name:
|
||||
raise TypeError('Cannot specify both name and size')
|
||||
if not name:
|
||||
flags = os.O_CREAT | os.O_EXCL
|
||||
if size <= 0:
|
||||
raise ValueError("'size' must be > 0")
|
||||
if not size:
|
||||
raise TypeError("'size' must be > 0")
|
||||
else:
|
||||
flags = 0
|
||||
flags |= os.O_RDONLY if readonly else os.O_RDWR
|
||||
if name is None and not flags & os.O_EXCL:
|
||||
raise ValueError("'name' can only be None if create=True")
|
||||
|
||||
if name is None:
|
||||
while True:
|
||||
name = make_filename(prefix)
|
||||
try:
|
||||
self._fd = shm_open(name, flags, mode)
|
||||
except FileExistsError:
|
||||
continue
|
||||
break
|
||||
else:
|
||||
tries = 30
|
||||
while not name and tries > 0:
|
||||
tries -= 1
|
||||
q = make_filename(prefix)
|
||||
try:
|
||||
self._fd = shm_open(q, flags, mode)
|
||||
name = q
|
||||
except FileExistsError:
|
||||
continue
|
||||
if tries <= 0:
|
||||
raise OSError(f'Failed to create a uniquely named SHM file, try shortening the prefix from: {prefix}')
|
||||
if self._fd < 0:
|
||||
self._fd = shm_open(name, flags)
|
||||
self._name = name
|
||||
try:
|
||||
if create and size:
|
||||
if flags & os.O_CREAT and size:
|
||||
os.ftruncate(self._fd, size)
|
||||
stats = os.fstat(self._fd)
|
||||
size = stats.st_size
|
||||
self.stats = os.fstat(self._fd)
|
||||
size = self.stats.st_size
|
||||
self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE)
|
||||
except OSError:
|
||||
self.unlink()
|
||||
raise
|
||||
|
||||
self._size = size
|
||||
self._buf = memoryview(self._mmap)
|
||||
|
||||
def read(self, sz: int = 0) -> bytes:
|
||||
if sz <= 0:
|
||||
sz = self.size
|
||||
return self.mmap.read(sz)
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
self.mmap.write(data)
|
||||
|
||||
def tell(self) -> int:
|
||||
return self.mmap.tell()
|
||||
|
||||
def seek(self, pos: int, whence: int = os.SEEK_SET) -> None:
|
||||
self.mmap.seek(pos, whence)
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
@ -95,25 +119,22 @@ class SharedMemory:
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def mmap(self) -> mmap.mmap:
|
||||
ans = self._mmap
|
||||
if ans is None:
|
||||
raise RuntimeError('Cannot access the mmap of a closed shared memory object')
|
||||
return ans
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self._fd
|
||||
|
||||
@property
|
||||
def buf(self) -> memoryview:
|
||||
ans = self._buf
|
||||
if ans is None:
|
||||
raise RuntimeError('Cannot access the buffer of a closed shared memory object')
|
||||
return ans
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes access to the shared memory from this instance but does
|
||||
not destroy the shared memory block."""
|
||||
if self._buf is not None:
|
||||
self._buf.release()
|
||||
self._buf = None
|
||||
if self._mmap is not None:
|
||||
self._mmap.close()
|
||||
self._mmap = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user