From 244507336b3a415f1ee08ef03971630cbadda213 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Wed, 8 Feb 2023 16:34:33 +0530 Subject: [PATCH] Function to change the remote command in an ssh kitten cmdline --- kittens/ssh/main.py | 26 +------------ kittens/ssh/utils.py | 88 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 26 deletions(-) diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index 70ebd9814..0e270acf9 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -22,7 +22,7 @@ from base64 import standard_b64decode, standard_b64encode from contextlib import contextmanager, suppress from getpass import getuser from select import select -from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set, Tuple, Union, cast +from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Tuple, Union, cast from kitty.constants import cache_dir, runtime_dir, shell_integration_dir, ssh_control_master_template, str_version, terminfo_dir from kitty.shell_integration import as_str_literal @@ -37,7 +37,7 @@ from .config import init_config from .copy import CopyInstruction from .options.types import Options as SSHOptions from .options.utils import DELETE_ENV_VAR -from .utils import create_shared_memory, ssh_options +from .utils import create_shared_memory, get_ssh_cli, is_extra_arg, passthrough_args @run_once @@ -291,25 +291,6 @@ def bootstrap_script( return prepare_script(ans, sd, script_type), replacements, shm_name -def get_ssh_cli() -> Tuple[Set[str], Set[str]]: - other_ssh_args: Set[str] = set() - boolean_ssh_args: Set[str] = set() - for k, v in ssh_options().items(): - k = f'-{k}' - if v: - other_ssh_args.add(k) - else: - boolean_ssh_args.add(k) - return boolean_ssh_args, other_ssh_args - - -def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str: - for x in extra_args: - if arg == x or arg.startswith(f'{x}='): - return x - return '' - - def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]: boolean_ssh_args, other_ssh_args = get_ssh_cli() port: Optional[int] = None @@ -405,9 +386,6 @@ class InvalidSSHArgs(ValueError): os.execlp(ssh_exe(), 'ssh') -passthrough_args = {f'-{x}' for x in 'NnfGT'} - - def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[List[str], List[str], bool, Tuple[str, ...]]: boolean_ssh_args, other_ssh_args = get_ssh_cli() ssh_args = [] diff --git a/kittens/ssh/utils.py b/kittens/ssh/utils.py index d956a7fdf..069d1f15c 100644 --- a/kittens/ssh/utils.py +++ b/kittens/ssh/utils.py @@ -4,7 +4,7 @@ import os import subprocess -from typing import Any, Dict, List, Sequence +from typing import Any, Dict, List, Sequence, Set, Tuple from kitty.types import run_once @@ -51,9 +51,14 @@ def ssh_options() -> Dict[str, str]: def is_kitten_cmdline(q: Sequence[str]) -> bool: + if not q: + return False + exe_name = os.path.basename(q[0]).lower() + if exe_name == 'kitten' and q[1:2] == ['ssh']: + return True if len(q) < 4: return False - if os.path.basename(q[0]).lower() != 'kitty': + if exe_name != 'kitty': return False if q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh']: return True @@ -91,3 +96,82 @@ def create_shared_memory(data: Any, prefix: str) -> str: def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None: patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv) + + + +def get_ssh_cli() -> Tuple[Set[str], Set[str]]: + other_ssh_args: Set[str] = set() + boolean_ssh_args: Set[str] = set() + for k, v in ssh_options().items(): + k = f'-{k}' + if v: + other_ssh_args.add(k) + else: + boolean_ssh_args.add(k) + return boolean_ssh_args, other_ssh_args + + +def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str: + for x in extra_args: + if arg == x or arg.startswith(f'{x}='): + return x + return '' + + +passthrough_args = {f'-{x}' for x in 'NnfGT'} + + +def set_server_args_in_cmdline(server_args: List[str], argv: List[str], extra_args: Tuple[str, ...] = ('--kitten',)) -> None: + boolean_ssh_args, other_ssh_args = get_ssh_cli() + ssh_args = [] + expecting_option_val = False + found_extra_args: List[str] = [] + expecting_extra_val = '' + ans = list(argv) + found_ssh = False + for i, argument in enumerate(argv): + if not found_ssh: + found_ssh = argument == 'ssh' + continue + if argument.startswith('-') and not expecting_option_val: + if argument == '--': + del ans[i+2:] + break + if extra_args: + matching_ex = is_extra_arg(argument, extra_args) + if matching_ex: + if '=' in argument: + exval = argument.partition('=')[-1] + found_extra_args.extend((matching_ex, exval)) + else: + expecting_extra_val = matching_ex + expecting_option_val = True + continue + # could be a multi-character option + all_args = argument[1:] + for i, arg in enumerate(all_args): + arg = f'-{arg}' + if arg in boolean_ssh_args: + ssh_args.append(arg) + continue + if arg in other_ssh_args: + ssh_args.append(arg) + rest = all_args[i+1:] + if rest: + ssh_args.append(rest) + else: + expecting_option_val = True + break + raise KeyError(f'unknown option -- {arg[1:]}') + continue + if expecting_option_val: + if expecting_extra_val: + found_extra_args.extend((expecting_extra_val, argument)) + expecting_extra_val = '' + else: + ssh_args.append(argument) + expecting_option_val = False + continue + del ans[i+1:] + break + argv[:] = ans + server_args