Function to change the remote command in an ssh kitten cmdline

This commit is contained in:
Kovid Goyal 2023-02-08 16:34:33 +05:30
parent 237a5d17c0
commit 244507336b
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 88 additions and 26 deletions

View File

@ -22,7 +22,7 @@ from base64 import standard_b64decode, standard_b64encode
from contextlib import contextmanager, suppress
from getpass import getuser
from select import select
from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set, Tuple, Union, cast
from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Tuple, Union, cast
from kitty.constants import cache_dir, runtime_dir, shell_integration_dir, ssh_control_master_template, str_version, terminfo_dir
from kitty.shell_integration import as_str_literal
@ -37,7 +37,7 @@ from .config import init_config
from .copy import CopyInstruction
from .options.types import Options as SSHOptions
from .options.utils import DELETE_ENV_VAR
from .utils import create_shared_memory, ssh_options
from .utils import create_shared_memory, get_ssh_cli, is_extra_arg, passthrough_args
@run_once
@ -291,25 +291,6 @@ def bootstrap_script(
return prepare_script(ans, sd, script_type), replacements, shm_name
def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
other_ssh_args: Set[str] = set()
boolean_ssh_args: Set[str] = set()
for k, v in ssh_options().items():
k = f'-{k}'
if v:
other_ssh_args.add(k)
else:
boolean_ssh_args.add(k)
return boolean_ssh_args, other_ssh_args
def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str:
for x in extra_args:
if arg == x or arg.startswith(f'{x}='):
return x
return ''
def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
port: Optional[int] = None
@ -405,9 +386,6 @@ class InvalidSSHArgs(ValueError):
os.execlp(ssh_exe(), 'ssh')
passthrough_args = {f'-{x}' for x in 'NnfGT'}
def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[List[str], List[str], bool, Tuple[str, ...]]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
ssh_args = []

View File

@ -4,7 +4,7 @@
import os
import subprocess
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Sequence, Set, Tuple
from kitty.types import run_once
@ -51,9 +51,14 @@ def ssh_options() -> Dict[str, str]:
def is_kitten_cmdline(q: Sequence[str]) -> bool:
if not q:
return False
exe_name = os.path.basename(q[0]).lower()
if exe_name == 'kitten' and q[1:2] == ['ssh']:
return True
if len(q) < 4:
return False
if os.path.basename(q[0]).lower() != 'kitty':
if exe_name != 'kitty':
return False
if q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh']:
return True
@ -91,3 +96,82 @@ def create_shared_memory(data: Any, prefix: str) -> str:
def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None:
patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)
def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
other_ssh_args: Set[str] = set()
boolean_ssh_args: Set[str] = set()
for k, v in ssh_options().items():
k = f'-{k}'
if v:
other_ssh_args.add(k)
else:
boolean_ssh_args.add(k)
return boolean_ssh_args, other_ssh_args
def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str:
for x in extra_args:
if arg == x or arg.startswith(f'{x}='):
return x
return ''
passthrough_args = {f'-{x}' for x in 'NnfGT'}
def set_server_args_in_cmdline(server_args: List[str], argv: List[str], extra_args: Tuple[str, ...] = ('--kitten',)) -> None:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
ssh_args = []
expecting_option_val = False
found_extra_args: List[str] = []
expecting_extra_val = ''
ans = list(argv)
found_ssh = False
for i, argument in enumerate(argv):
if not found_ssh:
found_ssh = argument == 'ssh'
continue
if argument.startswith('-') and not expecting_option_val:
if argument == '--':
del ans[i+2:]
break
if extra_args:
matching_ex = is_extra_arg(argument, extra_args)
if matching_ex:
if '=' in argument:
exval = argument.partition('=')[-1]
found_extra_args.extend((matching_ex, exval))
else:
expecting_extra_val = matching_ex
expecting_option_val = True
continue
# could be a multi-character option
all_args = argument[1:]
for i, arg in enumerate(all_args):
arg = f'-{arg}'
if arg in boolean_ssh_args:
ssh_args.append(arg)
continue
if arg in other_ssh_args:
ssh_args.append(arg)
rest = all_args[i+1:]
if rest:
ssh_args.append(rest)
else:
expecting_option_val = True
break
raise KeyError(f'unknown option -- {arg[1:]}')
continue
if expecting_option_val:
if expecting_extra_val:
found_extra_args.extend((expecting_extra_val, argument))
expecting_extra_val = ''
else:
ssh_args.append(argument)
expecting_option_val = False
continue
del ans[i+1:]
break
argv[:] = ans + server_args