diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index 5185b6790..20b18a2a3 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -28,13 +28,12 @@ from kitty.constants import cache_dir, runtime_dir, shell_integration_dir, ssh_c from kitty.shell_integration import as_str_literal from kitty.shm import SharedMemory from kitty.types import run_once -from kitty.utils import SSHConnectionData, expandvars, resolve_abs_or_config_path +from kitty.utils import expandvars, resolve_abs_or_config_path from kitty.utils import set_echo as turn_off_echo from ..tui.operations import RESTORE_PRIVATE_MODE_VALUES, SAVE_PRIVATE_MODE_VALUES, Mode, restore_colors, save_colors, set_mode from ..tui.utils import kitty_opts, running_in_tmux from .config import init_config -from .copy import CopyInstruction from .utils import create_shared_memory, get_ssh_cli, is_extra_arg, passthrough_args @@ -287,88 +286,6 @@ def bootstrap_script( return prepare_script(ans, sd, script_type), replacements, shm_name -def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]: - boolean_ssh_args, other_ssh_args = get_ssh_cli() - port: Optional[int] = None - expecting_port = expecting_identity = False - expecting_option_val = False - expecting_hostname = False - expecting_extra_val = '' - host_name = identity_file = found_ssh = '' - found_extra_args: List[Tuple[str, str]] = [] - - for i, arg in enumerate(args): - if not found_ssh: - if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'): - found_ssh = arg - continue - if expecting_hostname: - host_name = arg - continue - if arg.startswith('-') and not expecting_option_val: - if arg in boolean_ssh_args: - continue - if arg == '--': - expecting_hostname = True - if arg.startswith('-p'): - if arg[2:].isdigit(): - with suppress(Exception): - port = int(arg[2:]) - continue - elif arg == '-p': - expecting_port = True - elif arg.startswith('-i'): - if arg == '-i': - expecting_identity = True - else: - identity_file = arg[2:] - continue - if arg.startswith('--') and extra_args: - matching_ex = is_extra_arg(arg, extra_args) - if matching_ex: - if '=' in arg: - exval = arg.partition('=')[-1] - found_extra_args.append((matching_ex, exval)) - continue - expecting_extra_val = matching_ex - - expecting_option_val = True - continue - - if expecting_option_val: - if expecting_port: - with suppress(Exception): - port = int(arg) - expecting_port = False - elif expecting_identity: - identity_file = arg - elif expecting_extra_val: - found_extra_args.append((expecting_extra_val, arg)) - expecting_extra_val = '' - expecting_option_val = False - continue - - if not host_name: - host_name = arg - if not host_name: - return None - if host_name.startswith('ssh://'): - from urllib.parse import urlparse - purl = urlparse(host_name) - if purl.hostname: - host_name = purl.hostname - if purl.username: - host_name = f'{purl.username}@{host_name}' - if port is None and purl.port: - port = purl.port - if identity_file: - if not os.path.isabs(identity_file): - identity_file = os.path.expanduser(identity_file) - if not os.path.isabs(identity_file): - identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file)) - - return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args)) - class InvalidSSHArgs(ValueError): diff --git a/kittens/ssh/utils.py b/kittens/ssh/utils.py index 540ec5d09..fffe88949 100644 --- a/kittens/ssh/utils.py +++ b/kittens/ssh/utils.py @@ -4,9 +4,11 @@ import os import subprocess -from typing import Any, Dict, List, Sequence, Set, Tuple +from contextlib import suppress +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple from kitty.types import run_once +from kitty.utils import SSHConnectionData @run_once @@ -183,3 +185,86 @@ def set_server_args_in_cmdline( ans.insert(i, '-t') break argv[:] = ans + server_args + + +def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]: + boolean_ssh_args, other_ssh_args = get_ssh_cli() + port: Optional[int] = None + expecting_port = expecting_identity = False + expecting_option_val = False + expecting_hostname = False + expecting_extra_val = '' + host_name = identity_file = found_ssh = '' + found_extra_args: List[Tuple[str, str]] = [] + + for i, arg in enumerate(args): + if not found_ssh: + if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'): + found_ssh = arg + continue + if expecting_hostname: + host_name = arg + continue + if arg.startswith('-') and not expecting_option_val: + if arg in boolean_ssh_args: + continue + if arg == '--': + expecting_hostname = True + if arg.startswith('-p'): + if arg[2:].isdigit(): + with suppress(Exception): + port = int(arg[2:]) + continue + elif arg == '-p': + expecting_port = True + elif arg.startswith('-i'): + if arg == '-i': + expecting_identity = True + else: + identity_file = arg[2:] + continue + if arg.startswith('--') and extra_args: + matching_ex = is_extra_arg(arg, extra_args) + if matching_ex: + if '=' in arg: + exval = arg.partition('=')[-1] + found_extra_args.append((matching_ex, exval)) + continue + expecting_extra_val = matching_ex + + expecting_option_val = True + continue + + if expecting_option_val: + if expecting_port: + with suppress(Exception): + port = int(arg) + expecting_port = False + elif expecting_identity: + identity_file = arg + elif expecting_extra_val: + found_extra_args.append((expecting_extra_val, arg)) + expecting_extra_val = '' + expecting_option_val = False + continue + + if not host_name: + host_name = arg + if not host_name: + return None + if host_name.startswith('ssh://'): + from urllib.parse import urlparse + purl = urlparse(host_name) + if purl.hostname: + host_name = purl.hostname + if purl.username: + host_name = f'{purl.username}@{host_name}' + if port is None and purl.port: + port = purl.port + if identity_file: + if not os.path.isabs(identity_file): + identity_file = os.path.expanduser(identity_file) + if not os.path.isabs(identity_file): + identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file)) + + return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args)) diff --git a/kitty/window.py b/kitty/window.py index 76e4eaed6..94c2da798 100644 --- a/kitty/window.py +++ b/kitty/window.py @@ -959,7 +959,7 @@ class Window: def handle_remote_file(self, netloc: str, remote_path: str) -> None: from kittens.remote_file.main import is_ssh_kitten_sentinel - from kittens.ssh.main import get_connection_data + from kittens.ssh.utils import get_connection_data from .utils import SSHConnectionData args = self.ssh_kitten_cmdline() diff --git a/kitty_tests/ssh.py b/kitty_tests/ssh.py index 401b588ac..801cd54b9 100644 --- a/kitty_tests/ssh.py +++ b/kitty_tests/ssh.py @@ -10,7 +10,8 @@ from contextlib import suppress from functools import lru_cache from kittens.ssh.config import load_config -from kittens.ssh.main import bootstrap_script, get_connection_data, wrap_bootstrap_script +from kittens.ssh.main import bootstrap_script, wrap_bootstrap_script +from kittens.ssh.utils import get_connection_data from kittens.transfer.utils import set_paths from kitty.constants import is_macos, runtime_dir from kitty.fast_data_types import CURSOR_BEAM, shm_unlink