bootstrap.py is now tested the same as bootsstrap.sh

This commit is contained in:
Kovid Goyal 2022-03-06 14:26:15 +05:30
parent ec782d3296
commit 7f9fec061a
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 44 additions and 25 deletions

View File

@ -17,7 +17,8 @@ from base64 import standard_b64decode, standard_b64encode
from contextlib import suppress from contextlib import suppress
from getpass import getuser from getpass import getuser
from typing import ( from typing import (
Any, Callable, Dict, Iterator, List, NoReturn, Optional, Set, Tuple, Union Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set,
Tuple, Union
) )
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
@ -156,7 +157,6 @@ def get_ssh_data(msg: str) -> Iterator[bytes]:
traceback.print_exc() traceback.print_exc()
yield fmt_prefix('!error while gathering ssh data') yield fmt_prefix('!error while gathering ssh data')
else: else:
from base64 import standard_b64encode
encoded_data = standard_b64encode(data) encoded_data = standard_b64encode(data)
yield fmt_prefix(len(encoded_data)) yield fmt_prefix(len(encoded_data))
yield encoded_data yield encoded_data
@ -177,7 +177,18 @@ def prepare_script(ans: str, replacements: Dict[str, str]) -> str:
return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans) return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans)
def bootstrap_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str: def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str:
# ssh simply concatenates multiple commands using a space see
# line 1129 of ssh.c and on the remote side sshd.c runs the
# concatenated command as shell -c cmd
if is_python:
return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args)
return f"""exec "$login_shell" -c '{args}'"""
def bootstrap_script(script_type: str = 'sh', remote_args: Sequence[str] = (), ssh_opts_dict: Dict[str, Dict[str, Any]] = {}, test_script: str = '') -> str:
exec_cmd = prepare_exec_cmd(remote_args, script_type == 'py') if remote_args else ''
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f: with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
ans = f.read() ans = f.read()
pw = uuid4() pw = uuid4()
@ -185,12 +196,12 @@ def bootstrap_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict:
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict} data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict}
tf.write(json.dumps(data).encode('utf-8')) tf.write(json.dumps(data).encode('utf-8'))
atexit.register(safe_remove, tf.name) atexit.register(safe_remove, tf.name)
replacements = {'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd} replacements = {'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script}
return prepare_script(ans, replacements) return prepare_script(ans, replacements)
def load_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str: def load_script(script_type: str = 'sh', remote_args: Sequence[str] = (), ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str:
return bootstrap_script(script_type, exec_cmd, ssh_opts_dict=ssh_opts_dict) return bootstrap_script(script_type, remote_args, ssh_opts_dict=ssh_opts_dict)
def get_ssh_cli() -> Tuple[Set[str], Set[str]]: def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
@ -361,18 +372,8 @@ def get_remote_command(
remote_args: List[str], hostname: str = 'localhost', interpreter: str = 'sh', remote_args: List[str], hostname: str = 'localhost', interpreter: str = 'sh',
ssh_opts_dict: Dict[str, Dict[str, Any]] = {} ssh_opts_dict: Dict[str, Dict[str, Any]] = {}
) -> List[str]: ) -> List[str]:
command_to_execute = ''
is_python = 'python' in interpreter.lower() is_python = 'python' in interpreter.lower()
if remote_args: sh_script = load_script(script_type='py' if is_python else 'sh', remote_args=remote_args, ssh_opts_dict=ssh_opts_dict)
# ssh simply concatenates multiple commands using a space see
# line 1129 of ssh.c and on the remote side sshd.c runs the
# concatenated command as shell -c cmd
if is_python:
command_to_execute = standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
else:
args = [c.replace("'", """'"'"'""") for c in remote_args]
command_to_execute = "exec \"$login_shell\" -c '{}'".format(' '.join(args))
sh_script = load_script(script_type='py' if is_python else 'sh', exec_cmd=command_to_execute, ssh_opts_dict=ssh_opts_dict)
return [f'{interpreter} -c {shlex.quote(sh_script)}'] return [f'{interpreter} -c {shlex.quote(sh_script)}']

View File

