From 97b9572bec01f810abae245b32c5ea969bae394f Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 20 Feb 2023 17:18:43 +0530 Subject: [PATCH] Port parsing of ssh args --- tools/cmd/ssh/main.go | 9 +++ tools/cmd/ssh/utils.go | 109 ++++++++++++++++++++++++++++++++++++ tools/cmd/ssh/utils_test.go | 37 ++++++++++++ 3 files changed, 155 insertions(+) diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index fa6fb7b56..092ae7784 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -11,6 +11,15 @@ import ( var _ = fmt.Print func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { + if len(args) > 0 { + switch args[0] { + case "use-python": + args = args[1:] // backwards compat from when we had a python implementation + case "-h", "--help": + cmd.ShowHelp() + return + } + } return } diff --git a/tools/cmd/ssh/utils.go b/tools/cmd/ssh/utils.go index 3e5765124..fee3593b9 100644 --- a/tools/cmd/ssh/utils.go +++ b/tools/cmd/ssh/utils.go @@ -80,3 +80,112 @@ func SSHOptions() map[string]string { query_ssh_for_options_once.Do(get_ssh_options) return ssh_options } + +func GetSSHCLI() (boolean_ssh_args *utils.Set[string], other_ssh_args *utils.Set[string]) { + other_ssh_args, boolean_ssh_args = utils.NewSet[string](32), utils.NewSet[string](32) + for k, v := range SSHOptions() { + k = "-" + k + if v == "" { + boolean_ssh_args.Add(k) + } else { + other_ssh_args.Add(k) + } + } + return +} + +func is_extra_arg(arg string, extra_args []string) string { + for _, x := range extra_args { + if arg == x || strings.HasPrefix(arg, x+"=") { + return x + } + } + return "" +} + +type ErrInvalidSSHArgs struct { + Msg string +} + +func (self *ErrInvalidSSHArgs) Error() string { + return self.Msg +} + +func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, server_args []string, passthrough bool, found_extra_args []string, err error) { + if extra_args == nil { + extra_args = []string{} + } + if len(args) == 0 { + passthrough = true + return + } + passthrough_args := map[string]bool{"-N": true, "-n": true, "-f": true, "-G": true, "-T": true} + boolean_ssh_args, other_ssh_args := GetSSHCLI() + ssh_args, server_args, found_extra_args = make([]string, 0, 16), make([]string, 0, 16), make([]string, 0, 16) + expecting_option_val := false + stop_option_processing := false + expecting_extra_val := "" + for _, argument := range args { + if len(server_args) > 1 || stop_option_processing { + server_args = append(server_args, argument) + continue + } + if strings.HasPrefix(argument, "-") && !expecting_option_val { + if argument == "--" { + stop_option_processing = true + continue + } + if len(extra_args) > 0 { + matching_ex := is_extra_arg(argument, extra_args) + if matching_ex != "" { + _, exval, found := strings.Cut(argument, "=") + if found { + found_extra_args = append(found_extra_args, matching_ex, exval) + } else { + expecting_extra_val = matching_ex + expecting_option_val = true + } + continue + } + } + // could be a multi-character option + all_args := []rune(argument[1:]) + for i, ch := range all_args { + arg := "-" + string(ch) + if passthrough_args[arg] { + passthrough = true + } + if boolean_ssh_args.Has(arg) { + ssh_args = append(ssh_args, arg) + continue + } + if other_ssh_args.Has(arg) { + ssh_args = append(ssh_args, arg) + if i+1 < len(all_args) { + ssh_args = append(ssh_args, string(all_args[i+1:])) + } else { + expecting_option_val = true + } + break + } + err = &ErrInvalidSSHArgs{Msg: "unknown option -- " + arg[1:]} + return + } + continue + } + if expecting_option_val { + if expecting_extra_val != "" { + found_extra_args = append(found_extra_args, expecting_extra_val, argument) + } else { + ssh_args = append(ssh_args, argument) + } + expecting_option_val = false + continue + } + server_args = append(server_args, argument) + } + if len(server_args) == 0 { + err = &ErrInvalidSSHArgs{Msg: "No server to connect to specified"} + } + return +} diff --git a/tools/cmd/ssh/utils_test.go b/tools/cmd/ssh/utils_test.go index d7425a7ec..1b6bb3653 100644 --- a/tools/cmd/ssh/utils_test.go +++ b/tools/cmd/ssh/utils_test.go @@ -5,6 +5,10 @@ package ssh import ( "fmt" "testing" + + "kitty/tools/utils/shlex" + + "github.com/google/go-cmp/cmp" ) var _ = fmt.Print @@ -15,3 +19,36 @@ func TestGetSSHOptions(t *testing.T) { t.Fatalf("Unexpected set of SSH options: %#v", m) } } + +func TestParseSSHArgs(t *testing.T) { + split := func(x string) []string { + ans, err := shlex.Split(x) + if err != nil { + t.Fatal(err) + } + return ans + } + + p := func(args, expected_ssh_args, expected_server_args, expected_extra_args string, expected_passthrough bool) { + ssh_args, server_args, passthrough, extra_args, err := ParseSSHArgs(split(args), "--kitten") + if err != nil { + t.Fatal(err) + } + check := func(a, b any) { + diff := cmp.Diff(a, b) + if diff != "" { + t.Fatalf("Unexpected value for args: %s\n%s", args, diff) + } + } + check(split(expected_ssh_args), ssh_args) + check(split(expected_server_args), server_args) + check(split(expected_extra_args), extra_args) + check(expected_passthrough, passthrough) + } + p(`localhost`, ``, `localhost`, ``, false) + p(`-- localhost`, ``, `localhost`, ``, false) + p(`-46p23 localhost sh -c "a b"`, `-4 -6 -p 23`, `localhost sh -c "a b"`, ``, false) + p(`-46p23 -S/moose -W x:6 -- localhost sh -c "a b"`, `-4 -6 -p 23 -S /moose -W x:6`, `localhost sh -c "a b"`, ``, false) + p(`--kitten=abc -np23 --kitten xyz host`, `-n -p 23`, `host`, `--kitten abc --kitten xyz`, true) + +}