Also transfer env vars when cloning over ssh kitten

This commit is contained in:
Kovid Goyal 2022-04-13 20:08:06 +05:30
parent eb024fa40a
commit a1bfcd9fc5
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 89 additions and 47 deletions

View File

@ -24,8 +24,8 @@ from contextlib import contextmanager, suppress
from getpass import getuser
from select import select
from typing import (
Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set, Tuple,
Union
Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set,
Tuple, Union, cast
)
from kittens.tui.operations import restore_colors, save_colors
@ -36,7 +36,10 @@ from kitty.constants import (
from kitty.options.types import Options
from kitty.shm import SharedMemory
from kitty.types import run_once
from kitty.utils import SSHConnectionData, set_echo as turn_off_echo, expandvars, resolve_abs_or_config_path
from kitty.utils import (
SSHConnectionData, expandvars, resolve_abs_or_config_path,
set_echo as turn_off_echo
)
from .completion import complete, ssh_options
from .config import init_config
@ -58,16 +61,45 @@ def is_kitten_cmdline(q: List[str]) -> bool:
return q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh']
def set_cwd_in_cmdline(cwd: str, argv: List[str]) -> None:
def patch_cmdline(key: str, val: str, argv: List[str]) -> None:
for i, arg in enumerate(tuple(argv)):
if arg.startswith('--kitten=cwd'):
argv[i] = f'--kitten=cwd={cwd}'
if arg.startswith(f'--kitten={key}='):
argv[i] = f'--kitten={key}={val}'
return
elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith('cwd=') or arg.startswith('cwd ')):
argv[i] = cwd
elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith(f'{key}=') or arg.startswith(f'{key} ')):
argv[i] = val
return
idx = argv.index('ssh')
argv.insert(idx + 1, f'--kitten=cwd={cwd}')
argv.insert(idx + 1, f'--kitten={key}={val}')
def set_cwd_in_cmdline(cwd: str, argv: List[str]) -> None:
patch_cmdline('cwd', cwd, argv)
def create_shared_memory(data: Any, prefix: str) -> str:
from kitty.shm import SharedMemory
db = json.dumps(data).encode('utf-8')
with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, mode=stat.S_IREAD, prefix=prefix) as shm:
shm.write_data_with_size(db)
shm.flush()
atexit.register(shm.unlink)
return shm.name
def read_data_from_shared_memory(shm_name: str) -> Any:
with SharedMemory(shm_name, readonly=True) as shm:
shm.unlink()
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile')
mode = stat.S_IMODE(shm.stats.st_mode)
if mode != stat.S_IREAD:
raise ValueError('Incorrect permissions on pwfile')
return json.loads(shm.read_data_with_size())
def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None:
patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)
# See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
@ -105,7 +137,7 @@ def kitty_opts() -> Options:
return create_default_opts()
def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz') -> bytes:
def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz', literal_env: Dict[str, str] = {}) -> bytes:
def normalize_tarinfo(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo:
tarinfo.uname = tarinfo.gname = ''
@ -146,11 +178,12 @@ def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: st
'TERM': os.environ.get('TERM') or kitty_opts().term,
'COLORTERM': 'truecolor',
}
env.update(literal_env)
env.update(ssh_opts.env)
for q in ('KITTY_WINDOW_ID', 'WINDOWID'):
val = os.environ.get(q)
if val is not None:
env[q] = val
env.update(ssh_opts.env)
env['KITTY_SHELL_INTEGRATION'] = ksi or DELETE_ENV_VAR
env['KITTY_SSH_KITTEN_DATA_DIR'] = ssh_opts.remote_dir
if ssh_opts.login_shell:
@ -196,18 +229,11 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
yield b'invalid ssh data request message\n'
else:
try:
with SharedMemory(pwfilename, readonly=True) as shm:
shm.unlink()
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile')
mode = stat.S_IMODE(shm.stats.st_mode)
if mode != stat.S_IREAD:
raise ValueError('Incorrect permissions on pwfile')
env_data = json.loads(shm.read_data_with_size())
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
raise ValueError('Incorrect request id')
env_data = read_data_from_shared_memory(pwfilename)
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
raise ValueError('Incorrect request id')
except Exception as e:
traceback.print_exc()
yield f'{e}\n'.encode('utf-8')
@ -267,8 +293,8 @@ def prepare_export_home_cmd(ssh_opts: SSHOptions, is_python: bool) -> str:
def bootstrap_script(
ssh_opts: SSHOptions, script_type: str = 'sh', remote_args: Sequence[str] = (),
test_script: str = '', request_id: Optional[str] = None, cli_hostname: str = '', cli_uname: str = '',
request_data: bool = False, echo_on: bool = True
) -> Tuple[str, Dict[str, str], SharedMemory]:
request_data: bool = False, echo_on: bool = True, literal_env: Dict[str, str] = {}
) -> Tuple[str, Dict[str, str], str]:
if request_id is None:
request_id = os.environ['KITTY_PID'] + '-' + os.environ['KITTY_WINDOW_ID']
is_python = script_type == 'py'
@ -277,14 +303,10 @@ def bootstrap_script(
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
ans = f.read()
pw = secrets.token_hex()
tfd = standard_b64encode(make_tarfile(ssh_opts, dict(os.environ), 'gz' if script_type == 'sh' else 'bz2')).decode('ascii')
tfd = standard_b64encode(make_tarfile(ssh_opts, dict(os.environ), 'gz' if script_type == 'sh' else 'bz2', literal_env=literal_env)).decode('ascii')
data = {'pw': pw, 'opts': ssh_opts._asdict(), 'hostname': cli_hostname, 'uname': cli_uname, 'tarfile': tfd}
db = json.dumps(data)
with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, mode=stat.S_IREAD, prefix=f'kssh-{os.getpid()}-') as shm:
shm.write_data_with_size(db)
shm.flush()
atexit.register(shm.unlink)
sensitive_data = {'REQUEST_ID': request_id, 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm.name}
shm_name = create_shared_memory(data, prefix=f'kssh-{os.getpid()}-')
sensitive_data = {'REQUEST_ID': request_id, 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm_name}
replacements = {
'EXPORT_HOME_CMD': export_home_cmd,
'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
@ -294,7 +316,7 @@ def bootstrap_script(
if request_data:
sd.update(sensitive_data)
replacements.update(sensitive_data)
return prepare_script(ans, sd, script_type), replacements, shm
return prepare_script(ans, sd, script_type), replacements, shm_name
def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
@ -499,15 +521,15 @@ def wrap_bootstrap_script(sh_script: str, interpreter: str) -> List[str]:
def get_remote_command(
remote_args: List[str], ssh_opts: SSHOptions, cli_hostname: str = '', cli_uname: str = '',
echo_on: bool = True, request_data: bool = False
echo_on: bool = True, request_data: bool = False, literal_env: Dict[str, str] = {}
) -> Tuple[List[str], Dict[str, str], str]:
interpreter = ssh_opts.interpreter
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
sh_script, replacements, shm = bootstrap_script(
ssh_opts, script_type='py' if is_python else 'sh', remote_args=remote_args,
sh_script, replacements, shm_name = bootstrap_script(
ssh_opts, script_type='py' if is_python else 'sh', remote_args=remote_args, literal_env=literal_env,
cli_hostname=cli_hostname, cli_uname=cli_uname, echo_on=echo_on, request_data=request_data)
return wrap_bootstrap_script(sh_script, interpreter), replacements, shm.name
return wrap_bootstrap_script(sh_script, interpreter), replacements, shm_name
def connection_sharing_args(opts: SSHOptions, kitty_pid: int) -> List[str]:
@ -602,7 +624,9 @@ def drain_potential_tty_garbage(p: 'subprocess.Popen[bytes]', data_request: str)
def change_colors(color_scheme: str) -> bool:
if not color_scheme:
return False
from kittens.themes.collection import load_themes, NoCacheFound, text_as_opts
from kittens.themes.collection import (
NoCacheFound, load_themes, text_as_opts
)
from kittens.themes.main import colors_as_escape_codes
if color_scheme.endswith('.conf'):
conf_file = resolve_abs_or_config_path(color_scheme)
@ -627,6 +651,14 @@ def change_colors(color_scheme: str) -> bool:
return True
def add_cloned_env(shm_name: str) -> Dict[str, str]:
try:
return cast(Dict[str, str], read_data_from_shared_memory(shm_name))
except FileNotFoundError:
pass
return {}
def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple[str, ...]) -> NoReturn:
cmd = [ssh_exe()] + ssh_args
hostname, remote_args = server_args[0], server_args[1:]
@ -646,12 +678,16 @@ def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple
else:
hostname_for_match = hostname
hostname_for_match = hostname_for_match.split('@', 1)[-1].split(':', 1)[0]
overrides = []
overrides: List[str] = []
literal_env: Dict[str, str] = {}
pat = re.compile(r'^([a-zA-Z0-9_]+)[ \t]*=')
for i, a in enumerate(found_extra_args):
if i % 2 == 1:
aq = pat.sub(r'\1 ', a.lstrip())
if aq.split(maxsplit=1)[0] != 'hostname':
key = aq.split(maxsplit=1)[0]
if key == 'clone_env':
literal_env = add_cloned_env(aq.split(maxsplit=1)[1])
elif key != 'hostname':
overrides.append(aq)
if overrides:
overrides.insert(0, f'hostname {uname}@{hostname_for_match}')
@ -677,7 +713,7 @@ def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple
need_to_request_data = False
with restore_terminal_state() as echo_on:
rcmd, replacements, shm_name = get_remote_command(
remote_args, host_opts, hostname_for_match, uname, echo_on, request_data=need_to_request_data)
remote_args, host_opts, hostname_for_match, uname, echo_on, request_data=need_to_request_data, literal_env=literal_env)
cmd += rcmd
colors_changed = change_colors(host_opts.color_scheme)
try:

View File

@ -361,7 +361,7 @@ def launch(
if opts.os_window_title == 'current':
tm = boss.active_tab_manager
opts.os_window_title = get_os_window_title(tm.os_window_id) if tm else None
if base_env:
if base_env is not None:
env = base_env.copy()
env.update(get_env(opts))
else:
@ -501,7 +501,7 @@ def clone_and_launch(msg: str, window: Window) -> None:
from .child import cmdline_of_process
args = []
env: Dict[str, str] = {}
env: Optional[Dict[str, str]] = None
cwd = ''
pid = -1
@ -514,6 +514,7 @@ def clone_and_launch(msg: str, window: Window) -> None:
if k == 'a':
args.append(v)
elif k == 'env':
env = {}
for line in v.split('\0'):
if line:
try:
@ -535,9 +536,12 @@ def clone_and_launch(msg: str, window: Window) -> None:
cmdline = list(window.child.argv)
ssh_kitten_cmdline = window.ssh_kitten_cmdline()
if ssh_kitten_cmdline:
from kittens.ssh.main import set_cwd_in_cmdline
from kittens.ssh.main import set_cwd_in_cmdline, set_env_in_cmdline
cmdline[:] = ssh_kitten_cmdline
if opts.cwd:
set_cwd_in_cmdline(opts.cwd, cmdline)
opts.cwd = None
if env:
set_env_in_cmdline(env, cmdline)
env = None
launch(get_boss(), opts, cmdline, base_env=env)

View File

@ -7,6 +7,7 @@ import os
import shutil
import tempfile
from functools import lru_cache
from contextlib import suppress
from kittens.ssh.config import load_config
from kittens.ssh.main import (
@ -16,7 +17,7 @@ from kittens.ssh.options.types import Options as SSHOptions
from kittens.ssh.options.utils import DELETE_ENV_VAR
from kittens.transfer.utils import set_paths
from kitty.constants import is_macos, runtime_dir
from kitty.fast_data_types import CURSOR_BEAM
from kitty.fast_data_types import CURSOR_BEAM, shm_unlink
from kitty.utils import SSHConnectionData
from . import BaseTest
@ -243,7 +244,7 @@ copy --exclude */w.* d1
else:
test_script = f'echo "UNTAR_DONE"; {test_script}'
ssh_opts['shell_integration'] = SHELL_INTEGRATION_VALUE or 'disabled'
script, replacements, shm = bootstrap_script(
script, replacements, shm_name = bootstrap_script(
SSHOptions(ssh_opts), script_type='py' if 'python' in sh else 'sh', request_id="testing", test_script=test_script,
request_data=True
)
@ -275,4 +276,5 @@ copy --exclude */w.* d1
pty.wait_till(lambda: pty.screen.cursor.shape == CURSOR_BEAM)
return pty
finally:
shm.unlink()
with suppress(FileNotFoundError):
shm_unlink(shm_name)