@ -86,7 +86,8 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
@property @property
@lru_cache() @lru_cache()
def all_possible_sh(self): def all_possible_sh(self):
return tuple(filter(shutil.which, ('dash', 'zsh', 'bash', 'posh', 'sh'))) python = 'python3' if shutil.which('python3') else 'python'
return tuple(filter(shutil.which, ('dash', 'zsh', 'bash', 'posh', 'sh', python)))
def test_ssh_copy(self): def test_ssh_copy(self):
simple_data = 'rkjlhfwf9whoaa' simple_data = 'rkjlhfwf9whoaa'
@ -111,7 +112,7 @@ copy --exclude */w.* d1
''' '''
copy = load_config(overrides=filter(None, conf.splitlines()))['*'].copy copy = load_config(overrides=filter(None, conf.splitlines()))['*'].copy
self.check_bootstrap( self.check_bootstrap(
sh, remote_home, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='', sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='',
ssh_opts={'copy': copy} ssh_opts={'copy': copy}
) )
tname = '.terminfo' tname = '.terminfo'
@ -141,7 +142,7 @@ copy --exclude */w.* d1
for sh in self.all_possible_sh: for sh in self.all_possible_sh:
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir: with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
pty = self.check_bootstrap( pty = self.check_bootstrap(
sh, tdir, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='', sh, tdir, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='',
ssh_opts={'env': { ssh_opts={'env': {
'TSET': 'set-works', 'TSET': 'set-works',
'COLORTERM': DELETE_ENV_VAR, 'COLORTERM': DELETE_ENV_VAR,
@ -151,10 +152,13 @@ copy --exclude */w.* d1
self.assertNotIn('COLORTERM', pty.screen_contents()) self.assertNotIn('COLORTERM', pty.screen_contents())
def test_ssh_leading_data(self): def test_ssh_leading_data(self):
script = 'echo "ld:$leading_data"; exit 0'
for sh in self.all_possible_sh: for sh in self.all_possible_sh:
if 'python' in sh:
script = 'print("ld:" + leading_data.decode("ascii")); raise SystemExit(0);'
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir: with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
pty = self.check_bootstrap( pty = self.check_bootstrap(
sh, tdir, extra_exec='echo "ld:$leading_data"; exit 0', sh, tdir, test_script=script,
SHELL_INTEGRATION_VALUE='', pre_data='before_tarfile') SHELL_INTEGRATION_VALUE='', pre_data='before_tarfile')
self.ae(pty.screen_contents(), 'UNTAR_DONE\nld:before_tarfile') self.ae(pty.screen_contents(), 'UNTAR_DONE\nld:before_tarfile')
@ -174,8 +178,10 @@ copy --exclude */w.* d1
expected_login_shell = pwd.getpwuid(os.geteuid()).pw_shell expected_login_shell = pwd.getpwuid(os.geteuid()).pw_shell
for m in methods: for m in methods:
for sh in self.all_possible_sh: for sh in self.all_possible_sh:
if 'python' in sh:
continue
with self.subTest(sh=sh, method=m), tempfile.TemporaryDirectory() as tdir: with self.subTest(sh=sh, method=m), tempfile.TemporaryDirectory() as tdir:
pty = self.check_bootstrap(sh, tdir, extra_exec=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='') pty = self.check_bootstrap(sh, tdir, test_script=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='')
self.assertIn(expected_login_shell, pty.screen_contents()) self.assertIn(expected_login_shell, pty.screen_contents())
def test_ssh_shell_integration(self): def test_ssh_shell_integration(self):
@ -205,12 +211,19 @@ copy --exclude */w.* d1
pty.wait_till(lambda: len(pty.screen_contents().splitlines()) >= num_lines + 2) pty.wait_till(lambda: len(pty.screen_contents().splitlines()) >= num_lines + 2)
self.assertEqual(pty.screen.cursor.shape, 0) self.assertEqual(pty.screen.cursor.shape, 0)
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', extra_exec='', pre_data='', ssh_opts=None): def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None):
ssh_opts = ssh_opts or {} ssh_opts = ssh_opts or {}
if login_shell: if login_shell:
ssh_opts['login_shell'] = login_shell ssh_opts['login_shell'] = login_shell
if 'python' in sh:
if test_script.startswith('env;'):
test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})'
test_script = f'print("UNTAR_DONE", flush=True); {test_script}'
else:
test_script = f'echo "UNTAR_DONE"; {test_script}'
script = bootstrap_script( script = bootstrap_script(
exec_cmd=f'echo "UNTAR_DONE"; {extra_exec}', ssh_opts_dict={'*': ssh_opts}, script_type='py' if 'python' in sh else 'sh',
test_script=test_script, ssh_opts_dict={'*': ssh_opts},
) )
env = basic_shell_env(home_dir) env = basic_shell_env(home_dir)
# Avoid generating unneeded completion scripts # Avoid generating unneeded completion scripts

View File

@ -20,6 +20,7 @@ import tty
tty_fd = -1 tty_fd = -1
original_termios_state = None original_termios_state = None
data_dir = shell_integration_dir = '' data_dir = shell_integration_dir = ''
leading_data = b''
HOME = os.path.expanduser('~') HOME = os.path.expanduser('~')
login_shell = pwd.getpwuid(os.geteuid()).pw_shell or 'sh' login_shell = pwd.getpwuid(os.geteuid()).pw_shell or 'sh'
@ -108,7 +109,7 @@ def compile_terminfo(base):
def get_data(): def get_data():
global data_dir, shell_integration_dir global data_dir, shell_integration_dir, leading_data
data = b'' data = b''
while data.count(b'\036') < 2: while data.count(b'\036') < 2:
@ -219,6 +220,7 @@ def main():
if exec_cmd: if exec_cmd:
cmd = base64.standard_b64decode(exec_cmd).decode('utf-8') cmd = base64.standard_b64decode(exec_cmd).decode('utf-8')
os.execlp(login_shell, os.path.basename(login_shell), '-c', cmd) os.execlp(login_shell, os.path.basename(login_shell), '-c', cmd)
TEST_SCRIPT # noqa
if ksi and 'no-rc' not in ksi: if ksi and 'no-rc' not in ksi:
exec_with_shell_integration() exec_with_shell_integration()
os.environ.pop('KITTY_SHELL_INTEGRATION', None) os.environ.pop('KITTY_SHELL_INTEGRATION', None)

View File

@ -324,6 +324,9 @@ exec_with_shell_integration() {
esac esac
} }
# Used in the tests
TEST_SCRIPT
case "$KITTY_SHELL_INTEGRATION" in case "$KITTY_SHELL_INTEGRATION" in
("") ("")
# only blanks or unset # only blanks or unset