From d98504e1a66186a9628f960cf63f6af44a835c3b Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 19 Feb 2023 14:31:03 +0530 Subject: [PATCH] Finish porting SSH config file parsing --- kitty/conf/generate.py | 3 +- kitty_tests/ssh.py | 20 ----- tools/cmd/ssh/config.go | 145 ++++++++++++++++++++++++++++++++++- tools/cmd/ssh/config_test.go | 76 ++++++++++++++++++ 4 files changed, 221 insertions(+), 23 deletions(-) create mode 100644 tools/cmd/ssh/config_test.go diff --git a/kitty/conf/generate.py b/kitty/conf/generate.py index 46a5a82e6..044b8fd09 100644 --- a/kitty/conf/generate.py +++ b/kitty/conf/generate.py @@ -460,7 +460,8 @@ def go_type_data(parser_func: ParserFuncType, ctype: str) -> Tuple[str, str]: def gen_go_code(defn: Definition) -> str: - lines = ['import "fmt"', 'import "strconv"', 'import "kitty/tools/config"', 'var _ = fmt.Println', 'var _ = config.StringToBool', 'var _ = strconv.Atoi'] + lines = ['import "fmt"', 'import "strconv"', 'import "kitty/tools/config"', + 'var _ = fmt.Println', 'var _ = config.StringToBool', 'var _ = strconv.Atoi'] a = lines.append choices = {} go_types = {} diff --git a/kitty_tests/ssh.py b/kitty_tests/ssh.py index dc6458aee..7a006ddbb 100644 --- a/kitty_tests/ssh.py +++ b/kitty_tests/ssh.py @@ -56,26 +56,6 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77) t('ssh --kitten=one -p 12 --kitten two -ix main', identity_file='x', port=12, extra_args=(('--kitten', 'one'), ('--kitten', 'two'))) self.assertTrue(runtime_dir()) - def test_ssh_config_parsing(self): - def parse(conf, hostname='unmatched_host', username=''): - return load_config(overrides=conf.splitlines(), hostname=hostname, username=username) - - self.ae(parse('').env, {}) - self.ae(parse('env a=b').env, {'a': 'b'}) - conf = 'env a=b\nhostname 2\nenv a=c\nenv b=b' - self.ae(parse(conf).env, {'a': 'b'}) - self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'}) - self.ae(parse('env a=').env, {'a': ''}) - self.ae(parse('env a').env, {'a': '_delete_this_env_var_'}) - conf = 'env a=b\nhostname test@2\nenv a=c\nenv b=b' - self.ae(parse(conf).env, {'a': 'b'}) - self.ae(parse(conf, '2').env, {'a': 'b'}) - self.ae(parse(conf, '2', 'test').env, {'a': 'c', 'b': 'b'}) - conf = 'env a=b\nhostname 1 2\nenv a=c\nenv b=b' - self.ae(parse(conf).env, {'a': 'b'}) - self.ae(parse(conf, '1').env, {'a': 'c', 'b': 'b'}) - self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'}) - def test_ssh_bootstrap_sh_cmd_limit(self): # dropbear has a 9000 bytes maximum command length limit sh_script, _, _ = bootstrap_script(SSHOptions({'interpreter': 'sh'}), script_type='sh', remote_args=[], request_id='123-123') diff --git a/tools/cmd/ssh/config.go b/tools/cmd/ssh/config.go index 61df05778..6e88a5d2e 100644 --- a/tools/cmd/ssh/config.go +++ b/tools/cmd/ssh/config.go @@ -3,12 +3,16 @@ package ssh import ( + "encoding/json" "errors" "fmt" + "io/fs" "os" "path/filepath" "strings" + "kitty/tools/config" + "kitty/tools/utils" "kitty/tools/utils/paths" "kitty/tools/utils/shlex" @@ -18,8 +22,85 @@ import ( var _ = fmt.Print type EnvInstruction struct { - key, val string - delete_on_remote, copy_from_local bool + key, val string + delete_on_remote, copy_from_local, literal_quote bool +} + +func quote_for_sh(val string, literal_quote bool) string { + if literal_quote { + return utils.QuoteStringForSH(val) + } + // See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html + b := strings.Builder{} + b.Grow(len(val) + 16) + b.WriteRune('"') + runes := []rune(val) + for i, ch := range runes { + if ch == '\\' || ch == '`' || ch == '"' || (ch == '$' && i+1 < len(runes) && runes[i+1] == '(') { + // special chars are escaped + // $( is escaped to prevent execution + b.WriteRune('\\') + } + b.WriteRune(ch) + } + b.WriteRune('"') + return b.String() +} + +func (self *EnvInstruction) Serialize(for_python bool, get_local_env func(string) (string, bool)) string { + var unset func() string + var export func(string) string + if for_python { + dumps := func(x ...any) string { + ans, _ := json.Marshal(x) + return utils.UnsafeBytesToString(ans) + } + export = func(val string) string { + if val == "" { + return fmt.Sprintf("export %s", dumps(self.key)) + } + return fmt.Sprintf("export %s", dumps(self.key, val, self.literal_quote)) + } + unset = func() string { + return fmt.Sprintf("unset %s", dumps(self.key)) + } + } else { + kq := utils.QuoteStringForSH(self.key) + unset = func() string { + return fmt.Sprintf("unset %s", kq) + } + export = func(val string) string { + return fmt.Sprintf("export %s=%s", kq, quote_for_sh(val, self.literal_quote)) + } + } + if self.delete_on_remote { + return unset() + } + if self.copy_from_local { + val, found := get_local_env(self.key) + if !found { + return "" + } + return export(val) + } + return export(self.val) +} + +func (self *Config) final_env_instructions(for_python bool, get_local_env func(string) (string, bool)) string { + seen := make(map[string]int, len(self.Env)) + ans := make([]string, 0, len(self.Env)) + for _, ei := range self.Env { + q := ei.Serialize(for_python, get_local_env) + if q != "" { + if pos, found := seen[ei.key]; found { + ans[pos] = q + } else { + seen[ei.key] = len(ans) + ans = append(ans, q) + } + } + } + return strings.Join(ans, "\n") } type CopyInstruction struct { @@ -140,3 +221,63 @@ func ParseCopyInstruction(spec string) (ans []*CopyInstruction, err error) { } return } + +type ConfigSet struct { + all_configs []*Config +} + +func config_for_hostname(hostname_to_match, username_to_match string, cs *ConfigSet) *Config { + matcher := func(q *Config) bool { + for _, pat := range strings.Split(q.Hostname, " ") { + upat := "*" + if strings.Contains(pat, "@") { + upat, pat, _ = strings.Cut(pat, "@") + } + var host_matched, user_matched bool + if matched, err := filepath.Match(pat, hostname_to_match); matched && err == nil { + host_matched = true + } + if matched, err := filepath.Match(upat, username_to_match); matched && err == nil { + user_matched = true + } + if host_matched && user_matched { + return true + } + } + return false + } + for _, c := range utils.Reversed(cs.all_configs) { + if matcher(c) { + return c + } + } + return cs.all_configs[0] +} + +func (self *ConfigSet) line_handler(key, val string) error { + c := self.all_configs[len(self.all_configs)-1] + if key == "hostname" { + c = NewConfig() + self.all_configs = append(self.all_configs, c) + } + return c.Parse(key, val) +} + +func load_config(hostname_to_match string, username_to_match string, overrides []string, paths ...string) (*Config, error) { + ans := &ConfigSet{all_configs: []*Config{NewConfig()}} + p := config.ConfigParser{LineHandler: ans.line_handler} + if len(paths) == 0 { + paths = []string{filepath.Join(utils.ConfigDir(), "ssh.conf")} + } + err := p.ParseFiles(paths...) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return nil, err + } + if len(overrides) > 0 { + err = p.ParseOverrides(overrides...) + if err != nil { + return nil, err + } + } + return config_for_hostname(hostname_to_match, username_to_match, ans), nil +} diff --git a/tools/cmd/ssh/config_test.go b/tools/cmd/ssh/config_test.go new file mode 100644 index 000000000..7a3f15386 --- /dev/null +++ b/tools/cmd/ssh/config_test.go @@ -0,0 +1,76 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "fmt" + "kitty/tools/utils" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" +) + +var _ = fmt.Print + +func TestSSHConfigParsing(t *testing.T) { + tdir := t.TempDir() + hostname := "unmatched" + username := "" + conf := "" + for_python := false + rt := func(expected_env ...string) { + cf := filepath.Join(tdir, "ssh.conf") + os.WriteFile(cf, []byte(conf), 0o600) + c, err := load_config(hostname, username, nil, cf) + if err != nil { + t.Fatal(err) + } + actual := c.final_env_instructions(for_python, func(key string) (string, bool) { + if key == "LOCAL_ENV" { + return "LOCAL_VAL", true + } + return "", false + }) + if expected_env == nil { + expected_env = []string{} + } + diff := cmp.Diff(expected_env, utils.Splitlines(actual)) + if diff != "" { + t.Fatalf("Unexpected env for\nhostname: %#v\nusername: %#v\nconf: %s\n%s", hostname, username, conf, diff) + } + } + rt() + conf = "env a=b" + rt(`export 'a'="b"`) + conf = "env a=b\nhostname 2\nenv a=c\nenv b=b" + rt(`export 'a'="b"`) + hostname = "2" + rt(`export 'a'="c"`, `export 'b'="b"`) + conf = "env a=" + rt(`export 'a'=""`) + conf = "env a" + rt(`unset 'a'`) + conf = "env a=b\nhostname test@2\nenv a=c\nenv b=b" + hostname = "unmatched" + rt(`export 'a'="b"`) + hostname = "2" + rt(`export 'a'="b"`) + username = "test" + rt(`export 'a'="c"`, `export 'b'="b"`) + conf = "env a=b\nhostname 1 2\nenv a=c\nenv b=b" + username = "" + hostname = "unmatched" + rt(`export 'a'="b"`) + hostname = "1" + rt(`export 'a'="c"`, `export 'b'="b"`) + hostname = "2" + rt(`export 'a'="c"`, `export 'b'="b"`) + for_python = true + rt(`export ["a","c",false]`, `export ["b","b",false]`) + conf = "env a=" + rt(`export ["a"]`) + conf = "env a" + rt(`unset ["a"]`) +}