kitty/kittens/ssh/main.py
2022-03-10 06:11:58 +05:30

511 lines
20 KiB
Python

#!/usr/bin/env python3
# License: GPL v3 Copyright: 2018, Kovid Goyal <kovid at kovidgoyal.net>
import atexit
import fnmatch
import io
import json
import os
import re
import secrets
import shlex
import stat
import sys
import tarfile
import tempfile
import time
import traceback
from base64 import standard_b64decode, standard_b64encode
from contextlib import suppress
from getpass import getuser
from typing import (
Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set,
Tuple, Union
)
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
from kitty.fast_data_types import get_options
from kitty.utils import SSHConnectionData
from .completion import complete, ssh_options
from .config import init_config, options_for_host
from .copy import CopyInstruction
from .options.types import Options as SSHOptions
from .options.utils import DELETE_ENV_VAR
# See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
quote_pat = re.compile('([\\`"\n])')
def quote_env_val(x: str) -> str:
x = quote_pat.sub(r'\\\1', x)
x = x.replace('$(', r'\$(') # prevent execution with $()
return f'"{x}"'
def serialize_env(env: Dict[str, str], base_env: Dict[str, str]) -> bytes:
lines = []
def a(k: str, val: str) -> None:
lines.append(f'export {shlex.quote(k)}={quote_env_val(val)}')
for k in sorted(env):
v = env[k]
if v == DELETE_ENV_VAR:
lines.append(f'unset {shlex.quote(k)}')
elif v == '_kitty_copy_env_var_':
q = base_env.get(k)
if q is not None:
a(k, q)
else:
a(k, v)
return '\n'.join(lines).encode('utf-8')
def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz') -> bytes:
def normalize_tarinfo(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo:
tarinfo.uname = tarinfo.gname = ''
tarinfo.uid = tarinfo.gid = 0
return tarinfo
def add_data_as_file(tf: tarfile.TarFile, arcname: str, data: Union[str, bytes]) -> tarfile.TarInfo:
ans = tarfile.TarInfo(arcname)
ans.mtime = int(time.time())
ans.type = tarfile.REGTYPE
if isinstance(data, str):
data = data.encode('utf-8')
ans.size = len(data)
normalize_tarinfo(ans)
tf.addfile(ans, io.BytesIO(data))
return ans
def filter_from_globs(*pats: str) -> Callable[[tarfile.TarInfo], Optional[tarfile.TarInfo]]:
def filter(tarinfo: tarfile.TarInfo) -> Optional[tarfile.TarInfo]:
for junk_dir in ('.DS_Store', '__pycache__'):
for pat in (f'*/{junk_dir}', f'*/{junk_dir}/*'):
if fnmatch.fnmatch(tarinfo.name, pat):
return None
for pat in pats:
if fnmatch.fnmatch(tarinfo.name, pat):
return None
return normalize_tarinfo(tarinfo)
return filter
from kitty.shell_integration import get_effective_ksi_env_var
if ssh_opts.shell_integration == 'inherited':
ksi = get_effective_ksi_env_var()
else:
from kitty.options.types import Options
from kitty.options.utils import shell_integration
ksi = get_effective_ksi_env_var(Options({'shell_integration': shell_integration(ssh_opts.shell_integration)}))
env = {
'TERM': get_options().term,
'COLORTERM': 'truecolor',
}
for q in ('KITTY_WINDOW_ID', 'WINDOWID'):
val = os.environ.get(q)
if val is not None:
env[q] = val
env.update(ssh_opts.env)
env['KITTY_SHELL_INTEGRATION'] = ksi or DELETE_ENV_VAR
env['KITTY_SSH_KITTEN_DATA_DIR'] = ssh_opts.remote_dir
if ssh_opts.login_shell:
env['KITTY_LOGIN_SHELL'] = ssh_opts.login_shell
if ssh_opts.cwd:
env['KITTY_LOGIN_CWD'] = ssh_opts.cwd
env_script = serialize_env(env, base_env)
buf = io.BytesIO()
with tarfile.open(mode=f'w:{compression}', fileobj=buf, encoding='utf-8') as tf:
rd = ssh_opts.remote_dir.rstrip('/')
for ci in ssh_opts.copy.values():
tf.add(ci.local_path, arcname=ci.arcname, filter=filter_from_globs(*ci.exclude_patterns))
add_data_as_file(tf, 'data.sh', env_script)
if ksi:
arcname = 'home/' + rd + '/shell-integration'
tf.add(shell_integration_dir, arcname=arcname, filter=filter_from_globs(
f'{arcname}/ssh/bootstrap.*', # bootstrap files are sent as command line args
f'{arcname}/zsh/kitty.zsh', # present for legacy compat not needed by ssh kitten
))
tf.add(terminfo_dir, arcname='home/.terminfo', filter=normalize_tarinfo)
return buf.getvalue()
def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
record_sep = b'\036'
def fmt_prefix(msg: Any) -> bytes:
return str(msg).encode('ascii') + record_sep
yield record_sep # to discard leading data
try:
msg = standard_b64decode(msg).decode('utf-8')
md = dict(x.split('=', 1) for x in msg.split(':'))
hostname = md['hostname']
pw = md['pw']
pwfilename = md['pwfile']
username = md['user']
rq_id = md['id']
compression = md['compression']
except Exception:
traceback.print_exc()
yield fmt_prefix('!invalid ssh data request message')
else:
try:
with open(os.path.join(cache_dir(), 'ssh', pwfilename), 'rb') as f:
os.unlink(f.name)
st = os.stat(f.fileno())
mode = stat.S_IMODE(st.st_mode)
if mode != stat.S_IREAD:
raise ValueError('Incorrect permissions on pwfile')
env_data = json.load(f)
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
raise ValueError('Incorrect request id')
cli_hostname = env_data['cli_hostname']
cli_uname = env_data['cli_uname']
except Exception as e:
traceback.print_exc()
yield fmt_prefix(f'!{e}')
else:
ssh_opts = {k: SSHOptions(v) for k, v in env_data['opts'].items()}
resolved_ssh_opts = options_for_host(hostname, username, ssh_opts, cli_hostname, cli_uname)
resolved_ssh_opts.copy = {k: CopyInstruction(*v) for k, v in resolved_ssh_opts.copy.items()}
try:
data = make_tarfile(resolved_ssh_opts, env_data['env'], compression)
except Exception:
traceback.print_exc()
yield fmt_prefix('!error while gathering ssh data')
else:
encoded_data = standard_b64encode(data)
yield fmt_prefix(len(encoded_data))
yield encoded_data
def safe_remove(x: str) -> None:
with suppress(OSError):
os.remove(x)
def prepare_script(ans: str, replacements: Dict[str, str]) -> str:
for k in ('EXEC_CMD',):
replacements[k] = replacements.get(k, '')
def sub(m: 're.Match[str]') -> str:
return replacements[m.group()]
return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans)
def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str:
# ssh simply concatenates multiple commands using a space see
# line 1129 of ssh.c and on the remote side sshd.c runs the
# concatenated command as shell -c cmd
if is_python:
return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args)
return f"""exec "$login_shell" -c '{args}'"""
def bootstrap_script(
script_type: str = 'sh', remote_args: Sequence[str] = (),
ssh_opts_dict: Dict[str, Dict[str, Any]] = {},
test_script: str = '', request_id: Optional[str] = None, cli_hostname: str = '', cli_uname: str = ''
) -> Tuple[str, Dict[str, str]]:
if request_id is None:
request_id = os.environ['KITTY_PID'] + '-' + os.environ['KITTY_WINDOW_ID']
exec_cmd = prepare_exec_cmd(remote_args, script_type == 'py') if remote_args else ''
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
ans = f.read()
pw = secrets.token_hex()
ddir = os.path.join(cache_dir(), 'ssh')
os.makedirs(ddir, exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='ssh-kitten-pw-', suffix='.json', dir=ddir, delete=False) as tf:
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict, 'cli_hostname': cli_hostname, 'cli_uname': cli_uname}
tf.write(json.dumps(data).encode('utf-8'))
os.fchmod(tf.fileno(), stat.S_IREAD)
atexit.register(safe_remove, tf.name)
replacements = {
'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
'REQUEST_ID': request_id
}
return prepare_script(ans, replacements), replacements
def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
other_ssh_args: Set[str] = set()
boolean_ssh_args: Set[str] = set()
for k, v in ssh_options().items():
k = f'-{k}'
if v:
other_ssh_args.add(k)
else:
boolean_ssh_args.add(k)
return boolean_ssh_args, other_ssh_args
def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str:
for x in extra_args:
if arg == x or arg.startswith(f'{x}='):
return x
return ''
def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
port: Optional[int] = None
expecting_port = expecting_identity = False
expecting_option_val = False
expecting_hostname = False
expecting_extra_val = ''
host_name = identity_file = found_ssh = ''
found_extra_args: List[Tuple[str, str]] = []
for i, arg in enumerate(args):
if not found_ssh:
if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
found_ssh = arg
continue
if expecting_hostname:
host_name = arg
continue
if arg.startswith('-') and not expecting_option_val:
if arg in boolean_ssh_args:
continue
if arg == '--':
expecting_hostname = True
if arg.startswith('-p'):
if arg[2:].isdigit():
with suppress(Exception):
port = int(arg[2:])
continue
elif arg == '-p':
expecting_port = True
elif arg.startswith('-i'):
if arg == '-i':
expecting_identity = True
else:
identity_file = arg[2:]
continue
if arg.startswith('--') and extra_args:
matching_ex = is_extra_arg(arg, extra_args)
if matching_ex:
if '=' in arg:
exval = arg.partition('=')[-1]
found_extra_args.append((matching_ex, exval))
continue
expecting_extra_val = matching_ex
expecting_option_val = True
continue
if expecting_option_val:
if expecting_port:
with suppress(Exception):
port = int(arg)
expecting_port = False
elif expecting_identity:
identity_file = arg
elif expecting_extra_val:
found_extra_args.append((expecting_extra_val, arg))
expecting_option_val = False
continue
if not host_name:
host_name = arg
if not host_name:
return None
if identity_file:
if not os.path.isabs(identity_file):
identity_file = os.path.expanduser(identity_file)
if not os.path.isabs(identity_file):
identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))
class InvalidSSHArgs(ValueError):
def __init__(self, msg: str = ''):
super().__init__(msg)
self.err_msg = msg
def system_exit(self) -> None:
if self.err_msg:
print(self.err_msg, file=sys.stderr)
os.execlp('ssh', 'ssh')
def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[List[str], List[str], bool, Tuple[str, ...]]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
passthrough_args = {f'-{x}' for x in 'NnfG'}
ssh_args = []
server_args: List[str] = []
expecting_option_val = False
passthrough = False
stop_option_processing = False
found_extra_args: List[str] = []
expecting_extra_val = ''
for argument in args:
if len(server_args) > 1 or stop_option_processing:
server_args.append(argument)
continue
if argument.startswith('-') and not expecting_option_val:
if argument == '--':
stop_option_processing = True
continue
if extra_args:
matching_ex = is_extra_arg(argument, extra_args)
if matching_ex:
if '=' in argument:
exval = argument.partition('=')[-1]
found_extra_args.extend((matching_ex, exval))
else:
expecting_extra_val = matching_ex
expecting_option_val = True
continue
# could be a multi-character option
all_args = argument[1:]
for i, arg in enumerate(all_args):
arg = f'-{arg}'
if arg in passthrough_args:
passthrough = True
if arg in boolean_ssh_args:
ssh_args.append(arg)
continue
if arg in other_ssh_args:
ssh_args.append(arg)
rest = all_args[i+1:]
if rest:
ssh_args.append(rest)
else:
expecting_option_val = True
break
raise InvalidSSHArgs(f'unknown option -- {arg[1:]}')
continue
if expecting_option_val:
if expecting_extra_val:
found_extra_args.extend((expecting_extra_val, argument))
else:
ssh_args.append(argument)
expecting_option_val = False
continue
server_args.append(argument)
if not server_args:
raise InvalidSSHArgs()
return ssh_args, server_args, passthrough, tuple(found_extra_args)
def wrap_bootstrap_script(sh_script: str, interpreter: str) -> List[str]:
# sshd will execute the command we pass it by join all command line
# arguments with a space and passing it as a single argument to the users
# login shell with -c. If the user has a non POSIX login shell it might
# have different escaping semantics and syntax, so the command it should
# execute has to be as simple as possible, basically of the form
# interpreter -c unwrap_script escaped_bootstrap_script
# The unwrap_script is responsible for unescaping the bootstrap script and
# executing it.
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
if is_python:
es = standard_b64encode(sh_script.encode('utf-8')).decode('ascii')
unwrap_script = '''"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"'''
else:
# We cant rely on base64 being available on the remote system, so instead
# we quote the bootstrap script by replacing ' and \ with \v and \f
# also replacing \n and ! with \r and \b for tcsh
# finally surrounding with '
es = "'" + sh_script.replace("'", '\v').replace('\\', '\f').replace('\n', '\r').replace('!', '\b') + "'"
unwrap_script = r"""'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' """
# exec is supported by all sh like shells, and fish and csh
return ['exec', interpreter, '-c', unwrap_script, es]
def get_remote_command(
remote_args: List[str], hostname: str = 'localhost', cli_hostname: str = '', cli_uname: str = '',
interpreter: str = 'sh',
ssh_opts_dict: Dict[str, Dict[str, Any]] = {}
) -> Tuple[List[str], Dict[str, str]]:
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
sh_script, replacements = bootstrap_script(
script_type='py' if is_python else 'sh', remote_args=remote_args, ssh_opts_dict=ssh_opts_dict,
cli_hostname=cli_hostname, cli_uname=cli_uname)
return wrap_bootstrap_script(sh_script, interpreter), replacements
def connection_sharing_args(opts: SSHOptions, kitty_pid: int) -> List[str]:
cp = os.path.join(cache_dir(), 'ssh', f'{kitty_pid}-master-%C')
ans: List[str] = [
'-o', 'ControlMaster=auto',
'-o', f'ControlPath={cp}',
'-o', 'ControlPersist=yes',
]
return ans
def main(args: List[str]) -> NoReturn:
args = args[1:]
if args and args[0] == 'use-python':
args = args[1:] # backwards compat from when we had a python implementation
try:
ssh_args, server_args, passthrough, found_extra_args = parse_ssh_args(args, extra_args=('--kitten',))
except InvalidSSHArgs as e:
e.system_exit()
if not os.environ.get('KITTY_WINDOW_ID'):
passthrough = True
cmd = ['ssh'] + ssh_args
if passthrough:
cmd += server_args
else:
hostname, remote_args = server_args[0], server_args[1:]
if not remote_args:
cmd.append('-t')
insertion_point = len(cmd)
cmd.append('--')
cmd.append(hostname)
uname = getuser()
if hostname.startswith('ssh://'):
from urllib.parse import urlparse
purl = urlparse(hostname)
hostname_for_match = purl.hostname or hostname
uname = purl.username or uname
elif '@' in hostname and hostname[0] != '@':
uname, hostname_for_match = hostname.split('@', 1)
else:
hostname_for_match = hostname
hostname_for_match = hostname.split('@', 1)[-1].split(':', 1)[0]
overrides = []
pat = re.compile(r'^([a-zA-Z0-9_]+)[ \t]*=')
for i, a in enumerate(found_extra_args):
if i % 2 == 1:
overrides.append(pat.sub(r'\1 ', a.lstrip()))
if overrides:
overrides.insert(0, f'hostname {uname}@{hostname_for_match}')
so = init_config(overrides)
sod = {k: v._asdict() for k, v in so.items()}
host_opts = options_for_host(hostname_for_match, uname, so)
use_control_master = 'KITTY_PID' in os.environ and host_opts.share_connections
rcmd, replacements = get_remote_command(remote_args, hostname, hostname_for_match, uname, host_opts.interpreter, sod)
cmd += rcmd
if use_control_master:
cmd[insertion_point:insertion_point] = connection_sharing_args(host_opts, int(os.environ['KITTY_PID']))
import subprocess
with suppress(FileNotFoundError):
try:
raise SystemExit(subprocess.run(cmd).returncode)
except KeyboardInterrupt:
raise SystemExit(1)
raise SystemExit('Could not find the ssh executable, is it in your PATH?')
if __name__ == '__main__':
main(sys.argv)
elif __name__ == '__completer__':
setattr(sys, 'kitten_completer', complete)
elif __name__ == '__conf__':
from .options.definition import definition
sys.options_definition = definition # type: ignore