From 20962d989f6f2feda5aba3277f5e6b7c6b479e6d Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Thu, 10 Mar 2022 06:55:21 +0530 Subject: [PATCH] Use POSIX shm to pass ssh data to kitty --- kittens/ssh/main.py | 24 +++++++++++++----------- kitty/shm.py | 19 +++++++++++++------ 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index 28585a9d4..d0522b7bf 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -12,7 +12,6 @@ import shlex import stat import sys import tarfile -import tempfile import time import traceback from base64 import standard_b64decode, standard_b64encode @@ -25,6 +24,7 @@ from typing import ( from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir from kitty.fast_data_types import get_options +from kitty.shm import SharedMemory from kitty.utils import SSHConnectionData from .completion import complete, ssh_options @@ -153,13 +153,15 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]: yield fmt_prefix('!invalid ssh data request message') else: try: - with open(os.path.join(cache_dir(), 'ssh', pwfilename), 'rb') as f: - os.unlink(f.name) - st = os.stat(f.fileno()) + with SharedMemory(pwfilename, readonly=True) as shm: + shm.unlink() + st = os.stat(shm.fileno()) mode = stat.S_IMODE(st.st_mode) + if st.st_uid != os.geteuid() or st.st_gid != os.getegid(): + raise ValueError('Incorrect owner on pwfile') if mode != stat.S_IREAD: raise ValueError('Incorrect permissions on pwfile') - env_data = json.load(f) + env_data = json.loads(bytes(shm.buf)) if pw != env_data['pw']: raise ValueError('Incorrect password') if rq_id != request_id: @@ -222,13 +224,13 @@ def bootstrap_script( pw = secrets.token_hex() ddir = os.path.join(cache_dir(), 'ssh') os.makedirs(ddir, exist_ok=True) - with tempfile.NamedTemporaryFile(prefix='ssh-kitten-pw-', suffix='.json', dir=ddir, delete=False) as tf: - data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname} - tf.write(json.dumps(data).encode('utf-8')) - os.fchmod(tf.fileno(), stat.S_IREAD) - atexit.register(safe_remove, tf.name) + data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname} + db = json.dumps(data).encode('utf-8') + with SharedMemory(create=True, size=len(db), mode=stat.S_IREAD) as shm: + shm.buf[:] = db + atexit.register(shm.unlink) replacements = { - 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script, + 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script, 'REQUEST_ID': request_id } return prepare_script(ans, replacements), replacements diff --git a/kitty/shm.py b/kitty/shm.py index 279de1eb8..4c610a17d 100644 --- a/kitty/shm.py +++ b/kitty/shm.py @@ -21,6 +21,8 @@ def make_filename(safe_length: int = 14, prefix: str = '/ky-') -> str: class SharedMemory: + _buf: Optional[memoryview] = None + _fd: int = -1 def __init__(self, name: Optional[str] = None, create: bool = False, size: int = 0, readonly: bool = False, mode: int = 0o600): if not size >= 0: @@ -30,7 +32,8 @@ class SharedMemory: if size <= 0: raise ValueError("'size' must be > 0") else: - flags = os.O_RDONLY if readonly else os.O_RDWR + flags = 0 + flags |= os.O_RDONLY if readonly else os.O_RDWR if name is None and not flags & os.O_EXCL: raise ValueError("'name' can only be None if create=True") @@ -43,19 +46,21 @@ class SharedMemory: continue self._name = name break + else: + self._fd = shm_open(name, flags) self._name = name try: if create and size: os.ftruncate(self._fd, size) stats = os.fstat(self._fd) size = stats.st_size - self._mmap = mmap.mmap(self._fd, size) + self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE) except OSError: self.unlink() raise self.size = size - self._buf: Optional[memoryview] = memoryview(self._mmap) + self._buf = memoryview(self._mmap) def __del__(self) -> None: try: @@ -73,7 +78,6 @@ class SharedMemory: def name(self) -> str: return self._name - @property def fileno(self) -> int: return self._fd @@ -93,7 +97,7 @@ class SharedMemory: if self._buf is not None: self._buf.release() self._buf = None - if self._mmap is not None: + if getattr(self, '_mmap', None) is not None: self._mmap.close() if self._fd >= 0: os.close(self._fd) @@ -106,5 +110,8 @@ class SharedMemory: called once (and only once) across all processes which have access to the shared memory block.""" if self._name: - shm_unlink(self._name) + try: + shm_unlink(self._name) + except FileNotFoundError: + pass self._name = ''