bootstrap.py is now tested the same as bootsstrap.sh
This commit is contained in:
parent
ec782d3296
commit
7f9fec061a
@ -17,7 +17,8 @@ from base64 import standard_b64decode, standard_b64encode
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from getpass import getuser
|
from getpass import getuser
|
||||||
from typing import (
|
from typing import (
|
||||||
Any, Callable, Dict, Iterator, List, NoReturn, Optional, Set, Tuple, Union
|
Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set,
|
||||||
|
Tuple, Union
|
||||||
)
|
)
|
||||||
|
|
||||||
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
|
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
|
||||||
@ -156,7 +157,6 @@ def get_ssh_data(msg: str) -> Iterator[bytes]:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield fmt_prefix('!error while gathering ssh data')
|
yield fmt_prefix('!error while gathering ssh data')
|
||||||
else:
|
else:
|
||||||
from base64 import standard_b64encode
|
|
||||||
encoded_data = standard_b64encode(data)
|
encoded_data = standard_b64encode(data)
|
||||||
yield fmt_prefix(len(encoded_data))
|
yield fmt_prefix(len(encoded_data))
|
||||||
yield encoded_data
|
yield encoded_data
|
||||||
@ -177,7 +177,18 @@ def prepare_script(ans: str, replacements: Dict[str, str]) -> str:
|
|||||||
return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans)
|
return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans)
|
||||||
|
|
||||||
|
|
||||||
def bootstrap_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str:
|
def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str:
|
||||||
|
# ssh simply concatenates multiple commands using a space see
|
||||||
|
# line 1129 of ssh.c and on the remote side sshd.c runs the
|
||||||
|
# concatenated command as shell -c cmd
|
||||||
|
if is_python:
|
||||||
|
return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
|
||||||
|
args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args)
|
||||||
|
return f"""exec "$login_shell" -c '{args}'"""
|
||||||
|
|
||||||
|
|
||||||
|
def bootstrap_script(script_type: str = 'sh', remote_args: Sequence[str] = (), ssh_opts_dict: Dict[str, Dict[str, Any]] = {}, test_script: str = '') -> str:
|
||||||
|
exec_cmd = prepare_exec_cmd(remote_args, script_type == 'py') if remote_args else ''
|
||||||
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
|
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
|
||||||
ans = f.read()
|
ans = f.read()
|
||||||
pw = uuid4()
|
pw = uuid4()
|
||||||
@ -185,12 +196,12 @@ def bootstrap_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict:
|
|||||||
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict}
|
data = {'pw': pw, 'env': dict(os.environ), 'opts': ssh_opts_dict}
|
||||||
tf.write(json.dumps(data).encode('utf-8'))
|
tf.write(json.dumps(data).encode('utf-8'))
|
||||||
atexit.register(safe_remove, tf.name)
|
atexit.register(safe_remove, tf.name)
|
||||||
replacements = {'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd}
|
replacements = {'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': os.path.basename(tf.name), 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script}
|
||||||
return prepare_script(ans, replacements)
|
return prepare_script(ans, replacements)
|
||||||
|
|
||||||
|
|
||||||
def load_script(script_type: str = 'sh', exec_cmd: str = '', ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str:
|
def load_script(script_type: str = 'sh', remote_args: Sequence[str] = (), ssh_opts_dict: Dict[str, Dict[str, Any]] = {}) -> str:
|
||||||
return bootstrap_script(script_type, exec_cmd, ssh_opts_dict=ssh_opts_dict)
|
return bootstrap_script(script_type, remote_args, ssh_opts_dict=ssh_opts_dict)
|
||||||
|
|
||||||
|
|
||||||
def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
|
def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
|
||||||
@ -361,18 +372,8 @@ def get_remote_command(
|
|||||||
remote_args: List[str], hostname: str = 'localhost', interpreter: str = 'sh',
|
remote_args: List[str], hostname: str = 'localhost', interpreter: str = 'sh',
|
||||||
ssh_opts_dict: Dict[str, Dict[str, Any]] = {}
|
ssh_opts_dict: Dict[str, Dict[str, Any]] = {}
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
command_to_execute = ''
|
|
||||||
is_python = 'python' in interpreter.lower()
|
is_python = 'python' in interpreter.lower()
|
||||||
if remote_args:
|
sh_script = load_script(script_type='py' if is_python else 'sh', remote_args=remote_args, ssh_opts_dict=ssh_opts_dict)
|
||||||
# ssh simply concatenates multiple commands using a space see
|
|
||||||
# line 1129 of ssh.c and on the remote side sshd.c runs the
|
|
||||||
# concatenated command as shell -c cmd
|
|
||||||
if is_python:
|
|
||||||
command_to_execute = standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
|
|
||||||
else:
|
|
||||||
args = [c.replace("'", """'"'"'""") for c in remote_args]
|
|
||||||
command_to_execute = "exec \"$login_shell\" -c '{}'".format(' '.join(args))
|
|
||||||
sh_script = load_script(script_type='py' if is_python else 'sh', exec_cmd=command_to_execute, ssh_opts_dict=ssh_opts_dict)
|
|
||||||
return [f'{interpreter} -c {shlex.quote(sh_script)}']
|
return [f'{interpreter} -c {shlex.quote(sh_script)}']
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -86,7 +86,8 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
|
|||||||
@property
|
@property
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def all_possible_sh(self):
|
def all_possible_sh(self):
|
||||||
return tuple(filter(shutil.which, ('dash', 'zsh', 'bash', 'posh', 'sh')))
|
python = 'python3' if shutil.which('python3') else 'python'
|
||||||
|
return tuple(filter(shutil.which, ('dash', 'zsh', 'bash', 'posh', 'sh', python)))
|
||||||
|
|
||||||
def test_ssh_copy(self):
|
def test_ssh_copy(self):
|
||||||
simple_data = 'rkjlhfwf9whoaa'
|
simple_data = 'rkjlhfwf9whoaa'
|
||||||
@ -111,7 +112,7 @@ copy --exclude */w.* d1
|
|||||||
'''
|
'''
|
||||||
copy = load_config(overrides=filter(None, conf.splitlines()))['*'].copy
|
copy = load_config(overrides=filter(None, conf.splitlines()))['*'].copy
|
||||||
self.check_bootstrap(
|
self.check_bootstrap(
|
||||||
sh, remote_home, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='',
|
sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='',
|
||||||
ssh_opts={'copy': copy}
|
ssh_opts={'copy': copy}
|
||||||
)
|
)
|
||||||
tname = '.terminfo'
|
tname = '.terminfo'
|
||||||
@ -141,7 +142,7 @@ copy --exclude */w.* d1
|
|||||||
for sh in self.all_possible_sh:
|
for sh in self.all_possible_sh:
|
||||||
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
|
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
|
||||||
pty = self.check_bootstrap(
|
pty = self.check_bootstrap(
|
||||||
sh, tdir, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='',
|
sh, tdir, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='',
|
||||||
ssh_opts={'env': {
|
ssh_opts={'env': {
|
||||||
'TSET': 'set-works',
|
'TSET': 'set-works',
|
||||||
'COLORTERM': DELETE_ENV_VAR,
|
'COLORTERM': DELETE_ENV_VAR,
|
||||||
@ -151,10 +152,13 @@ copy --exclude */w.* d1
|
|||||||
self.assertNotIn('COLORTERM', pty.screen_contents())
|
self.assertNotIn('COLORTERM', pty.screen_contents())
|
||||||
|
|
||||||
def test_ssh_leading_data(self):
|
def test_ssh_leading_data(self):
|
||||||
|
script = 'echo "ld:$leading_data"; exit 0'
|
||||||
for sh in self.all_possible_sh:
|
for sh in self.all_possible_sh:
|
||||||
|
if 'python' in sh:
|
||||||
|
script = 'print("ld:" + leading_data.decode("ascii")); raise SystemExit(0);'
|
||||||
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
|
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
|
||||||
pty = self.check_bootstrap(
|
pty = self.check_bootstrap(
|
||||||
sh, tdir, extra_exec='echo "ld:$leading_data"; exit 0',
|
sh, tdir, test_script=script,
|
||||||
SHELL_INTEGRATION_VALUE='', pre_data='before_tarfile')
|
SHELL_INTEGRATION_VALUE='', pre_data='before_tarfile')
|
||||||
self.ae(pty.screen_contents(), 'UNTAR_DONE\nld:before_tarfile')
|
self.ae(pty.screen_contents(), 'UNTAR_DONE\nld:before_tarfile')
|
||||||
|
|
||||||
@ -174,8 +178,10 @@ copy --exclude */w.* d1
|
|||||||
expected_login_shell = pwd.getpwuid(os.geteuid()).pw_shell
|
expected_login_shell = pwd.getpwuid(os.geteuid()).pw_shell
|
||||||
for m in methods:
|
for m in methods:
|
||||||
for sh in self.all_possible_sh:
|
for sh in self.all_possible_sh:
|
||||||
|
if 'python' in sh:
|
||||||
|
continue
|
||||||
with self.subTest(sh=sh, method=m), tempfile.TemporaryDirectory() as tdir:
|
with self.subTest(sh=sh, method=m), tempfile.TemporaryDirectory() as tdir:
|
||||||
pty = self.check_bootstrap(sh, tdir, extra_exec=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='')
|
pty = self.check_bootstrap(sh, tdir, test_script=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='')
|
||||||
self.assertIn(expected_login_shell, pty.screen_contents())
|
self.assertIn(expected_login_shell, pty.screen_contents())
|
||||||
|
|
||||||
def test_ssh_shell_integration(self):
|
def test_ssh_shell_integration(self):
|
||||||
@ -205,12 +211,19 @@ copy --exclude */w.* d1
|
|||||||
pty.wait_till(lambda: len(pty.screen_contents().splitlines()) >= num_lines + 2)
|
pty.wait_till(lambda: len(pty.screen_contents().splitlines()) >= num_lines + 2)
|
||||||
self.assertEqual(pty.screen.cursor.shape, 0)
|
self.assertEqual(pty.screen.cursor.shape, 0)
|
||||||
|
|
||||||
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', extra_exec='', pre_data='', ssh_opts=None):
|
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None):
|
||||||
ssh_opts = ssh_opts or {}
|
ssh_opts = ssh_opts or {}
|
||||||
if login_shell:
|
if login_shell:
|
||||||
ssh_opts['login_shell'] = login_shell
|
ssh_opts['login_shell'] = login_shell
|
||||||
|
if 'python' in sh:
|
||||||
|
if test_script.startswith('env;'):
|
||||||
|
test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})'
|
||||||
|
test_script = f'print("UNTAR_DONE", flush=True); {test_script}'
|
||||||
|
else:
|
||||||
|
test_script = f'echo "UNTAR_DONE"; {test_script}'
|
||||||
script = bootstrap_script(
|
script = bootstrap_script(
|
||||||
exec_cmd=f'echo "UNTAR_DONE"; {extra_exec}', ssh_opts_dict={'*': ssh_opts},
|
script_type='py' if 'python' in sh else 'sh',
|
||||||
|
test_script=test_script, ssh_opts_dict={'*': ssh_opts},
|
||||||
)
|
)
|
||||||
env = basic_shell_env(home_dir)
|
env = basic_shell_env(home_dir)
|
||||||
# Avoid generating unneeded completion scripts
|
# Avoid generating unneeded completion scripts
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import tty
|
|||||||
tty_fd = -1
|
tty_fd = -1
|
||||||
original_termios_state = None
|
original_termios_state = None
|
||||||
data_dir = shell_integration_dir = ''
|
data_dir = shell_integration_dir = ''
|
||||||
|
leading_data = b''
|
||||||
HOME = os.path.expanduser('~')
|
HOME = os.path.expanduser('~')
|
||||||
login_shell = pwd.getpwuid(os.geteuid()).pw_shell or 'sh'
|
login_shell = pwd.getpwuid(os.geteuid()).pw_shell or 'sh'
|
||||||
|
|
||||||
@ -108,7 +109,7 @@ def compile_terminfo(base):
|
|||||||
|
|
||||||
|
|
||||||
def get_data():
|
def get_data():
|
||||||
global data_dir, shell_integration_dir
|
global data_dir, shell_integration_dir, leading_data
|
||||||
data = b''
|
data = b''
|
||||||
|
|
||||||
while data.count(b'\036') < 2:
|
while data.count(b'\036') < 2:
|
||||||
@ -219,6 +220,7 @@ def main():
|
|||||||
if exec_cmd:
|
if exec_cmd:
|
||||||
cmd = base64.standard_b64decode(exec_cmd).decode('utf-8')
|
cmd = base64.standard_b64decode(exec_cmd).decode('utf-8')
|
||||||
os.execlp(login_shell, os.path.basename(login_shell), '-c', cmd)
|
os.execlp(login_shell, os.path.basename(login_shell), '-c', cmd)
|
||||||
|
TEST_SCRIPT # noqa
|
||||||
if ksi and 'no-rc' not in ksi:
|
if ksi and 'no-rc' not in ksi:
|
||||||
exec_with_shell_integration()
|
exec_with_shell_integration()
|
||||||
os.environ.pop('KITTY_SHELL_INTEGRATION', None)
|
os.environ.pop('KITTY_SHELL_INTEGRATION', None)
|
||||||
|
|||||||
@ -324,6 +324,9 @@ exec_with_shell_integration() {
|
|||||||
esac
|
esac
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Used in the tests
|
||||||
|
TEST_SCRIPT
|
||||||
|
|
||||||
case "$KITTY_SHELL_INTEGRATION" in
|
case "$KITTY_SHELL_INTEGRATION" in
|
||||||
("")
|
("")
|
||||||
# only blanks or unset
|
# only blanks or unset
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user