Use POSIX shm to pass ssh data to kitty
This commit is contained in:
parent
a1e4b19486
commit
20962d989f
@ -12,7 +12,6 @@ import shlex
|
||||
import stat
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
@ -25,6 +24,7 @@ from typing import (
|
||||
|
||||
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
|
||||
from kitty.fast_data_types import get_options
|
||||
from kitty.shm import SharedMemory
|
||||
from kitty.utils import SSHConnectionData
|
||||
|
||||
from .completion import complete, ssh_options
|
||||
@ -153,13 +153,15 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
|
||||
yield fmt_prefix('!invalid ssh data request message')
|
||||
else:
|
||||
try:
|
||||
with open(os.path.join(cache_dir(), 'ssh', pwfilename), 'rb') as f:
|
||||
os.unlink(f.name)
|
||||
st = os.stat(f.fileno())
|
||||
with SharedMemory(pwfilename, readonly=True) as shm:
|
||||
shm.unlink()
|
||||
st = os.stat(shm.fileno())
|
||||
mode = stat.S_IMODE(st.st_mode)
|
||||
if st.st_uid != os.geteuid() or st.st_gid != os.getegid():
|
||||
raise ValueError('Incorrect owner on pwfile')
|
||||
if mode != stat.S_IREAD:
|
||||
raise ValueError('Incorrect permissions on pwfile')
|
||||
env_data = json.load(f)
|
||||
env_data = json.loads(bytes(shm.buf))
|
||||
if pw != env_data['pw']:
|
||||
raise ValueError('Incorrect password')
|
||||
if rq_id != request_id:
|
||||
@ -222,13 +224,13 @@ def bootstrap_script(
|
||||
pw = secrets.token_hex()
|
||||
ddir = os.path.join(cache_dir(), 'ssh')
|
||||
os.makedirs(ddir, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(prefix='ssh-kitten-pw-', suffix='.json', dir=ddir, delete=False) as tf:
|
||||
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
|
||||
tf.write(json.dumps(data).encode('utf-8'))
|
||||
os.fchmod(tf.fileno(), stat.S_IREAD)
|
||||
atexit.register(safe_remove, tf.name)
|
||||
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
|
||||
db = json.dumps(data).encode('utf-8')
|
||||
with SharedMemory(create=True, size=len(db), mode=stat.S_IREAD) as shm:
|
||||
shm.buf[:] = db
|
||||
atexit.register(shm.unlink)
|
||||
replacements = {
|
||||
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
|
||||
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
|
||||
'REQUEST_ID': request_id
|
||||
}
|
||||
return prepare_script(ans, replacements), replacements
|
||||
|
||||
19
kitty/shm.py
19
kitty/shm.py
@ -21,6 +21,8 @@ def make_filename(safe_length: int = 14, prefix: str = '/ky-') -> str:
|
||||
|
||||
|
||||
class SharedMemory:
|
||||
_buf: Optional[memoryview] = None
|
||||
_fd: int = -1
|
||||
|
||||
def __init__(self, name: Optional[str] = None, create: bool = False, size: int = 0, readonly: bool = False, mode: int = 0o600):
|
||||
if not size >= 0:
|
||||
@ -30,7 +32,8 @@ class SharedMemory:
|
||||
if size <= 0:
|
||||
raise ValueError("'size' must be > 0")
|
||||
else:
|
||||
flags = os.O_RDONLY if readonly else os.O_RDWR
|
||||
flags = 0
|
||||
flags |= os.O_RDONLY if readonly else os.O_RDWR
|
||||
if name is None and not flags & os.O_EXCL:
|
||||
raise ValueError("'name' can only be None if create=True")
|
||||
|
||||
@ -43,19 +46,21 @@ class SharedMemory:
|
||||
continue
|
||||
self._name = name
|
||||
break
|
||||
else:
|
||||
self._fd = shm_open(name, flags)
|
||||
self._name = name
|
||||
try:
|
||||
if create and size:
|
||||
os.ftruncate(self._fd, size)
|
||||
stats = os.fstat(self._fd)
|
||||
size = stats.st_size
|
||||
self._mmap = mmap.mmap(self._fd, size)
|
||||
self._mmap = mmap.mmap(self._fd, size, access=mmap.ACCESS_READ if readonly else mmap.ACCESS_WRITE)
|
||||
except OSError:
|
||||
self.unlink()
|
||||
raise
|
||||
|
||||
self.size = size
|
||||
self._buf: Optional[memoryview] = memoryview(self._mmap)
|
||||
self._buf = memoryview(self._mmap)
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
@ -73,7 +78,6 @@ class SharedMemory:
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def fileno(self) -> int:
|
||||
return self._fd
|
||||
|
||||
@ -93,7 +97,7 @@ class SharedMemory:
|
||||
if self._buf is not None:
|
||||
self._buf.release()
|
||||
self._buf = None
|
||||
if self._mmap is not None:
|
||||
if getattr(self, '_mmap', None) is not None:
|
||||
self._mmap.close()
|
||||
if self._fd >= 0:
|
||||
os.close(self._fd)
|
||||
@ -106,5 +110,8 @@ class SharedMemory:
|
||||
called once (and only once) across all processes which have access
|
||||
to the shared memory block."""
|
||||
if self._name:
|
||||
shm_unlink(self._name)
|
||||
try:
|
||||
shm_unlink(self._name)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
self._name = ''
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user