diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index 7ef2118ad..9378611d9 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -17,7 +17,8 @@ from base64 import standard_b64decode, standard_b64encode from contextlib import suppress from getpass import getuser from typing import ( - Any, Callable, Dict, Iterator, List, NoReturn, Optional, Set, Tuple, Union + Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set, + Tuple, Union ) from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir @@ -156,7 +157,6 @@ def get_ssh_data(msg: str) -> Iterator[bytes]: traceback.print_exc() yield fmt_prefix('!error while gathering ssh data') else: - from base64 import standard_b64encode encoded_data = standard_b64encode(data) yield fmt_prefix(len(encoded_data)) yield encoded_data @@ -177,7 +177,18 @@ def prepare_script(ans: str, replacements: Dict[str, str]) -> str: return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans) -def bootstrap_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str: +def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str: + # ssh simply concatenates multiple commands using a space see + # line 1129 of ssh.c and on the remote side sshd.c runs the + # concatenated command as shell -c cmd + if is_python: + return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii') + args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args) + return f"""exec "$login_shell" -c '{args}'""" + + +def bootstrap_script(script_type: str = 'sh', remote_args: Sequence[str] = (), ssh_opts_dict: Dict[str, Dict[str, Any]] = {}, test_script: str = '') -> str: + exec_cmd = prepare_exec_cmd(remote_args, script_type == 'py') if remote_args else '' with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f: ans = f.read() pw = uuid4() @@ -185,12 +196,12 @@ def bootstrap_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict} tf.write(json.dumps(data).encode('utf-8')) atexit.register(safe_remove, tf.name) - replacements = {'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd} + replacements = {'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script} return prepare_script(ans, replacements) -def load_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str: - return bootstrap_script(script_type, exec_cmd, ssh_opts_dict=ssh_opts_dict) +def load_script(script_type: str = 'sh', remote_args: Sequence[str] = (), ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str: + return bootstrap_script(script_type, remote_args, ssh_opts_dict=ssh_opts_dict) def get_ssh_cli() -> Tuple[Set[str], Set[str]]: @@ -361,18 +372,8 @@ def get_remote_command( remote_args: List[str], hostname: str = 'localhost', interpreter: str = 'sh', ssh_opts_dict: Dict[str, Dict[str, Any]] = {} ) -> List[str]: - command_to_execute = '' is_python = 'python' in interpreter.lower() - if remote_args: - # ssh simply concatenates multiple commands using a space see - # line 1129 of ssh.c and on the remote side sshd.c runs the - # concatenated command as shell -c cmd - if is_python: - command_to_execute = standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii') - else: - args = [c.replace("'", """'"'"'""") for c in remote_args] - command_to_execute = "exec \"$login_shell\" -c '{}'".format(' '.join(args)) - sh_script = load_script(script_type='py' if is_python else 'sh', exec_cmd=command_to_execute, ssh_opts_dict=ssh_opts_dict) + sh_script = load_script(script_type='py' if is_python else 'sh', remote_args=remote_args, ssh_opts_dict=ssh_opts_dict) return [f'{interpreter} -c {shlex.quote(sh_script)}'] diff --git a/kitty_tests/ssh.py b/kitty_tests/ssh.py index c53f34047..4a5dacba8 100644 --- a/kitty_tests/ssh.py +++ b/kitty_tests/ssh.py @@ -86,7 +86,8 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77) @property @lru_cache() def all_possible_sh(self): - return tuple(filter(shutil.which, ('dash', 'zsh', 'bash', 'posh', 'sh'))) + python = 'python3' if shutil.which('python3') else 'python' + return tuple(filter(shutil.which, ('dash', 'zsh', 'bash', 'posh', 'sh', python))) def test_ssh_copy(self): simple_data = 'rkjlhfwf9whoaa' @@ -111,7 +112,7 @@ copy --exclude */w.* d1 ''' copy = load_config(overrides=filter(None, conf.splitlines()))['*'].copy self.check_bootstrap( - sh, remote_home, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='', + sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', ssh_opts={'copy': copy} ) tname = '.terminfo' @@ -141,7 +142,7 @@ copy --exclude */w.* d1 for sh in self.all_possible_sh: with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir: pty = self.check_bootstrap( - sh, tdir, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='', + sh, tdir, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', ssh_opts={'env': { 'TSET': 'set-works', 'COLORTERM': DELETE_ENV_VAR, @@ -151,10 +152,13 @@ copy --exclude */w.* d1 self.assertNotIn('COLORTERM', pty.screen_contents()) def test_ssh_leading_data(self): + script = 'echo "ld:$leading_data"; exit 0' for sh in self.all_possible_sh: + if 'python' in sh: + script = 'print("ld:" + leading_data.decode("ascii")); raise SystemExit(0);' with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir: pty = self.check_bootstrap( - sh, tdir, extra_exec='echo "ld:$leading_data"; exit 0', + sh, tdir, test_script=script, SHELL_INTEGRATION_VALUE='', pre_data='before_tarfile') self.ae(pty.screen_contents(), 'UNTAR_DONE\nld:before_tarfile') @@ -174,8 +178,10 @@ copy --exclude */w.* d1 expected_login_shell = pwd.getpwuid(os.geteuid()).pw_shell for m in methods: for sh in self.all_possible_sh: + if 'python' in sh: + continue with self.subTest(sh=sh, method=m), tempfile.TemporaryDirectory() as tdir: - pty = self.check_bootstrap(sh, tdir, extra_exec=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='') + pty = self.check_bootstrap(sh, tdir, test_script=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='') self.assertIn(expected_login_shell, pty.screen_contents()) def test_ssh_shell_integration(self): @@ -205,12 +211,19 @@ copy --exclude */w.* d1 pty.wait_till(lambda: len(pty.screen_contents().splitlines()) >= num_lines + 2) self.assertEqual(pty.screen.cursor.shape, 0) - def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', extra_exec='', pre_data='', ssh_opts=None): + def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None): ssh_opts = ssh_opts or {} if login_shell: ssh_opts['login_shell'] = login_shell + if 'python' in sh: + if test_script.startswith('env;'): + test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})' + test_script = f'print("UNTAR_DONE", flush=True); {test_script}' + else: + test_script = f'echo "UNTAR_DONE"; {test_script}' script = bootstrap_script( - exec_cmd=f'echo "UNTAR_DONE"; {extra_exec}', ssh_opts_dict={'*': ssh_opts}, + script_type='py' if 'python' in sh else 'sh', + test_script=test_script, ssh_opts_dict={'*': ssh_opts}, ) env = basic_shell_env(home_dir) # Avoid generating unneeded completion scripts diff --git a/shell-integration/ssh/bootstrap.py b/shell-integration/ssh/bootstrap.py index c8923e6ae..3abfef1ee 100644 --- a/shell-integration/ssh/bootstrap.py +++ b/shell-integration/ssh/bootstrap.py @@ -20,6 +20,7 @@ import tty tty_fd = -1 original_termios_state = None data_dir = shell_integration_dir = '' +leading_data = b'' HOME = os.path.expanduser('~') login_shell = pwd.getpwuid(os.geteuid()).pw_shell or 'sh' @@ -108,7 +109,7 @@ def compile_terminfo(base): def get_data(): - global data_dir, shell_integration_dir + global data_dir, shell_integration_dir, leading_data data = b'' while data.count(b'\036') < 2: @@ -219,6 +220,7 @@ def main(): if exec_cmd: cmd = base64.standard_b64decode(exec_cmd).decode('utf-8') os.execlp(login_shell, os.path.basename(login_shell), '-c', cmd) + TEST_SCRIPT # noqa if ksi and 'no-rc' not in ksi: exec_with_shell_integration() os.environ.pop('KITTY_SHELL_INTEGRATION', None) diff --git a/shell-integration/ssh/bootstrap.sh b/shell-integration/ssh/bootstrap.sh index 9f58f859f..e6500be0b 100644 --- a/shell-integration/ssh/bootstrap.sh +++ b/shell-integration/ssh/bootstrap.sh @@ -324,6 +324,9 @@ exec_with_shell_integration() { esac } +# Used in the tests +TEST_SCRIPT + case "$KITTY_SHELL_INTEGRATION" in ("") # only blanks or unset