Get shared memory based data transfer working

This commit is contained in:
Kovid Goyal 2022-03-10 10:46:04 +05:30
parent 4528173ff5
commit 4c392426f6
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 65 additions and 41 deletions

View File

@ -10,6 +10,7 @@ import re
import secrets
import shlex
import stat
import struct
import sys
import tarfile
import time
@ -155,13 +156,13 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
try:
with SharedMemory(pwfilename, readonly=True) as shm:
shm.unlink()
st = os.stat(shm.fileno())
mode = stat.S_IMODE(st.st_mode)
if st.st_uid != os.geteuid() or st.st_gid != os.getegid():
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile')
mode = stat.S_IMODE(shm.stats.st_mode)
if mode != stat.S_IREAD:
raise ValueError('Incorrect permissions on pwfile')
env_data = json.loads(bytes(shm.buf))
sz = struct.unpack('>I', shm.read(struct.calcsize('>I')))[0]
env_data = json.loads(shm.read(sz))
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
@ -226,8 +227,10 @@ def bootstrap_script(
os.makedirs(ddir, exist_ok=True)
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
db = json.dumps(data).encode('utf-8')
with SharedMemory(create=True, size=len(db), mode=stat.S_IREAD) as shm:
shm.buf[:] = db
sz = struct.pack('>I', len(db))
with SharedMemory(size=len(db) + len(sz), mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm:
shm.write(sz)
shm.write(db)
atexit.register(shm.unlink)
replacements = {
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,

View File

@ -343,7 +343,7 @@ PyInit_fast_data_types(void) {
PyModule_AddIntMacro(m, FILE_TRANSFER_CODE);
#ifdef __APPLE__
// Apple says its SHM_NAME_MAX but SHM_NAME_MAX is not actually declared in typical CrApple style.
// This value is based on experimentation
// This value is based on experimentation and from qsharedmemory.cpp in Qt
PyModule_AddIntConstant(m, "SHM_NAME_MAX", 30);
#else
// FreeBSD's man page says this is 1023. Linux says its PATH_MAX.

View File

@ -5,6 +5,7 @@
# multiprocessing.shared_memory. However, it is crippled in various ways, most
# notably using extremely small filenames.
import errno
import mmap
import os
import secrets
@ -18,62 +19,85 @@ def make_filename(prefix: str) -> str:
"Create a random filename for the shared memory object."
# number of random bytes to use for name. Use a largeish value
# to make double unlink safe.
safe_length = min(128, SHM_NAME_MAX)
if not prefix.startswith('/'):
# FreeBSD requires name to start with /
prefix = '/' + prefix
nbytes = (safe_length - len(prefix)) // 2
plen = len(prefix.encode('utf-8'))
safe_length = min(plen + 64, SHM_NAME_MAX)
if safe_length - plen < 2:
raise OSError(errno.ENAMETOOLONG, f'SHM filename prefix {prefix} is too long')
nbytes = (safe_length - plen) // 2
name = prefix + secrets.token_hex(nbytes)
return name
class SharedMemory:
_buf: Optional[memoryview] = None
'''
Create or access randomly named shared memory.
WARNING: The actual size of the shared memory wmay be larger than the requested size.
'''
_fd: int = -1
_name: str = ''
_mmap: Optional[mmap.mmap] = None
_size: int = 0
def __init__(
self, name: Optional[str] = None, create: bool = False, size: int = 0, readonly: bool = False,
self, name: str = '', size: int = 0, readonly: bool = False,
mode: int = stat.S_IREAD | stat.S_IWRITE,
prefix: str = 'kitty-'
):
if not size >= 0:
raise ValueError("'size' must be a positive integer")
if create:
if size < 0:
raise TypeError("'size' must be a non-negative integer")
if size and name:
raise TypeError('Cannot specify both name and size')
if not name:
flags = os.O_CREAT | os.O_EXCL
if size <= 0:
raise ValueError("'size' must be > 0")
if not size:
raise TypeError("'size' must be > 0")
else:
flags = 0
flags |= os.O_RDONLY if readonly else os.O_RDWR
if name is None and not flags & os.O_EXCL:
raise ValueError("'name' can only be None if create=True")
if name is None:
while True:
name = make_filename(prefix)
try:
self._fd = shm_open(name, flags, mode)
except FileExistsError:
continue
break
else:
tries = 30
while not name and tries > 0:
tries -= 1
q = make_filename(prefix)
try:
self._fd = shm_open(q, flags, mode)
name = q
except FileExistsError:
continue
if tries <= 0:
raise OSError(f'Failed to create a uniquely named SHM file, try shortening the prefix from: {prefix}')
if self._fd < 0:
self._fd = shm_open(name, flags)
self._name = name
try:
if create and size:
if flags & os.O_CREAT and size:
os.ftruncate(self._fd, size)
stats = os.fstat(self._fd)
size = stats.st_size
self.stats = os.fstat(self._fd)
size = self.stats.st_size
self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE)
except OSError:
self.unlink()
raise
self._size = size
self._buf = memoryview(self._mmap)
def read(self, sz: int = 0) -> bytes:
if sz <= 0:
sz = self.size
return self.mmap.read(sz)
def write(self, data: bytes) -> None:
self.mmap.write(data)
def tell(self) -> int:
return self.mmap.tell()
def seek(self, pos: int, whence: int = os.SEEK_SET) -> None:
self.mmap.seek(pos, whence)
def __del__(self) -> None:
try:
@ -95,25 +119,22 @@ class SharedMemory:
def name(self) -> str:
return self._name
@property
def mmap(self) -> mmap.mmap:
ans = self._mmap
if ans is None:
raise RuntimeError('Cannot access the mmap of a closed shared memory object')
return ans
def fileno(self) -> int:
return self._fd
@property
def buf(self) -> memoryview:
ans = self._buf
if ans is None:
raise RuntimeError('Cannot access the buffer of a closed shared memory object')
return ans
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
def close(self) -> None:
"""Closes access to the shared memory from this instance but does
not destroy the shared memory block."""
if self._buf is not None:
self._buf.release()
self._buf = None
if self._mmap is not None:
self._mmap.close()
self._mmap = None