diff --git a/docs/kittens/ssh.rst b/docs/kittens/ssh.rst index e4f425e4d..6b89d9c66 100644 --- a/docs/kittens/ssh.rst +++ b/docs/kittens/ssh.rst @@ -46,7 +46,7 @@ quick example: copy env-files env SOMETHING=else - hostname somehost + hostname someuser@somehost copy --dest=foo/bar some-file copy --glob some/files.* diff --git a/kittens/ssh/config.py b/kittens/ssh/config.py index ffbf9e758..4379f4315 100644 --- a/kittens/ssh/config.py +++ b/kittens/ssh/config.py @@ -2,6 +2,7 @@ # License: GPLv3 Copyright: 2022, Kovid Goyal +import fnmatch import os from typing import Any, Dict, Iterable, Optional @@ -16,12 +17,18 @@ SYSTEM_CONF = '/etc/xdg/kitty/ssh.conf' defconf = os.path.join(config_dir, 'ssh.conf') -def options_for_host(hostname: str, per_host_opts: Dict[str, SSHOptions]) -> SSHOptions: - import fnmatch +def host_matches(pat: str, hostname: str, username: str) -> bool: + upat = '*' + if '@' in pat: + upat, pat = pat.split('@', 1) + return fnmatch.fnmatchcase(hostname, pat) and fnmatch.fnmatchcase(username, upat) + + +def options_for_host(hostname: str, username: str, per_host_opts: Dict[str, SSHOptions]) -> SSHOptions: matches = [] for spat, opts in per_host_opts.items(): for pat in spat.split(): - if fnmatch.fnmatchcase(hostname, pat): + if host_matches(pat, hostname, username): matches.append(opts) if not matches: return SSHOptions({}) @@ -45,7 +52,9 @@ def load_config(*paths: str, overrides: Optional[Iterable[str]] = None) -> Dict[ from .options.parse import ( create_result_dict, merge_result_dicts, parse_conf_item ) - from .options.utils import get_per_hosts_dict, init_results_dict, first_seen_positions + from .options.utils import ( + first_seen_positions, get_per_hosts_dict, init_results_dict + ) def merge_dicts(base: Dict[str, Any], vals: Dict[str, Any]) -> Dict[str, Any]: base_phd = get_per_hosts_dict(base) diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index 7bebbe8d7..da2cb3658 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -15,6 +15,7 @@ import time import traceback from base64 import standard_b64decode from contextlib import suppress +from getpass import getuser from typing import ( Any, Callable, Dict, Iterator, List, NoReturn, Optional, Set, Tuple, Union ) @@ -131,6 +132,7 @@ def get_ssh_data(msg: str, ssh_opts: Optional[Dict[str, SSHOptions]] = None) -> hostname = md['hostname'] pw = md['pw'] pwfilename = md['pwfile'] + username = md['user'] except Exception: traceback.print_exc() yield fmt_prefix('!invalid ssh data request message') @@ -145,7 +147,7 @@ def get_ssh_data(msg: str, ssh_opts: Optional[Dict[str, SSHOptions]] = None) -> traceback.print_exc() yield fmt_prefix('!incorrect ssh data password') else: - resolved_ssh_opts = options_for_host(hostname, ssh_opts) + resolved_ssh_opts = options_for_host(hostname, username, ssh_opts) try: data = make_tarfile(resolved_ssh_opts, env_data['env']) except Exception: @@ -348,8 +350,18 @@ def main(args: List[str]) -> NoReturn: cmd.append('-t') cmd.append('--') cmd.append(hostname) + uname = getuser() + if hostname.startswith('ssh://'): + from urllib.parse import urlparse + purl = urlparse(hostname) + hostname_for_match = purl.hostname or hostname + uname = purl.username or uname + elif '@' in hostname and hostname[0] != '@': + uname, hostname_for_match = hostname.split('@', 1) + else: + hostname_for_match = hostname hostname_for_match = hostname.split('@', 1)[-1].split(':', 1)[0] - cmd += get_remote_command(remote_args, hostname, options_for_host(hostname_for_match, load_ssh_options()).interpreter) + cmd += get_remote_command(remote_args, hostname, options_for_host(hostname_for_match, uname, load_ssh_options()).interpreter) os.execvp('ssh', cmd) diff --git a/kittens/ssh/options/definition.py b/kittens/ssh/options/definition.py index 1bd6a924e..efdc39413 100644 --- a/kittens/ssh/options/definition.py +++ b/kittens/ssh/options/definition.py @@ -27,6 +27,7 @@ opt('hostname', '*', option_type='hostname', long_text=''' The hostname the following options apply to. A glob pattern to match multiple hosts can be used. Multiple hostnames can also be specified separated by spaces. +The hostname can include an optional username in the form :code:`user@host`. When not specified options apply to all hosts, until the first hostname specification is found. Note that the hostname this matches against is the hostname used by the remote computer, not the name you pass diff --git a/kitty_tests/ssh.py b/kitty_tests/ssh.py index 7498b79b9..46342e278 100644 --- a/kitty_tests/ssh.py +++ b/kitty_tests/ssh.py @@ -57,10 +57,10 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77) def parse(conf): return load_config(overrides=conf.splitlines()) - def for_host(hostname, conf): + def for_host(hostname, conf, username='kitty'): if isinstance(conf, str): conf = parse(conf) - return options_for_host(hostname, conf) + return options_for_host(hostname, username, conf) self.ae(for_host('x', '').env, {}) self.ae(for_host('x', 'env a=b').env, {'a': 'b'}) @@ -70,6 +70,15 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77) self.ae(for_host('2', pc).env, {'a': 'c', 'b': 'b'}) self.ae(for_host('x', 'env a=').env, {'a': ''}) self.ae(for_host('x', 'env a').env, {'a': '_delete_this_env_var_'}) + pc = parse('env a=b\nhostname test@2\nenv a=c\nenv b=b') + self.ae(set(pc.keys()), {'*', 'test@2'}) + self.ae(for_host('x', pc).env, {'a': 'b'}) + self.ae(for_host('2', pc).env, {'a': 'b'}) + self.ae(for_host('2', pc, 'test').env, {'a': 'c', 'b': 'b'}) + pc = parse('env a=b\nhostname 1 2\nenv a=c\nenv b=b') + self.ae(for_host('x', pc).env, {'a': 'b'}) + self.ae(for_host('1', pc).env, {'a': 'c', 'b': 'b'}) + self.ae(for_host('2', pc).env, {'a': 'c', 'b': 'b'}) @property @lru_cache()