kitty/kittens/ssh/main.py

506 lines
19 KiB
Python

#!/usr/bin/env python3
# License: GPL v3 Copyright: 2018, Kovid Goyal <kovid at kovidgoyal.net>
import atexit
import fnmatch
import io
import json
import os
import re
import shlex
import sys
import tarfile
import tempfile
import time
import traceback
from base64 import standard_b64decode, standard_b64encode
from contextlib import suppress
from getpass import getuser
from typing import (
Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set,
Tuple, Union
)
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
from kitty.fast_data_types import get_options
from kitty.short_uuid import uuid4
from kitty.utils import SSHConnectionData
from .completion import complete, ssh_options
from .config import init_config, options_for_host
from .copy import CopyInstruction
from .options.types import Options as SSHOptions
from .options.utils import DELETE_ENV_VAR
# See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
quote_pat = re.compile('([\\`"\n])')
def quote_env_val(x: str) -> str:
x = quote_pat.sub(r'\\\1', x)
x = x.replace('$(', r'\$(') # prevent execution with $()
return f'"{x}"'
def serialize_env(env: Dict[str, str], base_env: Dict[str, str]) -> bytes:
lines = []
def a(k: str, val: str) -> None:
lines.append(f'export {shlex.quote(k)}={quote_env_val(val)}')
for k in sorted(env):
v = env[k]
if v == DELETE_ENV_VAR:
lines.append(f'unset {shlex.quote(k)}')
elif v == '_kitty_copy_env_var_':
q = base_env.get(k)
if q is not None:
a(k, q)
else:
a(k, v)
return '\n'.join(lines).encode('utf-8')
def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz') -> bytes:
def normalize_tarinfo(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo:
tarinfo.uname = tarinfo.gname = ''
tarinfo.uid = tarinfo.gid = 0
return tarinfo
def add_data_as_file(tf: tarfile.TarFile, arcname: str, data: Union[str, bytes]) -> tarfile.TarInfo:
ans = tarfile.TarInfo(arcname)
ans.mtime = int(time.time())
ans.type = tarfile.REGTYPE
if isinstance(data, str):
data = data.encode('utf-8')
ans.size = len(data)
normalize_tarinfo(ans)
tf.addfile(ans, io.BytesIO(data))
return ans
def filter_from_globs(*pats: str) -> Callable[[tarfile.TarInfo], Optional[tarfile.TarInfo]]:
def filter(tarinfo: tarfile.TarInfo) -> Optional[tarfile.TarInfo]:
for junk_dir in ('.DS_Store', '__pycache__'):
for pat in (f'*/{junk_dir}', f'*/{junk_dir}/*'):
if fnmatch.fnmatch(tarinfo.name, pat):
return None
for pat in pats:
if fnmatch.fnmatch(tarinfo.name, pat):
return None
return normalize_tarinfo(tarinfo)
return filter
from kitty.shell_integration import get_effective_ksi_env_var
if ssh_opts.shell_integration == 'inherited':
ksi = get_effective_ksi_env_var()
else:
from kitty.options.types import Options
from kitty.options.utils import shell_integration
ksi = get_effective_ksi_env_var(Options({'shell_integration': shell_integration(ssh_opts.shell_integration)}))
env = {
'TERM': get_options().term,
'COLORTERM': 'truecolor',
}
for q in ('KITTY_WINDOW_ID', 'WINDOWID'):
val = os.environ.get(q)
if val is not None:
env[q] = val
env.update(ssh_opts.env)
env['KITTY_SHELL_INTEGRATION'] = ksi or DELETE_ENV_VAR
env['KITTY_SSH_KITTEN_DATA_DIR'] = ssh_opts.remote_dir
if ssh_opts.login_shell:
env['KITTY_LOGIN_SHELL'] = ssh_opts.login_shell
if ssh_opts.cwd:
env['KITTY_LOGIN_CWD'] = ssh_opts.cwd
env_script = serialize_env(env, base_env)
buf = io.BytesIO()
with tarfile.open(mode=f'w:{compression}', fileobj=buf, encoding='utf-8') as tf:
rd = ssh_opts.remote_dir.rstrip('/')
for ci in ssh_opts.copy.values():
tf.add(ci.local_path, arcname=ci.arcname, filter=filter_from_globs(*ci.exclude_patterns))
add_data_as_file(tf, 'data.sh', env_script)
if ksi:
arcname = 'home/' + rd + '/shell-integration'
tf.add(shell_integration_dir, arcname=arcname, filter=filter_from_globs(
f'{arcname}/ssh/bootstrap.*', # bootstrap files are sent as command line args
f'{arcname}/zsh/kitty.zsh', # present for legacy compat not needed by ssh kitten
))
tf.add(terminfo_dir, arcname='home/.terminfo', filter=normalize_tarinfo)
return buf.getvalue()
def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
record_sep = b'\036'
def fmt_prefix(msg: Any) -> bytes:
return str(msg).encode('ascii') + record_sep
yield record_sep # to discard leading data
try:
msg = standard_b64decode(msg).decode('utf-8')
md = dict(x.split('=', 1) for x in msg.split(':'))
hostname = md['hostname']
pw = md['pw']
pwfilename = md['pwfile']
username = md['user']
rq_id = md['id']
compression = md['compression']
except Exception:
traceback.print_exc()
yield fmt_prefix('!invalid ssh data request message')
else:
try:
with open(os.path.join(cache_dir(), 'ssh', pwfilename), 'rb') as f:
os.unlink(f.name)
env_data = json.load(f)
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
raise ValueError('Incorrect request id')
cli_hostname = env_data['cli_hostname']
cli_uname = env_data['cli_uname']
except Exception as e:
traceback.print_exc()
yield fmt_prefix(f'!{e}')
else:
ssh_opts = {k: SSHOptions(v) for k, v in env_data['opts'].items()}
resolved_ssh_opts = options_for_host(hostname, username, ssh_opts, cli_hostname, cli_uname)
resolved_ssh_opts.copy = {k: CopyInstruction(*v) for k, v in resolved_ssh_opts.copy.items()}
try:
data = make_tarfile(resolved_ssh_opts, env_data['env'], compression)
except Exception:
traceback.print_exc()
yield fmt_prefix('!error while gathering ssh data')
else:
encoded_data = standard_b64encode(data)
yield fmt_prefix(len(encoded_data))
yield encoded_data
def safe_remove(x: str) -> None:
with suppress(OSError):
os.remove(x)
def prepare_script(ans: str, replacements: Dict[str, str]) -> str:
for k in ('EXEC_CMD',):
replacements[k] = replacements.get(k, '')
def sub(m: 're.Match[str]') -> str:
return replacements[m.group()]
return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans)
def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str:
# ssh simply concatenates multiple commands using a space see
# line 1129 of ssh.c and on the remote side sshd.c runs the
# concatenated command as shell -c cmd
if is_python:
return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args)
return f"""exec "$login_shell" -c '{args}'"""
def bootstrap_script(
script_type: str = 'sh', remote_args: Sequence[str] = (),
ssh_opts_dict: Dict[str, Dict[str, Any]] = {},
test_script: str = '', request_id: Optional[str] = None, cli_hostname: str = '', cli_uname: str = ''
) -> Tuple[str, Dict[str, str]]:
if request_id is None:
request_id = os.environ['KITTY_PID'] + '-' + os.environ['KITTY_WINDOW_ID']
exec_cmd = prepare_exec_cmd(remote_args, script_type == 'py') if remote_args else ''
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
ans = f.read()
pw = uuid4()
ddir = os.path.join(cache_dir(), 'ssh')
os.makedirs(ddir, exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='ssh-kitten-pw-', suffix='.json', dir=ddir, delete=False) as tf:
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
tf.write(json.dumps(data).encode('utf-8'))
atexit.register(safe_remove, tf.name)
replacements = {
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
'REQUEST_ID': request_id
}
return prepare_script(ans, replacements), replacements
def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
other_ssh_args: Set[str] = set()
boolean_ssh_args: Set[str] = set()
for k, v in ssh_options().items():
k = f'-{k}'
if v:
other_ssh_args.add(k)
else:
boolean_ssh_args.add(k)
return boolean_ssh_args, other_ssh_args
def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str:
for x in extra_args:
if arg == x or arg.startswith(f'{x}='):
return x
return ''
def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
port: Optional[int] = None
expecting_port = expecting_identity = False
expecting_option_val = False
expecting_hostname = False
expecting_extra_val = ''
host_name = identity_file = found_ssh = ''
found_extra_args: List[Tuple[str, str]] = []
for i, arg in enumerate(args):
if not found_ssh:
if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
found_ssh = arg
continue
if expecting_hostname:
host_name = arg
continue
if arg.startswith('-') and not expecting_option_val:
if arg in boolean_ssh_args:
continue
if arg == '--':
expecting_hostname = True
if arg.startswith('-p'):
if arg[2:].isdigit():
with suppress(Exception):
port = int(arg[2:])
continue
elif arg == '-p':
expecting_port = True
elif arg.startswith('-i'):
if arg == '-i':
expecting_identity = True
else:
identity_file = arg[2:]
continue
if arg.startswith('--') and extra_args:
matching_ex = is_extra_arg(arg, extra_args)
if matching_ex:
if '=' in arg:
exval = arg.partition('=')[-1]
found_extra_args.append((matching_ex, exval))
continue
expecting_extra_val = matching_ex
expecting_option_val = True
continue
if expecting_option_val:
if expecting_port:
with suppress(Exception):
port = int(arg)
expecting_port = False
elif expecting_identity:
identity_file = arg
elif expecting_extra_val:
found_extra_args.append((expecting_extra_val, arg))
expecting_option_val = False
continue
if not host_name:
host_name = arg
if not host_name:
return None
if identity_file:
if not os.path.isabs(identity_file):
identity_file = os.path.expanduser(identity_file)
if not os.path.isabs(identity_file):
identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))
class InvalidSSHArgs(ValueError):
def __init__(self, msg: str = ''):
super().__init__(msg)
self.err_msg = msg
def system_exit(self) -> None:
if self.err_msg:
print(self.err_msg, file=sys.stderr)
os.execlp('ssh', 'ssh')
def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[List[str], List[str], bool, Tuple[str, ...]]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
passthrough_args = {f'-{x}' for x in 'NnfG'}
ssh_args = []
server_args: List[str] = []
expecting_option_val = False
passthrough = False
stop_option_processing = False
found_extra_args: List[str] = []
expecting_extra_val = ''
for argument in args:
if len(server_args) > 1 or stop_option_processing:
server_args.append(argument)
continue
if argument.startswith('-') and not expecting_option_val:
if argument == '--':
stop_option_processing = True
continue
if extra_args:
matching_ex = is_extra_arg(argument, extra_args)
if matching_ex:
if '=' in argument:
exval = argument.partition('=')[-1]
found_extra_args.extend((matching_ex, exval))
else:
expecting_extra_val = matching_ex
expecting_option_val = True
continue
# could be a multi-character option
all_args = argument[1:]
for i, arg in enumerate(all_args):
arg = f'-{arg}'
if arg in passthrough_args:
passthrough = True
if arg in boolean_ssh_args:
ssh_args.append(arg)
continue
if arg in other_ssh_args:
ssh_args.append(arg)
rest = all_args[i+1:]
if rest:
ssh_args.append(rest)
else:
expecting_option_val = True
break
raise InvalidSSHArgs(f'unknown option -- {arg[1:]}')
continue
if expecting_option_val:
if expecting_extra_val:
found_extra_args.extend((expecting_extra_val, argument))
else:
ssh_args.append(argument)
expecting_option_val = False
continue
server_args.append(argument)
if not server_args:
raise InvalidSSHArgs()
return ssh_args, server_args, passthrough, tuple(found_extra_args)
def wrap_bootstrap_script(sh_script: str, interpreter: str) -> List[str]:
# sshd will execute the command we pass it by join all command line
# arguments with a space and passing it as a single argument to the users
# login shell with -c. If the user has a non POSIX login shell it might
# have different escaping semantics and syntax, so the command it should
# execute has to be as simple as possible, basically of the form
# interpreter -c unwrap_script escaped_bootstrap_script
# The unwrap_script is responsible for unescaping the bootstrap script and
# executing it.
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
if is_python:
es = standard_b64encode(sh_script.encode('utf-8')).decode('ascii')
unwrap_script = '''"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"'''
else:
# We cant rely on base64 being available on the remote system, so instead
# we quote the bootstrap script by replacing ' and \ with \v and \f
# also replacing \n and ! with \r and \b for tcsh
# finally surrounding with '
es = "'" + sh_script.replace("'", '\v').replace('\\', '\f').replace('\n', '\r').replace('!', '\b') + "'"
unwrap_script = r"""'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' """
# exec is supported by all sh like shells, and fish and csh
return ['exec', interpreter, '-c', unwrap_script, es]
def get_remote_command(
remote_args: List[str], hostname: str = 'localhost', cli_hostname: str = '', cli_uname: str = '',
interpreter: str = 'sh',
ssh_opts_dict: Dict[str, Dict[str, Any]] = {}
) -> Tuple[List[str], Dict[str, str]]:
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
sh_script, replacements = bootstrap_script(
script_type='py' if is_python else 'sh', remote_args=remote_args, ssh_opts_dict=ssh_opts_dict,
cli_hostname=cli_hostname, cli_uname=cli_uname)
return wrap_bootstrap_script(sh_script, interpreter), replacements
def connection_sharing_args(opts: SSHOptions, kitty_pid: int) -> List[str]:
cp = os.path.join(cache_dir(), 'ssh', f'{kitty_pid}-master-%C')
ans: List[str] = [
'-o', 'ControlMaster=auto',
'-o', f'ControlPath={cp}',
'-o', 'ControlPersist=yes',
]
return ans
def main(args: List[str]) -> NoReturn:
args = args[1:]
if args and args[0] == 'use-python':
args = args[1:] # backwards compat from when we had a python implementation
try:
ssh_args, server_args, passthrough, found_extra_args = parse_ssh_args(args, extra_args=('--kitten',))
except InvalidSSHArgs as e:
e.system_exit()
if not os.environ.get('KITTY_WINDOW_ID'):
passthrough = True
cmd = ['ssh'] + ssh_args
if passthrough:
cmd += server_args
else:
hostname, remote_args = server_args[0], server_args[1:]
if not remote_args:
cmd.append('-t')
insertion_point = len(cmd)
cmd.append('--')
cmd.append(hostname)
uname = getuser()
if hostname.startswith('ssh://'):
from urllib.parse import urlparse
purl = urlparse(hostname)
hostname_for_match = purl.hostname or hostname
uname = purl.username or uname
elif '@' in hostname and hostname[0] != '@':
uname, hostname_for_match = hostname.split('@', 1)
else:
hostname_for_match = hostname
hostname_for_match = hostname.split('@', 1)[-1].split(':', 1)[0]
overrides = []
pat = re.compile(r'^([a-zA-Z0-9_]+)[ \t]*=')
for i, a in enumerate(found_extra_args):
if i % 2 == 1:
overrides.append(pat.sub(r'\1 ', a.lstrip()))
if overrides:
overrides.insert(0, f'hostname {uname}@{hostname_for_match}')
so = init_config(overrides)
sod = {k: v._asdict() for k, v in so.items()}
host_opts = options_for_host(hostname_for_match, uname, so)
use_control_master = 'KITTY_PID' in os.environ and host_opts.share_connections
rcmd, replacements = get_remote_command(remote_args, hostname, hostname_for_match, uname, host_opts.interpreter, sod)
cmd += rcmd
if use_control_master:
cmd[insertion_point:insertion_point] = connection_sharing_args(host_opts, int(os.environ['KITTY_PID']))
import subprocess
with suppress(FileNotFoundError):
try:
raise SystemExit(subprocess.run(cmd).returncode)
except KeyboardInterrupt:
raise SystemExit(1)
raise SystemExit('Could not find the ssh executable, is it in your PATH?')
if __name__ == '__main__':
main(sys.argv)
elif __name__ == '__completer__':
setattr(sys, 'kitten_completer', complete)
elif __name__ == '__conf__':
from .options.definition import definition
sys.options_definition = definition # type: ignore