Get shared memory based data transfer working

This commit is contained in:
Kovid Goyal 2022-03-10 10:46:04 +05:30
parent 4528173ff5
commit 4c392426f6
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 65 additions and 41 deletions

View File

@ -10,6 +10,7 @@ import re
import secrets import secrets
import shlex import shlex
import stat import stat
import struct
import sys import sys
import tarfile import tarfile
import time import time
@ -155,13 +156,13 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
try: try:
with SharedMemory(pwfilename, readonly=True) as shm: with SharedMemory(pwfilename, readonly=True) as shm:
shm.unlink() shm.unlink()
st = os.stat(shm.fileno()) if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
mode = stat.S_IMODE(st.st_mode)
if st.st_uid != os.geteuid() or st.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile') raise ValueError('Incorrect owner on pwfile')
mode = stat.S_IMODE(shm.stats.st_mode)
if mode != stat.S_IREAD: if mode != stat.S_IREAD:
raise ValueError('Incorrect permissions on pwfile') raise ValueError('Incorrect permissions on pwfile')
env_data = json.loads(bytes(shm.buf)) sz = struct.unpack('>I', shm.read(struct.calcsize('>I')))[0]
env_data = json.loads(shm.read(sz))
if pw != env_data['pw']: if pw != env_data['pw']:
raise ValueError('Incorrect password') raise ValueError('Incorrect password')
if rq_id != request_id: if rq_id != request_id:
@ -226,8 +227,10 @@ def bootstrap_script(
os.makedirs(ddir, exist_ok=True) os.makedirs(ddir, exist_ok=True)
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname} data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
db = json.dumps(data).encode('utf-8') db = json.dumps(data).encode('utf-8')
with SharedMemory(create=True, size=len(db), mode=stat.S_IREAD) as shm: sz = struct.pack('>I', len(db))
shm.buf[:] = db with SharedMemory(size=len(db) + len(sz), mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm:
shm.write(sz)
shm.write(db)
atexit.register(shm.unlink) atexit.register(shm.unlink)
replacements = { replacements = {
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script, 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,

View File

@ -343,7 +343,7 @@ PyInit_fast_data_types(void) {
PyModule_AddIntMacro(m, FILE_TRANSFER_CODE); PyModule_AddIntMacro(m, FILE_TRANSFER_CODE);
#ifdef __APPLE__ #ifdef __APPLE__
// Apple says its SHM_NAME_MAX but SHM_NAME_MAX is not actually declared in typical CrApple style. // Apple says its SHM_NAME_MAX but SHM_NAME_MAX is not actually declared in typical CrApple style.
// This value is based on experimentation // This value is based on experimentation and from qsharedmemory.cpp in Qt
PyModule_AddIntConstant(m, "SHM_NAME_MAX", 30); PyModule_AddIntConstant(m, "SHM_NAME_MAX", 30);
#else #else
// FreeBSD's man page says this is 1023. Linux says its PATH_MAX. // FreeBSD's man page says this is 1023. Linux says its PATH_MAX.

View File

@ -5,6 +5,7 @@
# multiprocessing.shared_memory. However, it is crippled in various ways, most # multiprocessing.shared_memory. However, it is crippled in various ways, most
# notably using extremely small filenames. # notably using extremely small filenames.
import errno
import mmap import mmap
import os import os
import secrets import secrets
@ -18,62 +19,85 @@ def make_filename(prefix: str) -> str:
"Create a random filename for the shared memory object." "Create a random filename for the shared memory object."
# number of random bytes to use for name. Use a largeish value # number of random bytes to use for name. Use a largeish value
# to make double unlink safe. # to make double unlink safe.
safe_length = min(128, SHM_NAME_MAX)
if not prefix.startswith('/'): if not prefix.startswith('/'):
# FreeBSD requires name to start with / # FreeBSD requires name to start with /
prefix = '/' + prefix prefix = '/' + prefix
nbytes = (safe_length - len(prefix)) // 2 plen = len(prefix.encode('utf-8'))
safe_length = min(plen + 64, SHM_NAME_MAX)
if safe_length - plen < 2:
raise OSError(errno.ENAMETOOLONG, f'SHM filename prefix {prefix} is too long')
nbytes = (safe_length - plen) // 2
name = prefix + secrets.token_hex(nbytes) name = prefix + secrets.token_hex(nbytes)
return name return name
class SharedMemory: class SharedMemory:
_buf: Optional[memoryview] = None '''
Create or access randomly named shared memory.
WARNING: The actual size of the shared memory wmay be larger than the requested size.
'''
_fd: int = -1 _fd: int = -1
_name: str = '' _name: str = ''
_mmap: Optional[mmap.mmap] = None _mmap: Optional[mmap.mmap] = None
_size: int = 0 _size: int = 0
def __init__( def __init__(
self, name: Optional[str] = None, create: bool = False, size: int = 0, readonly: bool = False, self, name: str = '', size: int = 0, readonly: bool = False,
mode: int = stat.S_IREAD | stat.S_IWRITE, mode: int = stat.S_IREAD | stat.S_IWRITE,
prefix: str = 'kitty-' prefix: str = 'kitty-'
): ):
if not size >= 0: if size < 0:
raise ValueError("'size' must be a positive integer") raise TypeError("'size' must be a non-negative integer")
if create: if size and name:
raise TypeError('Cannot specify both name and size')
if not name:
flags = os.O_CREAT | os.O_EXCL flags = os.O_CREAT | os.O_EXCL
if size <= 0: if not size:
raise ValueError("'size' must be > 0") raise TypeError("'size' must be > 0")
else: else:
flags = 0 flags = 0
flags |= os.O_RDONLY if readonly else os.O_RDWR flags |= os.O_RDONLY if readonly else os.O_RDWR
if name is None and not flags & os.O_EXCL:
raise ValueError("'name' can only be None if create=True")
if name is None: tries = 30
while True: while not name and tries > 0:
name = make_filename(prefix) tries -= 1
q = make_filename(prefix)
try: try:
self._fd = shm_open(name, flags, mode) self._fd = shm_open(q, flags, mode)
name = q
except FileExistsError: except FileExistsError:
continue continue
break if tries <= 0:
else: raise OSError(f'Failed to create a uniquely named SHM file, try shortening the prefix from: {prefix}')
if self._fd < 0:
self._fd = shm_open(name, flags) self._fd = shm_open(name, flags)
self._name = name self._name = name
try: try:
if create and size: if flags & os.O_CREAT and size:
os.ftruncate(self._fd, size) os.ftruncate(self._fd, size)
stats = os.fstat(self._fd) self.stats = os.fstat(self._fd)
size = stats.st_size size = self.stats.st_size
self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE) self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE)
except OSError: except OSError:
self.unlink() self.unlink()
raise raise
self._size = size self._size = size
self._buf = memoryview(self._mmap)
def read(self, sz: int = 0) -> bytes:
if sz <= 0:
sz = self.size
return self.mmap.read(sz)
def write(self, data: bytes) -> None:
self.mmap.write(data)
def tell(self) -> int:
return self.mmap.tell()
def seek(self, pos: int, whence: int = os.SEEK_SET) -> None:
self.mmap.seek(pos, whence)
def __del__(self) -> None: def __del__(self) -> None:
try: try:
@ -95,25 +119,22 @@ class SharedMemory:
def name(self) -> str: def name(self) -> str:
return self._name return self._name
@property
def mmap(self) -> mmap.mmap:
ans = self._mmap
if ans is None:
raise RuntimeError('Cannot access the mmap of a closed shared memory object')
return ans
def fileno(self) -> int: def fileno(self) -> int:
return self._fd return self._fd
@property
def buf(self) -> memoryview:
ans = self._buf
if ans is None:
raise RuntimeError('Cannot access the buffer of a closed shared memory object')
return ans
def __repr__(self) -> str: def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.name!r}, size={self.size})' return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
def close(self) -> None: def close(self) -> None:
"""Closes access to the shared memory from this instance but does """Closes access to the shared memory from this instance but does
not destroy the shared memory block.""" not destroy the shared memory block."""
if self._buf is not None:
self._buf.release()
self._buf = None
if self._mmap is not None: if self._mmap is not None:
self._mmap.close() self._mmap.close()
self._mmap = None self._mmap = None