diff --git a/kitty_tests/ssh.py b/kitty_tests/ssh.py index 801cd54b9..94c01a21e 100644 --- a/kitty_tests/ssh.py +++ b/kitty_tests/ssh.py @@ -3,17 +3,17 @@ import glob +import json import os import shutil +import subprocess import tempfile from contextlib import suppress from functools import lru_cache -from kittens.ssh.config import load_config -from kittens.ssh.main import bootstrap_script, wrap_bootstrap_script from kittens.ssh.utils import get_connection_data from kittens.transfer.utils import set_paths -from kitty.constants import is_macos, runtime_dir +from kitty.constants import is_macos, kitten_exe, runtime_dir from kitty.fast_data_types import CURSOR_BEAM, shm_unlink from kitty.utils import SSHConnectionData @@ -88,10 +88,8 @@ copy --dest=a/sfa simple-file copy --glob g.* copy --exclude */w.* d1 ''' - copy = load_config(overrides=filter(None, conf.splitlines())).copy self.check_bootstrap( - sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', - ssh_opts={'copy': copy} + sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', conf=conf ) tname = '.terminfo' if os.path.exists('/usr/share/misc/terminfo.cdb'): @@ -125,13 +123,14 @@ copy --exclude */w.* d1 for sh in self.all_possible_sh: with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir: os.mkdir(os.path.join(tdir, 'cwd')) + conf = f''' +cwd $HOME/cwd +env A=AAA +env TSET={tset} +env COLORTERM +''' pty = self.check_bootstrap( - sh, tdir, test_script='env; pwd; exit 0', SHELL_INTEGRATION_VALUE='', - ssh_opts={'cwd': '$HOME/cwd', 'env': { - 'A': 'AAA', - 'TSET': tset, - 'COLORTERM': DELETE_ENV_VAR, - }} + sh, tdir, test_script='env; pwd; exit 0', SHELL_INTEGRATION_VALUE='', conf=conf ) pty.wait_till(lambda: 'TSET={}'.format(tset.replace('$A', 'AAA')) in pty.screen_contents()) self.assertNotIn('COLORTERM', pty.screen_contents()) @@ -213,34 +212,31 @@ copy --exclude */w.* d1 self.assertEqual(pty.screen.cursor.shape, 0) self.assertNotIn(b'\x1b]133;', pty.received_bytes) - def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None, launcher='sh'): - ssh_opts = ssh_opts or {} + def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', conf='', launcher='sh'): if login_shell: - ssh_opts['login_shell'] = login_shell + conf += f'\nlogin_shell {login_shell}' if 'python' in sh: if test_script.startswith('env;'): test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})' test_script = f'print("UNTAR_DONE", flush=True); {test_script}' else: test_script = f'echo "UNTAR_DONE"; {test_script}' - ssh_opts['shell_integration'] = SHELL_INTEGRATION_VALUE or 'disabled' - script, replacements, shm_name = bootstrap_script( - SSHOptions(ssh_opts), script_type='py' if 'python' in sh else 'sh', request_id="testing", test_script=test_script, - request_data=True - ) + conf += '\nshell_integration ' + SHELL_INTEGRATION_VALUE or 'disabled' + conf += '\ninterpreter ' + sh + cp = subprocess.run([kitten_exe(), '__pytest__', 'ssh', test_script], stdout=subprocess.PIPE, input=conf.encode('utf-8')) + self.assertEqual(cp.returncode, 0) + self.rdata = json.loads(cp.stdout) + del cp try: env = basic_shell_env(home_dir) # Avoid generating unneeded completion scripts os.makedirs(os.path.join(home_dir, '.local', 'share', 'fish', 'generated_completions'), exist_ok=True) # prevent newuser-install from running open(os.path.join(home_dir, '.zshrc'), 'w').close() - cmd = wrap_bootstrap_script(script, sh) - pty = self.create_pty([launcher, '-c', ' '.join(cmd)], cwd=home_dir, env=env) + pty = self.create_pty([launcher, '-c', ' '.join(self.rdata['cmd'])], cwd=home_dir, env=env) pty.turn_off_echo() - del cmd if pre_data: pty.write_buf = pre_data.encode('utf-8') - del script def check_untar_or_fail(): q = pty.screen_contents() @@ -257,4 +253,4 @@ copy --exclude */w.* d1 return pty finally: with suppress(FileNotFoundError): - shm_unlink(shm_name) + shm_unlink(self.rdata['shm_name']) diff --git a/tools/cmd/pytest/main.go b/tools/cmd/pytest/main.go index 973d68b6b..bb929c705 100644 --- a/tools/cmd/pytest/main.go +++ b/tools/cmd/pytest/main.go @@ -6,6 +6,7 @@ import ( "fmt" "kitty/tools/cli" + "kitty/tools/cmd/ssh" "kitty/tools/utils/shm" ) @@ -17,4 +18,5 @@ func EntryPoint(root *cli.Command) { Hidden: true, }) shm.TestEntryPoint(root) + ssh.TestEntryPoint(root) } diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index 4a79e0f59..2fb6fb035 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/fs" "kitty" "net/url" @@ -236,7 +237,7 @@ func serialize_env(cd *connection_data, get_local_env func(string) (string, bool add_env("KITTY_REMOTE", cd.host_opts.Remote_kitty.String()) } add_env("KITTY_PUBLIC_KEY", os.Getenv("KITTY_PUBLIC_KEY")) - return final_env_instructions(cd.script_type == "py", get_local_env), ksi + return final_env_instructions(cd.script_type == "py", get_local_env, env...), ksi } func make_tarfile(cd *connection_data, get_local_env func(string) (string, bool)) ([]byte, error) { @@ -303,6 +304,9 @@ func make_tarfile(cd *connection_data, get_local_env func(string) (string, bool) } add_data(fe{"data.sh", utils.UnsafeStringToBytes(env_script)}) + if cd.script_type == "sh" { + add_data(fe{"bootstrap-utils.sh", Data()[path.Join("shell-integration/ssh/bootstrap-utils.sh")].data}) + } if ksi != "" { for _, fname := range Data().files_matching( "shell-integration/", @@ -616,3 +620,52 @@ func specialize_command(ssh *cli.Command) { ssh.OnlyArgsAllowed = true ssh.ArgCompleter = cli.CompletionForWrapper("ssh") } + +func test_integration_with_python(args []string) (rc int, err error) { + f, err := os.CreateTemp("", "*.conf") + if err != nil { + return 1, err + } + defer func() { + f.Close() + os.Remove(f.Name()) + }() + _, err = io.Copy(f, os.Stdin) + if err != nil { + return 1, err + } + cd := &connection_data{ + request_id: "testing", remote_args: []string{}, + username: "testuser", hostname_for_match: "host.test", request_data: true, + test_script: args[0], + } + opts, err := load_config(cd.hostname_for_match, cd.username, nil, f.Name()) + if err == nil { + cd.host_opts = opts + err = get_remote_command(cd) + } + if err != nil { + return 1, err + } + data, err := json.Marshal(map[string]any{"cmd": cd.rcmd, "shm_name": cd.shm_name}) + if err == nil { + _, err = os.Stdout.Write(data) + os.Stdout.Close() + } + if err != nil { + return 1, err + } + + return +} + +func TestEntryPoint(root *cli.Command) { + root.AddSubCommand(&cli.Command{ + Name: "ssh", + OnlyArgsAllowed: true, + Run: func(cmd *cli.Command, args []string) (rc int, err error) { + return test_integration_with_python(args) + }, + }) + +}