Get shared memory based data transfer working
This commit is contained in:
parent
4528173ff5
commit
4c392426f6
@ -10,6 +10,7 @@ import re
|
|||||||
import secrets
|
import secrets
|
||||||
import shlex
|
import shlex
|
||||||
import stat
|
import stat
|
||||||
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import time
|
import time
|
||||||
@ -155,13 +156,13 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
|
|||||||
try:
|
try:
|
||||||
with SharedMemory(pwfilename, readonly=True) as shm:
|
with SharedMemory(pwfilename, readonly=True) as shm:
|
||||||
shm.unlink()
|
shm.unlink()
|
||||||
st = os.stat(shm.fileno())
|
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
|
||||||
mode = stat.S_IMODE(st.st_mode)
|
|
||||||
if st.st_uid != os.geteuid() or st.st_gid != os.getegid():
|
|
||||||
raise ValueError('Incorrect owner on pwfile')
|
raise ValueError('Incorrect owner on pwfile')
|
||||||
|
mode = stat.S_IMODE(shm.stats.st_mode)
|
||||||
if mode != stat.S_IREAD:
|
if mode != stat.S_IREAD:
|
||||||
raise ValueError('Incorrect permissions on pwfile')
|
raise ValueError('Incorrect permissions on pwfile')
|
||||||
env_data = json.loads(bytes(shm.buf))
|
sz = struct.unpack('>I', shm.read(struct.calcsize('>I')))[0]
|
||||||
|
env_data = json.loads(shm.read(sz))
|
||||||
if pw != env_data['pw']:
|
if pw != env_data['pw']:
|
||||||
raise ValueError('Incorrect password')
|
raise ValueError('Incorrect password')
|
||||||
if rq_id != request_id:
|
if rq_id != request_id:
|
||||||
@ -226,8 +227,10 @@ def bootstrap_script(
|
|||||||
os.makedirs(ddir, exist_ok=True)
|
os.makedirs(ddir, exist_ok=True)
|
||||||
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
|
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
|
||||||
db = json.dumps(data).encode('utf-8')
|
db = json.dumps(data).encode('utf-8')
|
||||||
with SharedMemory(create=True, size=len(db), mode=stat.S_IREAD) as shm:
|
sz = struct.pack('>I', len(db))
|
||||||
shm.buf[:] = db
|
with SharedMemory(size=len(db) + len(sz), mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm:
|
||||||
|
shm.write(sz)
|
||||||
|
shm.write(db)
|
||||||
atexit.register(shm.unlink)
|
atexit.register(shm.unlink)
|
||||||
replacements = {
|
replacements = {
|
||||||
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
|
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
|
||||||
|
|||||||
@ -343,7 +343,7 @@ PyInit_fast_data_types(void) {
|
|||||||
PyModule_AddIntMacro(m, FILE_TRANSFER_CODE);
|
PyModule_AddIntMacro(m, FILE_TRANSFER_CODE);
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
// Apple says its SHM_NAME_MAX but SHM_NAME_MAX is not actually declared in typical CrApple style.
|
// Apple says its SHM_NAME_MAX but SHM_NAME_MAX is not actually declared in typical CrApple style.
|
||||||
// This value is based on experimentation
|
// This value is based on experimentation and from qsharedmemory.cpp in Qt
|
||||||
PyModule_AddIntConstant(m, "SHM_NAME_MAX", 30);
|
PyModule_AddIntConstant(m, "SHM_NAME_MAX", 30);
|
||||||
#else
|
#else
|
||||||
// FreeBSD's man page says this is 1023. Linux says its PATH_MAX.
|
// FreeBSD's man page says this is 1023. Linux says its PATH_MAX.
|
||||||
|
|||||||
89
kitty/shm.py
89
kitty/shm.py
@ -5,6 +5,7 @@
|
|||||||
# multiprocessing.shared_memory. However, it is crippled in various ways, most
|
# multiprocessing.shared_memory. However, it is crippled in various ways, most
|
||||||
# notably using extremely small filenames.
|
# notably using extremely small filenames.
|
||||||
|
|
||||||
|
import errno
|
||||||
import mmap
|
import mmap
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
@ -18,62 +19,85 @@ def make_filename(prefix: str) -> str:
|
|||||||
"Create a random filename for the shared memory object."
|
"Create a random filename for the shared memory object."
|
||||||
# number of random bytes to use for name. Use a largeish value
|
# number of random bytes to use for name. Use a largeish value
|
||||||
# to make double unlink safe.
|
# to make double unlink safe.
|
||||||
safe_length = min(128, SHM_NAME_MAX)
|
|
||||||
if not prefix.startswith('/'):
|
if not prefix.startswith('/'):
|
||||||
# FreeBSD requires name to start with /
|
# FreeBSD requires name to start with /
|
||||||
prefix = '/' + prefix
|
prefix = '/' + prefix
|
||||||
nbytes = (safe_length - len(prefix)) // 2
|
plen = len(prefix.encode('utf-8'))
|
||||||
|
safe_length = min(plen + 64, SHM_NAME_MAX)
|
||||||
|
if safe_length - plen < 2:
|
||||||
|
raise OSError(errno.ENAMETOOLONG, f'SHM filename prefix {prefix} is too long')
|
||||||
|
nbytes = (safe_length - plen) // 2
|
||||||
name = prefix + secrets.token_hex(nbytes)
|
name = prefix + secrets.token_hex(nbytes)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
class SharedMemory:
|
class SharedMemory:
|
||||||
_buf: Optional[memoryview] = None
|
'''
|
||||||
|
Create or access randomly named shared memory.
|
||||||
|
|
||||||
|
WARNING: The actual size of the shared memory wmay be larger than the requested size.
|
||||||
|
'''
|
||||||
_fd: int = -1
|
_fd: int = -1
|
||||||
_name: str = ''
|
_name: str = ''
|
||||||
_mmap: Optional[mmap.mmap] = None
|
_mmap: Optional[mmap.mmap] = None
|
||||||
_size: int = 0
|
_size: int = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: Optional[str] = None, create: bool = False, size: int = 0, readonly: bool = False,
|
self, name: str = '', size: int = 0, readonly: bool = False,
|
||||||
mode: int = stat.S_IREAD | stat.S_IWRITE,
|
mode: int = stat.S_IREAD | stat.S_IWRITE,
|
||||||
prefix: str = 'kitty-'
|
prefix: str = 'kitty-'
|
||||||
):
|
):
|
||||||
if not size >= 0:
|
if size < 0:
|
||||||
raise ValueError("'size' must be a positive integer")
|
raise TypeError("'size' must be a non-negative integer")
|
||||||
if create:
|
if size and name:
|
||||||
|
raise TypeError('Cannot specify both name and size')
|
||||||
|
if not name:
|
||||||
flags = os.O_CREAT | os.O_EXCL
|
flags = os.O_CREAT | os.O_EXCL
|
||||||
if size <= 0:
|
if not size:
|
||||||
raise ValueError("'size' must be > 0")
|
raise TypeError("'size' must be > 0")
|
||||||
else:
|
else:
|
||||||
flags = 0
|
flags = 0
|
||||||
flags |= os.O_RDONLY if readonly else os.O_RDWR
|
flags |= os.O_RDONLY if readonly else os.O_RDWR
|
||||||
if name is None and not flags & os.O_EXCL:
|
|
||||||
raise ValueError("'name' can only be None if create=True")
|
|
||||||
|
|
||||||
if name is None:
|
tries = 30
|
||||||
while True:
|
while not name and tries > 0:
|
||||||
name = make_filename(prefix)
|
tries -= 1
|
||||||
try:
|
q = make_filename(prefix)
|
||||||
self._fd = shm_open(name, flags, mode)
|
try:
|
||||||
except FileExistsError:
|
self._fd = shm_open(q, flags, mode)
|
||||||
continue
|
name = q
|
||||||
break
|
except FileExistsError:
|
||||||
else:
|
continue
|
||||||
|
if tries <= 0:
|
||||||
|
raise OSError(f'Failed to create a uniquely named SHM file, try shortening the prefix from: {prefix}')
|
||||||
|
if self._fd < 0:
|
||||||
self._fd = shm_open(name, flags)
|
self._fd = shm_open(name, flags)
|
||||||
self._name = name
|
self._name = name
|
||||||
try:
|
try:
|
||||||
if create and size:
|
if flags & os.O_CREAT and size:
|
||||||
os.ftruncate(self._fd, size)
|
os.ftruncate(self._fd, size)
|
||||||
stats = os.fstat(self._fd)
|
self.stats = os.fstat(self._fd)
|
||||||
size = stats.st_size
|
size = self.stats.st_size
|
||||||
self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE)
|
self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE)
|
||||||
except OSError:
|
except OSError:
|
||||||
self.unlink()
|
self.unlink()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
self._size = size
|
self._size = size
|
||||||
self._buf = memoryview(self._mmap)
|
|
||||||
|
def read(self, sz: int = 0) -> bytes:
|
||||||
|
if sz <= 0:
|
||||||
|
sz = self.size
|
||||||
|
return self.mmap.read(sz)
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> None:
|
||||||
|
self.mmap.write(data)
|
||||||
|
|
||||||
|
def tell(self) -> int:
|
||||||
|
return self.mmap.tell()
|
||||||
|
|
||||||
|
def seek(self, pos: int, whence: int = os.SEEK_SET) -> None:
|
||||||
|
self.mmap.seek(pos, whence)
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
try:
|
try:
|
||||||
@ -95,25 +119,22 @@ class SharedMemory:
|
|||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mmap(self) -> mmap.mmap:
|
||||||
|
ans = self._mmap
|
||||||
|
if ans is None:
|
||||||
|
raise RuntimeError('Cannot access the mmap of a closed shared memory object')
|
||||||
|
return ans
|
||||||
|
|
||||||
def fileno(self) -> int:
|
def fileno(self) -> int:
|
||||||
return self._fd
|
return self._fd
|
||||||
|
|
||||||
@property
|
|
||||||
def buf(self) -> memoryview:
|
|
||||||
ans = self._buf
|
|
||||||
if ans is None:
|
|
||||||
raise RuntimeError('Cannot access the buffer of a closed shared memory object')
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
|
return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Closes access to the shared memory from this instance but does
|
"""Closes access to the shared memory from this instance but does
|
||||||
not destroy the shared memory block."""
|
not destroy the shared memory block."""
|
||||||
if self._buf is not None:
|
|
||||||
self._buf.release()
|
|
||||||
self._buf = None
|
|
||||||
if self._mmap is not None:
|
if self._mmap is not None:
|
||||||
self._mmap.close()
|
self._mmap.close()
|
||||||
self._mmap = None
|
self._mmap = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user