bootstrap.py is now tested the same as bootsstrap.sh
This commit is contained in:
parent
ec782d3296
commit
7f9fec061a
@ -17,7 +17,8 @@ from base64 import standard_b64decode, standard_b64encode
|
||||
from contextlib import suppress
|
||||
from getpass import getuser
|
||||
from typing import (
|
||||
Any, Callable, Dict, Iterator, List, NoReturn, Optional, Set, Tuple, Union
|
||||
Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set,
|
||||
Tuple, Union
|
||||
)
|
||||
|
||||
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
|
||||
@ -156,7 +157,6 @@ def get_ssh_data(msg: str) -> Iterator[bytes]:
|
||||
traceback.print_exc()
|
||||
yield fmt_prefix('!error while gathering ssh data')
|
||||
else:
|
||||
from base64 import standard_b64encode
|
||||
encoded_data = standard_b64encode(data)
|
||||
yield fmt_prefix(len(encoded_data))
|
||||
yield encoded_data
|
||||
@ -177,7 +177,18 @@ def prepare_script(ans: str, replacements: Dict[str, str]) -> str:
|
||||
return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans)
|
||||
|
||||
|
||||
def bootstrap_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str:
|
||||
def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str:
|
||||
# ssh simply concatenates multiple commands using a space see
|
||||
# line 1129 of ssh.c and on the remote side sshd.c runs the
|
||||
# concatenated command as shell -c cmd
|
||||
if is_python:
|
||||
return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
|
||||
args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args)
|
||||
return f"""exec "$login_shell" -c '{args}'"""
|
||||
|
||||
|
||||
def bootstrap_script(script_type: str = 'sh', remote_args: Sequence[str] = (), ssh_opts_dict: Dict[str, Dict[str, Any]] = {}, test_script: str = '') -> str:
|
||||
exec_cmd = prepare_exec_cmd(remote_args, script_type == 'py') if remote_args else ''
|
||||
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
|
||||
ans = f.read()
|
||||
pw = uuid4()
|
||||
@ -185,12 +196,12 @@ def bootstrap_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict:
|
||||
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict}
|
||||
tf.write(json.dumps(data).encode('utf-8'))
|
||||
atexit.register(safe_remove, tf.name)
|
||||
replacements = {'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd}
|
||||
replacements = {'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script}
|
||||
return prepare_script(ans, replacements)
|
||||
|
||||
|
||||
def load_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str:
|
||||
return bootstrap_script(script_type, exec_cmd, ssh_opts_dict=ssh_opts_dict)
|
||||
def load_script(script_type: str = 'sh', remote_args: Sequence[str] = (), ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str:
|
||||
return bootstrap_script(script_type, remote_args, ssh_opts_dict=ssh_opts_dict)
|
||||
|
||||
|
||||
def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
|
||||
@ -361,18 +372,8 @@ def get_remote_command(
|
||||
remote_args: List[str], hostname: str = 'localhost', interpreter: str = 'sh',
|
||||
ssh_opts_dict: Dict[str, Dict[str, Any]] = {}
|
||||
) -> List[str]:
|
||||
command_to_execute = ''
|
||||
is_python = 'python' in interpreter.lower()
|
||||
if remote_args:
|
||||
# ssh simply concatenates multiple commands using a space see
|
||||
# line 1129 of ssh.c and on the remote side sshd.c runs the
|
||||
# concatenated command as shell -c cmd
|
||||
if is_python:
|
||||
command_to_execute = standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
|
||||
else:
|
||||
args = [c.replace("'", """'"'"'""") for c in remote_args]
|
||||
command_to_execute = "exec \"$login_shell\" -c '{}'".format(' '.join(args))
|
||||
sh_script = load_script(script_type='py' if is_python else 'sh', exec_cmd=command_to_execute, ssh_opts_dict=ssh_opts_dict)
|
||||
sh_script = load_script(script_type='py' if is_python else 'sh', remote_args=remote_args, ssh_opts_dict=ssh_opts_dict)
|
||||
return [f'{interpreter} -c {shlex.quote(sh_script)}']
|
||||
|
||||
|
||||
|
||||
@ -86,7 +86,8 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
|
||||
@property
|
||||
@lru_cache()
|
||||
def all_possible_sh(self):
|
||||
return tuple(filter(shutil.which, ('dash', 'zsh', 'bash', 'posh', 'sh')))
|
||||
python = 'python3' if shutil.which('python3') else 'python'
|
||||
return tuple(filter(shutil.which, ('dash', 'zsh', 'bash', 'posh', 'sh', python)))
|
||||
|
||||
def test_ssh_copy(self):
|
||||
simple_data = 'rkjlhfwf9whoaa'
|
||||
@ -111,7 +112,7 @@ copy --exclude */w.* d1
|
||||
'''
|
||||
copy = load_config(overrides=filter(None, conf.splitlines()))['*'].copy
|
||||
self.check_bootstrap(
|
||||
sh, remote_home, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='',
|
||||
sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='',
|
||||
ssh_opts={'copy': copy}
|
||||
)
|
||||
tname = '.terminfo'
|
||||
@ -141,7 +142,7 @@ copy --exclude */w.* d1
|
||||
for sh in self.all_possible_sh:
|
||||
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
|
||||
pty = self.check_bootstrap(
|
||||
sh, tdir, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='',
|
||||
sh, tdir, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='',
|
||||
ssh_opts={'env': {
|
||||
'TSET': 'set-works',
|
||||
'COLORTERM': DELETE_ENV_VAR,
|
||||
@ -151,10 +152,13 @@ copy --exclude */w.* d1
|
||||
self.assertNotIn('COLORTERM', pty.screen_contents())
|
||||
|
||||
def test_ssh_leading_data(self):
|
||||
script = 'echo "ld:$leading_data"; exit 0'
|
||||
for sh in self.all_possible_sh:
|
||||
if 'python' in sh:
|
||||
script = 'print("ld:" + leading_data.decode("ascii")); raise SystemExit(0);'
|
||||
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
|
||||
pty = self.check_bootstrap(
|
||||
sh, tdir, extra_exec='echo "ld:$leading_data"; exit 0',
|
||||
sh, tdir, test_script=script,
|
||||
SHELL_INTEGRATION_VALUE='', pre_data='before_tarfile')
|
||||
self.ae(pty.screen_contents(), 'UNTAR_DONE\nld:before_tarfile')
|
||||
|
||||
@ -174,8 +178,10 @@ copy --exclude */w.* d1
|
||||
expected_login_shell = pwd.getpwuid(os.geteuid()).pw_shell
|
||||
for m in methods:
|
||||
for sh in self.all_possible_sh:
|
||||
if 'python' in sh:
|
||||
continue
|
||||
with self.subTest(sh=sh, method=m), tempfile.TemporaryDirectory() as tdir:
|
||||
pty = self.check_bootstrap(sh, tdir, extra_exec=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='')
|
||||
pty = self.check_bootstrap(sh, tdir, test_script=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='')
|
||||
self.assertIn(expected_login_shell, pty.screen_contents())
|
||||
|
||||
def test_ssh_shell_integration(self):
|
||||
@ -205,12 +211,19 @@ copy --exclude */w.* d1
|
||||
pty.wait_till(lambda: len(pty.screen_contents().splitlines()) >= num_lines + 2)
|
||||
self.assertEqual(pty.screen.cursor.shape, 0)
|
||||
|
||||
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', extra_exec='', pre_data='', ssh_opts=None):
|
||||
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None):
|
||||
ssh_opts = ssh_opts or {}
|
||||
if login_shell:
|
||||
ssh_opts['login_shell'] = login_shell
|
||||
if 'python' in sh:
|
||||
if test_script.startswith('env;'):
|
||||
test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})'
|
||||
test_script = f'print("UNTAR_DONE", flush=True); {test_script}'
|
||||
else:
|
||||
test_script = f'echo "UNTAR_DONE"; {test_script}'
|
||||
script = bootstrap_script(
|
||||
exec_cmd=f'echo "UNTAR_DONE"; {extra_exec}', ssh_opts_dict={'*': ssh_opts},
|
||||
script_type='py' if 'python' in sh else 'sh',
|
||||
test_script=test_script, ssh_opts_dict={'*': ssh_opts},
|
||||
)
|
||||
env = basic_shell_env(home_dir)
|
||||
# Avoid generating unneeded completion scripts
|
||||
|
||||
@ -20,6 +20,7 @@ import tty
|
||||
tty_fd = -1
|
||||
original_termios_state = None
|
||||
data_dir = shell_integration_dir = ''
|
||||
leading_data = b''
|
||||
HOME = os.path.expanduser('~')
|
||||
login_shell = pwd.getpwuid(os.geteuid()).pw_shell or 'sh'
|
||||
|
||||
@ -108,7 +109,7 @@ def compile_terminfo(base):
|
||||
|
||||
|
||||
def get_data():
|
||||
global data_dir, shell_integration_dir
|
||||
global data_dir, shell_integration_dir, leading_data
|
||||
data = b''
|
||||
|
||||
while data.count(b'\036') < 2:
|
||||
@ -219,6 +220,7 @@ def main():
|
||||
if exec_cmd:
|
||||
cmd = base64.standard_b64decode(exec_cmd).decode('utf-8')
|
||||
os.execlp(login_shell, os.path.basename(login_shell), '-c', cmd)
|
||||
TEST_SCRIPT # noqa
|
||||
if ksi and 'no-rc' not in ksi:
|
||||
exec_with_shell_integration()
|
||||
os.environ.pop('KITTY_SHELL_INTEGRATION', None)
|
||||
|
||||
@ -324,6 +324,9 @@ exec_with_shell_integration() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Used in the tests
|
||||
TEST_SCRIPT
|
||||
|
||||
case "$KITTY_SHELL_INTEGRATION" in
|
||||
("")
|
||||
# only blanks or unset
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user