diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index c4435c9ab..7dcc5ef96 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -24,8 +24,8 @@ from contextlib import contextmanager, suppress from getpass import getuser from select import select from typing import ( - Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set, Tuple, - Union + Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set, + Tuple, Union, cast ) from kittens.tui.operations import restore_colors, save_colors @@ -36,7 +36,10 @@ from kitty.constants import ( from kitty.options.types import Options from kitty.shm import SharedMemory from kitty.types import run_once -from kitty.utils import SSHConnectionData, set_echo as turn_off_echo, expandvars, resolve_abs_or_config_path +from kitty.utils import ( + SSHConnectionData, expandvars, resolve_abs_or_config_path, + set_echo as turn_off_echo +) from .completion import complete, ssh_options from .config import init_config @@ -58,16 +61,45 @@ def is_kitten_cmdline(q: List[str]) -> bool: return q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh'] -def set_cwd_in_cmdline(cwd: str, argv: List[str]) -> None: +def patch_cmdline(key: str, val: str, argv: List[str]) -> None: for i, arg in enumerate(tuple(argv)): - if arg.startswith('--kitten=cwd'): - argv[i] = f'--kitten=cwd={cwd}' + if arg.startswith(f'--kitten={key}='): + argv[i] = f'--kitten={key}={val}' return - elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith('cwd=') or arg.startswith('cwd ')): - argv[i] = cwd + elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith(f'{key}=') or arg.startswith(f'{key} ')): + argv[i] = val return idx = argv.index('ssh') - argv.insert(idx + 1, f'--kitten=cwd={cwd}') + argv.insert(idx + 1, f'--kitten={key}={val}') + + +def set_cwd_in_cmdline(cwd: str, argv: List[str]) -> None: + patch_cmdline('cwd', cwd, argv) + + +def create_shared_memory(data: Any, prefix: str) -> str: + from kitty.shm import SharedMemory + db = json.dumps(data).encode('utf-8') + with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, mode=stat.S_IREAD, prefix=prefix) as shm: + shm.write_data_with_size(db) + shm.flush() + atexit.register(shm.unlink) + return shm.name + + +def read_data_from_shared_memory(shm_name: str) -> Any: + with SharedMemory(shm_name, readonly=True) as shm: + shm.unlink() + if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid(): + raise ValueError('Incorrect owner on pwfile') + mode = stat.S_IMODE(shm.stats.st_mode) + if mode != stat.S_IREAD: + raise ValueError('Incorrect permissions on pwfile') + return json.loads(shm.read_data_with_size()) + + +def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None: + patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv) # See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html @@ -105,7 +137,7 @@ def kitty_opts() -> Options: return create_default_opts() -def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz') -> bytes: +def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz', literal_env: Dict[str, str] = {}) -> bytes: def normalize_tarinfo(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo: tarinfo.uname = tarinfo.gname = '' @@ -146,11 +178,12 @@ def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: st 'TERM': os.environ.get('TERM') or kitty_opts().term, 'COLORTERM': 'truecolor', } + env.update(literal_env) + env.update(ssh_opts.env) for q in ('KITTY_WINDOW_ID', 'WINDOWID'): val = os.environ.get(q) if val is not None: env[q] = val - env.update(ssh_opts.env) env['KITTY_SHELL_INTEGRATION'] = ksi or DELETE_ENV_VAR env['KITTY_SSH_KITTEN_DATA_DIR'] = ssh_opts.remote_dir if ssh_opts.login_shell: @@ -196,18 +229,11 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]: yield b'invalid ssh data request message\n' else: try: - with SharedMemory(pwfilename, readonly=True) as shm: - shm.unlink() - if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid(): - raise ValueError('Incorrect owner on pwfile') - mode = stat.S_IMODE(shm.stats.st_mode) - if mode != stat.S_IREAD: - raise ValueError('Incorrect permissions on pwfile') - env_data = json.loads(shm.read_data_with_size()) - if pw != env_data['pw']: - raise ValueError('Incorrect password') - if rq_id != request_id: - raise ValueError('Incorrect request id') + env_data = read_data_from_shared_memory(pwfilename) + if pw != env_data['pw']: + raise ValueError('Incorrect password') + if rq_id != request_id: + raise ValueError('Incorrect request id') except Exception as e: traceback.print_exc() yield f'{e}\n'.encode('utf-8') @@ -267,8 +293,8 @@ def prepare_export_home_cmd(ssh_opts: SSHOptions, is_python: bool) -> str: def bootstrap_script( ssh_opts: SSHOptions, script_type: str = 'sh', remote_args: Sequence[str] = (), test_script: str = '', request_id: Optional[str] = None, cli_hostname: str = '', cli_uname: str = '', - request_data: bool = False, echo_on: bool = True -) -> Tuple[str, Dict[str, str], SharedMemory]: + request_data: bool = False, echo_on: bool = True, literal_env: Dict[str, str] = {} +) -> Tuple[str, Dict[str, str], str]: if request_id is None: request_id = os.environ['KITTY_PID'] + '-' + os.environ['KITTY_WINDOW_ID'] is_python = script_type == 'py' @@ -277,14 +303,10 @@ def bootstrap_script( with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f: ans = f.read() pw = secrets.token_hex() - tfd = standard_b64encode(make_tarfile(ssh_opts, dict(os.environ), 'gz' if script_type == 'sh' else 'bz2')).decode('ascii') + tfd = standard_b64encode(make_tarfile(ssh_opts, dict(os.environ), 'gz' if script_type == 'sh' else 'bz2', literal_env=literal_env)).decode('ascii') data = {'pw': pw, 'opts': ssh_opts._asdict(), 'hostname': cli_hostname, 'uname': cli_uname, 'tarfile': tfd} - db = json.dumps(data) - with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm: - shm.write_data_with_size(db) - shm.flush() - atexit.register(shm.unlink) - sensitive_data = {'REQUEST_ID': request_id, 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name} + shm_name = create_shared_memory(data, prefix=f'kssh-{os.getpid()}-') + sensitive_data = {'REQUEST_ID': request_id, 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm_name} replacements = { 'EXPORT_HOME_CMD': export_home_cmd, 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script, @@ -294,7 +316,7 @@ def bootstrap_script( if request_data: sd.update(sensitive_data) replacements.update(sensitive_data) - return prepare_script(ans, sd, script_type), replacements, shm + return prepare_script(ans, sd, script_type), replacements, shm_name def get_ssh_cli() -> Tuple[Set[str], Set[str]]: @@ -499,15 +521,15 @@ def wrap_bootstrap_script(sh_script: str, interpreter: str) -> List[str]: def get_remote_command( remote_args: List[str], ssh_opts: SSHOptions, cli_hostname: str = '', cli_uname: str = '', - echo_on: bool = True, request_data: bool = False + echo_on: bool = True, request_data: bool = False, literal_env: Dict[str, str] = {} ) -> Tuple[List[str], Dict[str, str], str]: interpreter = ssh_opts.interpreter q = os.path.basename(interpreter).lower() is_python = 'python' in q - sh_script, replacements, shm = bootstrap_script( - ssh_opts, script_type='py' if is_python else 'sh', remote_args=remote_args, + sh_script, replacements, shm_name = bootstrap_script( + ssh_opts, script_type='py' if is_python else 'sh', remote_args=remote_args, literal_env=literal_env, cli_hostname=cli_hostname, cli_uname=cli_uname, echo_on=echo_on, request_data=request_data) - return wrap_bootstrap_script(sh_script, interpreter), replacements, shm.name + return wrap_bootstrap_script(sh_script, interpreter), replacements, shm_name def connection_sharing_args(opts: SSHOptions, kitty_pid: int) -> List[str]: @@ -602,7 +624,9 @@ def drain_potential_tty_garbage(p: 'subprocess.Popen[bytes]', data_request: str) def change_colors(color_scheme: str) -> bool: if not color_scheme: return False - from kittens.themes.collection import load_themes, NoCacheFound, text_as_opts + from kittens.themes.collection import ( + NoCacheFound, load_themes, text_as_opts + ) from kittens.themes.main import colors_as_escape_codes if color_scheme.endswith('.conf'): conf_file = resolve_abs_or_config_path(color_scheme) @@ -627,6 +651,14 @@ def change_colors(color_scheme: str) -> bool: return True +def add_cloned_env(shm_name: str) -> Dict[str, str]: + try: + return cast(Dict[str, str], read_data_from_shared_memory(shm_name)) + except FileNotFoundError: + pass + return {} + + def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple[str, ...]) -> NoReturn: cmd = [ssh_exe()] + ssh_args hostname, remote_args = server_args[0], server_args[1:] @@ -646,12 +678,16 @@ def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple else: hostname_for_match = hostname hostname_for_match = hostname_for_match.split('@', 1)[-1].split(':', 1)[0] - overrides = [] + overrides: List[str] = [] + literal_env: Dict[str, str] = {} pat = re.compile(r'^([a-zA-Z0-9_]+)[ \t]*=') for i, a in enumerate(found_extra_args): if i % 2 == 1: aq = pat.sub(r'\1 ', a.lstrip()) - if aq.split(maxsplit=1)[0] != 'hostname': + key = aq.split(maxsplit=1)[0] + if key == 'clone_env': + literal_env = add_cloned_env(aq.split(maxsplit=1)[1]) + elif key != 'hostname': overrides.append(aq) if overrides: overrides.insert(0, f'hostname {uname}@{hostname_for_match}') @@ -677,7 +713,7 @@ def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple need_to_request_data = False with restore_terminal_state() as echo_on: rcmd, replacements, shm_name = get_remote_command( - remote_args, host_opts, hostname_for_match, uname, echo_on, request_data=need_to_request_data) + remote_args, host_opts, hostname_for_match, uname, echo_on, request_data=need_to_request_data, literal_env=literal_env) cmd += rcmd colors_changed = change_colors(host_opts.color_scheme) try: diff --git a/kitty/launch.py b/kitty/launch.py index 0b01e73c5..6bd6d3a5d 100644 --- a/kitty/launch.py +++ b/kitty/launch.py @@ -361,7 +361,7 @@ def launch( if opts.os_window_title == 'current': tm = boss.active_tab_manager opts.os_window_title = get_os_window_title(tm.os_window_id) if tm else None - if base_env: + if base_env is not None: env = base_env.copy() env.update(get_env(opts)) else: @@ -501,7 +501,7 @@ def clone_and_launch(msg: str, window: Window) -> None: from .child import cmdline_of_process args = [] - env: Dict[str, str] = {} + env: Optional[Dict[str, str]] = None cwd = '' pid = -1 @@ -514,6 +514,7 @@ def clone_and_launch(msg: str, window: Window) -> None: if k == 'a': args.append(v) elif k == 'env': + env = {} for line in v.split('\0'): if line: try: @@ -535,9 +536,12 @@ def clone_and_launch(msg: str, window: Window) -> None: cmdline = list(window.child.argv) ssh_kitten_cmdline = window.ssh_kitten_cmdline() if ssh_kitten_cmdline: - from kittens.ssh.main import set_cwd_in_cmdline + from kittens.ssh.main import set_cwd_in_cmdline, set_env_in_cmdline cmdline[:] = ssh_kitten_cmdline if opts.cwd: set_cwd_in_cmdline(opts.cwd, cmdline) opts.cwd = None + if env: + set_env_in_cmdline(env, cmdline) + env = None launch(get_boss(), opts, cmdline, base_env=env) diff --git a/kitty_tests/ssh.py b/kitty_tests/ssh.py index 1f4241466..c46641a37 100644 --- a/kitty_tests/ssh.py +++ b/kitty_tests/ssh.py @@ -7,6 +7,7 @@ import os import shutil import tempfile from functools import lru_cache +from contextlib import suppress from kittens.ssh.config import load_config from kittens.ssh.main import ( @@ -16,7 +17,7 @@ from kittens.ssh.options.types import Options as SSHOptions from kittens.ssh.options.utils import DELETE_ENV_VAR from kittens.transfer.utils import set_paths from kitty.constants import is_macos, runtime_dir -from kitty.fast_data_types import CURSOR_BEAM +from kitty.fast_data_types import CURSOR_BEAM, shm_unlink from kitty.utils import SSHConnectionData from . import BaseTest @@ -243,7 +244,7 @@ copy --exclude */w.* d1 else: test_script = f'echo "UNTAR_DONE"; {test_script}' ssh_opts['shell_integration'] = SHELL_INTEGRATION_VALUE or 'disabled' - script, replacements, shm = bootstrap_script( + script, replacements, shm_name = bootstrap_script( SSHOptions(ssh_opts), script_type='py' if 'python' in sh else 'sh', request_id="testing", test_script=test_script, request_data=True ) @@ -275,4 +276,5 @@ copy --exclude */w.* d1 pty.wait_till(lambda: pty.screen.cursor.shape == CURSOR_BEAM) return pty finally: - shm.unlink() + with suppress(FileNotFoundError): + shm_unlink(shm_name)