diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index 092ae7784..e18059b3b 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -3,9 +3,15 @@ package ssh import ( + "errors" "fmt" + "os" + "strings" "kitty/tools/cli" + + "golang.org/x/exp/maps" + "golang.org/x/sys/unix" ) var _ = fmt.Print @@ -20,6 +26,27 @@ func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { return } } + ssh_args, server_args, passthrough, found_extra_args, err := ParseSSHArgs(args, "--kitten") + if err != nil { + var invargs *ErrInvalidSSHArgs + switch { + case errors.As(err, &invargs): + if invargs.Msg != "" { + fmt.Fprintln(os.Stderr, invargs.Msg) + } + return 1, unix.Exec(ssh_exe(), []string{"ssh"}, os.Environ()) + } + return 1, err + } + if passthrough { + if len(found_extra_args) > 0 { + return 1, fmt.Errorf("The SSH kitten cannot work with the options: %s", strings.Join(maps.Keys(PassthroughArgs()), " ")) + } + return 1, unix.Exec(ssh_exe(), append([]string{"ssh"}, args...), os.Environ()) + } + if false { + return len(ssh_args) + len(server_args), nil + } return } diff --git a/tools/cmd/ssh/utils.go b/tools/cmd/ssh/utils.go index fee3593b9..24d7bec68 100644 --- a/tools/cmd/ssh/utils.go +++ b/tools/cmd/ssh/utils.go @@ -16,6 +16,18 @@ var _ = fmt.Print var ssh_options map[string]string var query_ssh_for_options_once sync.Once +func ssh_exe() string { + ans := utils.Which("ssh") + if ans != "" { + return ans + } + ans = utils.Which("ssh", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin") + if ans == "" { + ans = "ssh" + } + return ans +} + func get_ssh_options() { defer func() { if ssh_options == nil { @@ -30,7 +42,7 @@ func get_ssh_options() { } } }() - cmd := exec.Command("ssh") + cmd := exec.Command(ssh_exe()) stderr, err := cmd.StderrPipe() if err != nil { return @@ -111,6 +123,10 @@ func (self *ErrInvalidSSHArgs) Error() string { return self.Msg } +func PassthroughArgs() map[string]bool { + return map[string]bool{"-N": true, "-n": true, "-f": true, "-G": true, "-T": true} +} + func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, server_args []string, passthrough bool, found_extra_args []string, err error) { if extra_args == nil { extra_args = []string{} @@ -119,7 +135,7 @@ func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, serve passthrough = true return } - passthrough_args := map[string]bool{"-N": true, "-n": true, "-f": true, "-G": true, "-T": true} + passthrough_args := PassthroughArgs() boolean_ssh_args, other_ssh_args := GetSSHCLI() ssh_args, server_args, found_extra_args = make([]string, 0, 16), make([]string, 0, 16), make([]string, 0, 16) expecting_option_val := false @@ -184,8 +200,8 @@ func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, serve } server_args = append(server_args, argument) } - if len(server_args) == 0 { - err = &ErrInvalidSSHArgs{Msg: "No server to connect to specified"} + if len(server_args) == 0 && !passthrough { + err = &ErrInvalidSSHArgs{Msg: ""} } return }