When testing ssh kitten launch the bootscrapt script the same way sshd does it

This commit is contained in:
Kovid Goyal 2022-03-09 11:25:02 +05:30
parent 2341a27f63
commit 53b1607c4d
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 22 additions and 14 deletions

View File

@ -391,16 +391,7 @@ def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[L
return ssh_args, server_args, passthrough, tuple(found_extra_args) return ssh_args, server_args, passthrough, tuple(found_extra_args)
def get_remote_command( def wrap_bootstrap_script(sh_script: str, interpreter: str) -> List[str]:
remote_args: List[str], hostname: str = 'localhost', cli_hostname: str = '', cli_uname: str = '',
interpreter: str = 'sh',
ssh_opts_dict: Dict[str, Dict[str, Any]] = {}
) -> List[str]:
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
sh_script = bootstrap_script(
script_type='py' if is_python else 'sh', remote_args=remote_args, ssh_opts_dict=ssh_opts_dict,
cli_hostname=cli_hostname, cli_uname=cli_uname)
# sshd will execute the command we pass it by join all command line # sshd will execute the command we pass it by join all command line
# arguments with a space and passing it as a single argument to the users # arguments with a space and passing it as a single argument to the users
# login shell with -c. If the user has a non POSIX login shell it might # login shell with -c. If the user has a non POSIX login shell it might
@ -409,6 +400,8 @@ def get_remote_command(
# interpreter -c unwrap_script escaped_bootstrap_script # interpreter -c unwrap_script escaped_bootstrap_script
# The unwrap_script is responsible for unescaping the bootstrap script and # The unwrap_script is responsible for unescaping the bootstrap script and
# executing it. # executing it.
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
if is_python: if is_python:
es = standard_b64encode(sh_script.encode('utf-8')).decode('ascii') es = standard_b64encode(sh_script.encode('utf-8')).decode('ascii')
unwrap_script = '''"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"''' unwrap_script = '''"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"'''
@ -421,6 +414,19 @@ def get_remote_command(
return [interpreter, '-c', unwrap_script, es] return [interpreter, '-c', unwrap_script, es]
def get_remote_command(
remote_args: List[str], hostname: str = 'localhost', cli_hostname: str = '', cli_uname: str = '',
interpreter: str = 'sh',
ssh_opts_dict: Dict[str, Dict[str, Any]] = {}
) -> List[str]:
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
sh_script = bootstrap_script(
script_type='py' if is_python else 'sh', remote_args=remote_args, ssh_opts_dict=ssh_opts_dict,
cli_hostname=cli_hostname, cli_uname=cli_uname)
return wrap_bootstrap_script(sh_script, interpreter)
def main(args: List[str]) -> NoReturn: def main(args: List[str]) -> NoReturn:
args = args[1:] args = args[1:]
if args and args[0] == 'use-python': if args and args[0] == 'use-python':

View File

@ -4,13 +4,14 @@
import glob import glob
import os import os
import shlex
import shutil import shutil
import tempfile import tempfile
from functools import lru_cache from functools import lru_cache
from kittens.ssh.config import load_config, options_for_host from kittens.ssh.config import load_config, options_for_host
from kittens.ssh.main import bootstrap_script, get_connection_data from kittens.ssh.main import (
bootstrap_script, get_connection_data, wrap_bootstrap_script
)
from kittens.ssh.options.utils import DELETE_ENV_VAR from kittens.ssh.options.utils import DELETE_ENV_VAR
from kittens.transfer.utils import set_paths from kittens.transfer.utils import set_paths
from kitty.constants import is_macos from kitty.constants import is_macos
@ -217,7 +218,7 @@ copy --exclude */w.* d1
self.assertEqual(pty.screen.cursor.shape, 0) self.assertEqual(pty.screen.cursor.shape, 0)
self.assertNotIn(b'\x1b]133;', pty.received_bytes) self.assertNotIn(b'\x1b]133;', pty.received_bytes)
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None): def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None, launcher='sh'):
ssh_opts = ssh_opts or {} ssh_opts = ssh_opts or {}
if login_shell: if login_shell:
ssh_opts['login_shell'] = login_shell ssh_opts['login_shell'] = login_shell
@ -237,7 +238,8 @@ copy --exclude */w.* d1
# prevent newuser-install from running # prevent newuser-install from running
open(os.path.join(home_dir, '.zshrc'), 'w').close() open(os.path.join(home_dir, '.zshrc'), 'w').close()
options = {'shell_integration': shell_integration(SHELL_INTEGRATION_VALUE or 'disabled')} options = {'shell_integration': shell_integration(SHELL_INTEGRATION_VALUE or 'disabled')}
pty = self.create_pty(f'{sh} -c {shlex.quote(script)}', cwd=home_dir, env=env, options=options) cmd = wrap_bootstrap_script(script, sh)
pty = self.create_pty([launcher, '-c', ' '.join(cmd)], cwd=home_dir, env=env, options=options)
if pre_data: if pre_data:
pty.write_buf = pre_data.encode('utf-8') pty.write_buf = pre_data.encode('utf-8')
del script del script