diff --git a/gen-go-code.py b/gen-go-code.py index f51194a61..01f7fc0cf 100755 --- a/gen-go-code.py +++ b/gen-go-code.py @@ -452,6 +452,7 @@ type VersionType struct {{ const VersionString string = "{kc.str_version}" const WebsiteBaseURL string = "{kc.website_base_url}" const VCSRevision string = "" +const SSHControlMasterTemplate = "{kc.ssh_control_master_template}" const RC_ENCRYPTION_PROTOCOL_VERSION string = "{kc.RC_ENCRYPTION_PROTOCOL_VERSION}" const IsFrozenBuild bool = false const IsStandaloneBuild bool = false diff --git a/kitty_tests/check_build.py b/kitty_tests/check_build.py index d2463e6a3..89745f3d4 100644 --- a/kitty_tests/check_build.py +++ b/kitty_tests/check_build.py @@ -67,7 +67,7 @@ class TestBuild(BaseTest): q = stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH return mode & q == q - for x in ('kitty', 'kitten', 'askpass.py'): + for x in ('kitty', 'kitten'): x = os.path.join(shell_integration_dir, 'ssh', x) self.assertTrue(is_executable(x), f'{x} is not executable') if getattr(sys, 'frozen', False): diff --git a/pyproject.toml b/pyproject.toml index 4c1505aef..90f20d628 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.mypy] -files = 'kitty,kittens,glfw,*.py,docs/conf.py,shell-integration/ssh/askpass.py' +files = 'kitty,kittens,glfw,*.py,docs/conf.py' no_implicit_optional = true sqlite_cache = true cache_fine_grained = true diff --git a/setup.py b/setup.py index 0a4250221..bdc74876b 100755 --- a/setup.py +++ b/setup.py @@ -1459,7 +1459,7 @@ def package(args: Options, bundle_type: str) -> None: if path.endswith('.so'): return True q = path.split(os.sep)[-2:] - if len(q) == 2 and q[0] == 'ssh' and q[1] in ('askpass.py', 'kitty', 'kitten'): + if len(q) == 2 and q[0] == 'ssh' and q[1] in ('kitty', 'kitten'): return True return False diff --git a/shell-integration/ssh/askpass.py b/shell-integration/ssh/askpass.py deleted file mode 100755 index 868d79b16..000000000 --- a/shell-integration/ssh/askpass.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env -S kitty +launch -# License: GPLv3 Copyright: 2022, Kovid Goyal - -import json -import os -import sys -import time - -from kitty.shm import SharedMemory - -msg = sys.argv[-1] -prompt = os.environ.get('SSH_ASKPASS_PROMPT', '') -is_confirm = prompt == 'confirm' -is_fingerprint_check = '(yes/no/[fingerprint])' in msg -q = { - 'message': msg, - 'type': 'confirm' if is_confirm else 'get_line', - 'is_password': not is_fingerprint_check, -} - -data = json.dumps(q) -with SharedMemory( - size=len(data) + 1 + SharedMemory.num_bytes_for_size, unlink_on_exit=True, prefix=f'askpass-{os.getpid()}-') as shm, \ - open(os.ctermid(), 'wb') as tty: - shm.write(b'\0') - shm.write_data_with_size(data) - shm.flush() - with open(os.ctermid(), 'wb') as f: - f.write(f'\x1bP@kitty-ask|{shm.name}\x1b\\'.encode('ascii')) - f.flush() - while True: - # TODO: Replace sleep() with a mutex and condition variable created in the shared memory - time.sleep(0.05) - shm.seek(0) - if shm.read(1) == b'\x01': - break - response = json.loads(shm.read_data_with_size()) -if is_confirm: - response = 'yes' if response else 'no' -elif is_fingerprint_check: - if response.lower() in ('y', 'yes'): - response = 'yes' - if response.lower() in ('n', 'no'): - response = 'no' -if response: - print(response, flush=True) diff --git a/tools/cmd/main.go b/tools/cmd/main.go index 08c697b7d..23c12705c 100644 --- a/tools/cmd/main.go +++ b/tools/cmd/main.go @@ -3,12 +3,22 @@ package main import ( + "os" + "kitty/tools/cli" "kitty/tools/cmd/completion" + "kitty/tools/cmd/ssh" "kitty/tools/cmd/tool" ) func main() { + krm := os.Getenv("KITTY_KITTEN_RUN_MODULE") + os.Unsetenv("KITTY_KITTEN_RUN_MODULE") + switch krm { + case "ssh_askpass": + ssh.RunSSHAskpass() + return + } root := cli.NewRootCommand() root.ShortDescription = "Fast, statically compiled implementations for various kittens (command line tools for use with kitty)" root.Usage = "command [command options] [command args]" diff --git a/tools/cmd/ssh/askpass.go b/tools/cmd/ssh/askpass.go new file mode 100644 index 000000000..edc7bbf31 --- /dev/null +++ b/tools/cmd/ssh/askpass.go @@ -0,0 +1,118 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + "kitty/tools/cli" + "kitty/tools/tty" + "kitty/tools/utils/shm" +) + +var _ = fmt.Print + +func fatal(err error) { + cli.ShowError(err) + os.Exit(1) +} + +func trigger_ask(name string) { + term, err := tty.OpenControllingTerm() + if err != nil { + fatal(err) + } + defer term.Close() + _, err = term.WriteString("\x1bP@kitty-ask|" + name + "\x1b\\") + if err != nil { + fatal(err) + } + +} + +func RunSSHAskpass() { + msg := os.Args[len(os.Args)-1] + prompt := os.Getenv("SSH_ASKPASS_PROMPT") + is_confirm := prompt == "confirm" + q_type := "get_line" + if is_confirm { + q_type = "confirm" + } + is_fingerprint_check := strings.Contains(msg, "(yes/no/[fingerprint])") + q := map[string]any{ + "message": msg, + "type": q_type, + "is_password": !is_fingerprint_check, + } + data, err := json.Marshal(q) + if err != nil { + fatal(err) + } + shm, err := shm.CreateTemp("askpass-*", uint64(len(data)+32)) + if err != nil { + fatal(fmt.Errorf("Failed to create SHM file with error: %w", err)) + } + defer shm.Close() + defer shm.Unlink() + + shm.Slice()[0] = 0 + binary.BigEndian.PutUint32(shm.Slice()[1:], uint32(len(data))) + copy(shm.Slice()[5:], data) + err = shm.Flush() + if err != nil { + fatal(fmt.Errorf("Failed to flush SHM file with error: %w", err)) + } + trigger_ask(shm.Name()) + buf := []byte{0} + for { + time.Sleep(50 * time.Millisecond) + _, err = shm.Seek(0, os.SEEK_SET) + if err != nil { + fatal(fmt.Errorf("Failed to seek into SHM file while waiting for response with error: %w", err)) + } + _, err = shm.Read(buf) + if err != nil { + fatal(fmt.Errorf("Failed to read from SHM file while waiting for response with error: %w", err)) + } + if buf[0] == 1 { + break + } + } + data, err = shm.ReadWithSize() + if err != nil { + fatal(fmt.Errorf("Failed to read response data from SHM file with error: %w", err)) + } + response := "" + if is_confirm { + var ok bool + err = json.Unmarshal(data, &ok) + if err != nil { + fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err)) + } + response = "no" + if ok { + response = "yes" + } + } else { + err = json.Unmarshal(data, &response) + if err != nil { + fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err)) + } + if is_fingerprint_check { + response = strings.ToLower(response) + if response == "y" { + response = "yes" + } else if response == "n" { + response = "no" + } + } + } + if response != "" { + fmt.Println(response) + } +} diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index 019c6c396..551f04f9e 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -7,16 +7,22 @@ import ( "errors" "fmt" "io/fs" + "kitty" "net/url" "os" + "os/exec" "os/user" + "path/filepath" + "strconv" "strings" "kitty/tools/cli" "kitty/tools/tty" + "kitty/tools/utils" "kitty/tools/utils/shm" "golang.org/x/exp/maps" + "golang.org/x/exp/slices" "golang.org/x/sys/unix" ) @@ -102,8 +108,58 @@ func parse_kitten_args(found_extra_args []string, username, hostname_for_match s return } +func connection_sharing_args(kitty_pid int) ([]string, error) { + rd := utils.RuntimeDir() + // Bloody OpenSSH generates a 40 char hash and in creating the socket + // appends a 27 char temp suffix to it. Socket max path length is approx + // ~104 chars. And on idiotic Apple the path length to the runtime dir + // (technically the cache dir since Apple has no runtime dir and thinks it's + // a great idea to delete files in /tmp) is ~48 chars. + if len(rd) > 35 { + idiotic_design := fmt.Sprintf("/tmp/kssh-rdir-%d", os.Geteuid()) + if err := utils.AtomicCreateSymlink(rd, idiotic_design); err != nil { + return nil, err + } + rd = idiotic_design + } + cp := strings.Replace(kitty.SSHControlMasterTemplate, "{kitty_pid}", strconv.Itoa(kitty_pid), 1) + cp = strings.Replace(cp, "{ssh_placeholder}", "%C", 1) + return []string{ + "-o", "ControlMaster=auto", + "-o", "ControlPath=" + cp, + "-o", "ControlPersist=yes", + "-o", "ServerAliveInterval=60", + "-o", "ServerAliveCountMax=5", + "-o", "TCPKeepAlive=no", + }, nil +} + +func set_askpass() (need_to_request_data bool) { + need_to_request_data = true + sentinel := filepath.Join(utils.CacheDir(), "openssh-is-new-enough-for-askpass") + _, err := os.Stat(sentinel) + sentinel_exists := err == nil + if sentinel_exists || GetSSHVersion().SupportsAskpassRequire() { + if !sentinel_exists { + os.WriteFile(sentinel, []byte{0}, 0o644) + } + need_to_request_data = false + } + exe, err := os.Executable() + if err == nil { + os.Setenv("SSH_ASKPASS", exe) + os.Setenv("KITTY_KITTEN_RUN_MODULE", "ssh_askpass") + if !need_to_request_data { + os.Setenv("SSH_ASKPASS_REQUIRE", "force") + } + } else { + need_to_request_data = true + } + return +} + func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) { - cmd := append([]string{ssh_exe()}, ssh_args...) + cmd := append([]string{SSHExe()}, ssh_args...) hostname, remote_args := server_args[0], server_args[1:] if len(remote_args) == 0 { cmd = append(cmd, "-t") @@ -115,10 +171,35 @@ func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err erro if err != nil { return 1, err } - if insertion_point > 0 && overrides != nil && literal_env != nil { + host_opts, err := load_config(hostname_for_match, uname, overrides) + if err != nil { + return 1, err } - // TODO: Implement me - return + if host_opts.Share_connections { + kpid, err := strconv.Atoi(os.Getenv("KITTY_PID")) + if err != nil { + return 1, fmt.Errorf("Invalid KITTY_PID env var not an integer: %#v", os.Getenv("KITTY_PID")) + } + cpargs, err := connection_sharing_args(kpid) + if err != nil { + return 1, err + } + cmd = slices.Insert(cmd, insertion_point, cpargs...) + } + use_kitty_askpass := host_opts.Askpass == Askpass_native || (host_opts.Askpass == Askpass_unless_set && os.Getenv("SSH_ASKPASS") == "") + need_to_request_data := true + if use_kitty_askpass { + need_to_request_data = set_askpass() + } + if need_to_request_data && host_opts.Share_connections { + check_cmd := slices.Insert(cmd, 1, "-O", "check") + err = exec.Command(check_cmd[0], check_cmd[1:]...).Run() + if err == nil { + need_to_request_data = false + } + } + _ = literal_env + return 0, nil } func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { @@ -139,7 +220,7 @@ func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { if invargs.Msg != "" { fmt.Fprintln(os.Stderr, invargs.Msg) } - return 1, unix.Exec(ssh_exe(), []string{"ssh"}, os.Environ()) + return 1, unix.Exec(SSHExe(), []string{"ssh"}, os.Environ()) } return 1, err } @@ -147,7 +228,7 @@ func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { if len(found_extra_args) > 0 { return 1, fmt.Errorf("The SSH kitten cannot work with the options: %s", strings.Join(maps.Keys(PassthroughArgs()), " ")) } - return 1, unix.Exec(ssh_exe(), append([]string{"ssh"}, args...), os.Environ()) + return 1, unix.Exec(SSHExe(), append([]string{"ssh"}, args...), os.Environ()) } if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" { return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window") diff --git a/tools/cmd/ssh/utils.go b/tools/cmd/ssh/utils.go index 24d7bec68..1f1af9c4c 100644 --- a/tools/cmd/ssh/utils.go +++ b/tools/cmd/ssh/utils.go @@ -7,28 +7,26 @@ import ( "io" "kitty/tools/utils" "os/exec" + "regexp" + "strconv" "strings" - "sync" ) var _ = fmt.Print -var ssh_options map[string]string -var query_ssh_for_options_once sync.Once - -func ssh_exe() string { +var SSHExe = (&utils.Once[string]{Run: func() string { ans := utils.Which("ssh") if ans != "" { return ans } - ans = utils.Which("ssh", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin") + ans = utils.Which("ssh", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin", "/usr/sbin", "/sbin") if ans == "" { ans = "ssh" } return ans -} +}}).Get -func get_ssh_options() { +var SSHOptions = (&utils.Once[map[string]string]{Run: func() (ssh_options map[string]string) { defer func() { if ssh_options == nil { ssh_options = map[string]string{ @@ -42,7 +40,7 @@ func get_ssh_options() { } } }() - cmd := exec.Command(ssh_exe()) + cmd := exec.Command(SSHExe()) stderr, err := cmd.StderrPipe() if err != nil { return @@ -86,12 +84,8 @@ func get_ssh_options() { } } } -} - -func SSHOptions() map[string]string { - query_ssh_for_options_once.Do(get_ssh_options) - return ssh_options -} + return +}}).Get func GetSSHCLI() (boolean_ssh_args *utils.Set[string], other_ssh_args *utils.Set[string]) { other_ssh_args, boolean_ssh_args = utils.NewSet[string](32), utils.NewSet[string](32) @@ -205,3 +199,23 @@ func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, serve } return } + +type SSHVersion struct{ Major, Minor int } + +func (self SSHVersion) SupportsAskpassRequire() bool { + return self.Major > 8 || (self.Major == 8 && self.Minor >= 4) +} + +var GetSSHVersion = (&utils.Once[SSHVersion]{Run: func() SSHVersion { + b, err := exec.Command(SSHExe(), "-V").CombinedOutput() + if err != nil { + return SSHVersion{} + } + m := regexp.MustCompile(`OpenSSH_(\d+).(\d+)`).FindSubmatch(b) + if len(m) == 3 { + maj, _ := strconv.Atoi(utils.UnsafeBytesToString(m[1])) + min, _ := strconv.Atoi(utils.UnsafeBytesToString(m[2])) + return SSHVersion{Major: maj, Minor: min} + } + return SSHVersion{} +}}).Get diff --git a/tools/cmd/ssh/utils_test.go b/tools/cmd/ssh/utils_test.go index 1b6bb3653..b4852e9d5 100644 --- a/tools/cmd/ssh/utils_test.go +++ b/tools/cmd/ssh/utils_test.go @@ -50,5 +50,4 @@ func TestParseSSHArgs(t *testing.T) { p(`-46p23 localhost sh -c "a b"`, `-4 -6 -p 23`, `localhost sh -c "a b"`, ``, false) p(`-46p23 -S/moose -W x:6 -- localhost sh -c "a b"`, `-4 -6 -p 23 -S /moose -W x:6`, `localhost sh -c "a b"`, ``, false) p(`--kitten=abc -np23 --kitten xyz host`, `-n -p 23`, `host`, `--kitten abc --kitten xyz`, true) - } diff --git a/tools/utils/atomic-write.go b/tools/utils/atomic-write.go index 04681e011..41f83896d 100644 --- a/tools/utils/atomic-write.go +++ b/tools/utils/atomic-write.go @@ -20,6 +20,9 @@ func AtomicCreateSymlink(oldname, newname string) (err error) { if !errors.Is(err, fs.ErrExist) { return err } + if et, err := os.Readlink(newname); err == nil && et == oldname { + return nil + } for { tempname := newname + RandomFilename() err = os.Symlink(oldname, tempname) diff --git a/tools/utils/shm/shm.go b/tools/utils/shm/shm.go index 389b21f4e..65c44bb56 100644 --- a/tools/utils/shm/shm.go +++ b/tools/utils/shm/shm.go @@ -52,6 +52,10 @@ type MMap interface { IsFileSystemBacked() bool FileSystemName() string Stat() (fs.FileInfo, error) + Flush() error + Seek(offset int64, whence int) (int64, error) + Read(b []byte) (int, error) + ReadWithSize() ([]byte, error) } type AccessFlags int diff --git a/tools/utils/shm/shm_fs.go b/tools/utils/shm/shm_fs.go index 5722665f5..3a79beeab 100644 --- a/tools/utils/shm/shm_fs.go +++ b/tools/utils/shm/shm_fs.go @@ -13,6 +13,8 @@ import ( "runtime" "kitty/tools/utils" + + "golang.org/x/sys/unix" ) var _ = fmt.Print @@ -51,6 +53,22 @@ func (self *file_based_mmap) Name() string { return filepath.Base(self.f.Name()) } +func (self *file_based_mmap) Flush() error { + return unix.Msync(self.region, unix.MS_SYNC) +} + +func (self *file_based_mmap) Seek(offset int64, whence int) (int64, error) { + return self.f.Seek(offset, whence) +} + +func (self *file_based_mmap) Read(b []byte) (int, error) { + return self.f.Read(b) +} + +func (self *file_based_mmap) ReadWithSize() ([]byte, error) { + return read_with_size(self.f) +} + func (self *file_based_mmap) FileSystemName() string { return self.f.Name() } diff --git a/tools/utils/shm/shm_syscall.go b/tools/utils/shm/shm_syscall.go index f4a8f2956..caddbd002 100644 --- a/tools/utils/shm/shm_syscall.go +++ b/tools/utils/shm/shm_syscall.go @@ -92,10 +92,26 @@ func (self *syscall_based_mmap) Stat() (fs.FileInfo, error) { return self.f.Stat() } +func (self *syscall_based_mmap) Flush() error { + return unix.Msync(self.region, unix.MS_SYNC) +} + func (self *syscall_based_mmap) Slice() []byte { return self.region } +func (self *syscall_based_mmap) Seek(offset int64, whence int) (int64, error) { + return self.f.Seek(offset, whence) +} + +func (self *syscall_based_mmap) Read(b []byte) (int, error) { + return self.f.Read(b) +} + +func (self *syscall_based_mmap) ReadWithSize() ([]byte, error) { + return read_with_size(self.f) +} + func (self *syscall_based_mmap) Close() (err error) { if self.region != nil { self.f.Close() diff --git a/tools/utils/shm/shm_test.go b/tools/utils/shm/shm_test.go index 6d6c2817e..086731f8e 100644 --- a/tools/utils/shm/shm_test.go +++ b/tools/utils/shm/shm_test.go @@ -23,6 +23,10 @@ func TestSHM(t *testing.T) { } copy(mm.Slice(), data) + err = mm.Flush() + if err != nil { + t.Fatalf("Failed to msync() with error: %v", err) + } err = mm.Close() if err != nil { t.Fatalf("Failed to close with error: %v", err)