Use POSIX shm to pass ssh data to kitty

This commit is contained in:
Kovid Goyal 2022-03-10 06:55:21 +05:30
parent a1e4b19486
commit 20962d989f
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 26 additions and 17 deletions

View File

@ -12,7 +12,6 @@ import shlex
import stat
import sys
import tarfile
import tempfile
import time
import traceback
from base64 import standard_b64decode, standard_b64encode
@ -25,6 +24,7 @@ from typing import (
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
from kitty.fast_data_types import get_options
from kitty.shm import SharedMemory
from kitty.utils import SSHConnectionData
from .completion import complete, ssh_options
@ -153,13 +153,15 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
yield fmt_prefix('!invalid ssh data request message')
else:
try:
with open(os.path.join(cache_dir(), 'ssh', pwfilename), 'rb') as f:
os.unlink(f.name)
st = os.stat(f.fileno())
with SharedMemory(pwfilename, readonly=True) as shm:
shm.unlink()
st = os.stat(shm.fileno())
mode = stat.S_IMODE(st.st_mode)
if st.st_uid != os.geteuid() or st.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile')
if mode != stat.S_IREAD:
raise ValueError('Incorrect permissions on pwfile')
env_data = json.load(f)
env_data = json.loads(bytes(shm.buf))
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
@ -222,13 +224,13 @@ def bootstrap_script(
pw = secrets.token_hex()
ddir = os.path.join(cache_dir(), 'ssh')
os.makedirs(ddir, exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='ssh-kitten-pw-', suffix='.json', dir=ddir, delete=False) as tf:
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
tf.write(json.dumps(data).encode('utf-8'))
os.fchmod(tf.fileno(), stat.S_IREAD)
atexit.register(safe_remove, tf.name)
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
db = json.dumps(data).encode('utf-8')
with SharedMemory(create=True, size=len(db), mode=stat.S_IREAD) as shm:
shm.buf[:] = db
atexit.register(shm.unlink)
replacements = {
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
'REQUEST_ID': request_id
}
return prepare_script(ans, replacements), replacements

View File

@ -21,6 +21,8 @@ def make_filename(safe_length: int = 14, prefix: str = '/ky-') -> str:
class SharedMemory:
_buf: Optional[memoryview] = None
_fd: int = -1
def __init__(self, name: Optional[str] = None, create: bool = False, size: int = 0, readonly: bool = False, mode: int = 0o600):
if not size >= 0:
@ -30,7 +32,8 @@ class SharedMemory:
if size <= 0:
raise ValueError("'size' must be > 0")
else:
flags = os.O_RDONLY if readonly else os.O_RDWR
flags = 0
flags |= os.O_RDONLY if readonly else os.O_RDWR
if name is None and not flags & os.O_EXCL:
raise ValueError("'name' can only be None if create=True")
@ -43,19 +46,21 @@ class SharedMemory:
continue
self._name = name
break
else:
self._fd = shm_open(name, flags)
self._name = name
try:
if create and size:
os.ftruncate(self._fd, size)
stats = os.fstat(self._fd)
size = stats.st_size
self._mmap = mmap.mmap(self._fd, size)
self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE)
except OSError:
self.unlink()
raise
self.size = size
self._buf: Optional[memoryview] = memoryview(self._mmap)
self._buf = memoryview(self._mmap)
def __del__(self) -> None:
try:
@ -73,7 +78,6 @@ class SharedMemory:
def name(self) -> str:
return self._name
@property
def fileno(self) -> int:
return self._fd
@ -93,7 +97,7 @@ class SharedMemory:
if self._buf is not None:
self._buf.release()
self._buf = None
if self._mmap is not None:
if getattr(self, '_mmap', None) is not None:
self._mmap.close()
if self._fd >= 0:
os.close(self._fd)
@ -106,5 +110,8 @@ class SharedMemory:
called once (and only once) across all processes which have access
to the shared memory block."""
if self._name:
shm_unlink(self._name)
try:
shm_unlink(self._name)
except FileNotFoundError:
pass
self._name = ''