Finish porting SSH config file parsing

This commit is contained in:
Kovid Goyal 2023-02-19 14:31:03 +05:30
parent 07f4adbab5
commit d98504e1a6
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 221 additions and 23 deletions

View File

@ -460,7 +460,8 @@ def go_type_data(parser_func: ParserFuncType, ctype: str) -> Tuple[str, str]:
def gen_go_code(defn: Definition) -> str: def gen_go_code(defn: Definition) -> str:
lines = ['import "fmt"', 'import "strconv"', 'import "kitty/tools/config"', 'var _ = fmt.Println', 'var _ = config.StringToBool', 'var _ = strconv.Atoi'] lines = ['import "fmt"', 'import "strconv"', 'import "kitty/tools/config"',
'var _ = fmt.Println', 'var _ = config.StringToBool', 'var _ = strconv.Atoi']
a = lines.append a = lines.append
choices = {} choices = {}
go_types = {} go_types = {}

View File

@ -56,26 +56,6 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
t('ssh --kitten=one -p 12 --kitten two -ix main', identity_file='x', port=12, extra_args=(('--kitten', 'one'), ('--kitten', 'two'))) t('ssh --kitten=one -p 12 --kitten two -ix main', identity_file='x', port=12, extra_args=(('--kitten', 'one'), ('--kitten', 'two')))
self.assertTrue(runtime_dir()) self.assertTrue(runtime_dir())
def test_ssh_config_parsing(self):
def parse(conf, hostname='unmatched_host', username=''):
return load_config(overrides=conf.splitlines(), hostname=hostname, username=username)
self.ae(parse('').env, {})
self.ae(parse('env a=b').env, {'a': 'b'})
conf = 'env a=b\nhostname 2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'})
self.ae(parse('env a=').env, {'a': ''})
self.ae(parse('env a').env, {'a': '_delete_this_env_var_'})
conf = 'env a=b\nhostname test@2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '2').env, {'a': 'b'})
self.ae(parse(conf, '2', 'test').env, {'a': 'c', 'b': 'b'})
conf = 'env a=b\nhostname 1 2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '1').env, {'a': 'c', 'b': 'b'})
self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'})
def test_ssh_bootstrap_sh_cmd_limit(self): def test_ssh_bootstrap_sh_cmd_limit(self):
# dropbear has a 9000 bytes maximum command length limit # dropbear has a 9000 bytes maximum command length limit
sh_script, _, _ = bootstrap_script(SSHOptions({'interpreter': 'sh'}), script_type='sh', remote_args=[], request_id='123-123') sh_script, _, _ = bootstrap_script(SSHOptions({'interpreter': 'sh'}), script_type='sh', remote_args=[], request_id='123-123')

View File

@ -3,12 +3,16 @@
package ssh package ssh
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/fs"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"kitty/tools/config"
"kitty/tools/utils"
"kitty/tools/utils/paths" "kitty/tools/utils/paths"
"kitty/tools/utils/shlex" "kitty/tools/utils/shlex"
@ -19,7 +23,84 @@ var _ = fmt.Print
type EnvInstruction struct { type EnvInstruction struct {
key, val string key, val string
delete_on_remote, copy_from_local bool delete_on_remote, copy_from_local, literal_quote bool
}
func quote_for_sh(val string, literal_quote bool) string {
if literal_quote {
return utils.QuoteStringForSH(val)
}
// See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
b := strings.Builder{}
b.Grow(len(val) + 16)
b.WriteRune('"')
runes := []rune(val)
for i, ch := range runes {
if ch == '\\' || ch == '`' || ch == '"' || (ch == '$' && i+1 < len(runes) && runes[i+1] == '(') {
// special chars are escaped
// $( is escaped to prevent execution
b.WriteRune('\\')
}
b.WriteRune(ch)
}
b.WriteRune('"')
return b.String()
}
func (self *EnvInstruction) Serialize(for_python bool, get_local_env func(string) (string, bool)) string {
var unset func() string
var export func(string) string
if for_python {
dumps := func(x ...any) string {
ans, _ := json.Marshal(x)
return utils.UnsafeBytesToString(ans)
}
export = func(val string) string {
if val == "" {
return fmt.Sprintf("export %s", dumps(self.key))
}
return fmt.Sprintf("export %s", dumps(self.key, val, self.literal_quote))
}
unset = func() string {
return fmt.Sprintf("unset %s", dumps(self.key))
}
} else {
kq := utils.QuoteStringForSH(self.key)
unset = func() string {
return fmt.Sprintf("unset %s", kq)
}
export = func(val string) string {
return fmt.Sprintf("export %s=%s", kq, quote_for_sh(val, self.literal_quote))
}
}
if self.delete_on_remote {
return unset()
}
if self.copy_from_local {
val, found := get_local_env(self.key)
if !found {
return ""
}
return export(val)
}
return export(self.val)
}
func (self *Config) final_env_instructions(for_python bool, get_local_env func(string) (string, bool)) string {
seen := make(map[string]int, len(self.Env))
ans := make([]string, 0, len(self.Env))
for _, ei := range self.Env {
q := ei.Serialize(for_python, get_local_env)
if q != "" {
if pos, found := seen[ei.key]; found {
ans[pos] = q
} else {
seen[ei.key] = len(ans)
ans = append(ans, q)
}
}
}
return strings.Join(ans, "\n")
} }
type CopyInstruction struct { type CopyInstruction struct {
@ -140,3 +221,63 @@ func ParseCopyInstruction(spec string) (ans []*CopyInstruction, err error) {
} }
return return
} }
type ConfigSet struct {
all_configs []*Config
}
func config_for_hostname(hostname_to_match, username_to_match string, cs *ConfigSet) *Config {
matcher := func(q *Config) bool {
for _, pat := range strings.Split(q.Hostname, " ") {
upat := "*"
if strings.Contains(pat, "@") {
upat, pat, _ = strings.Cut(pat, "@")
}
var host_matched, user_matched bool
if matched, err := filepath.Match(pat, hostname_to_match); matched && err == nil {
host_matched = true
}
if matched, err := filepath.Match(upat, username_to_match); matched && err == nil {
user_matched = true
}
if host_matched && user_matched {
return true
}
}
return false
}
for _, c := range utils.Reversed(cs.all_configs) {
if matcher(c) {
return c
}
}
return cs.all_configs[0]
}
func (self *ConfigSet) line_handler(key, val string) error {
c := self.all_configs[len(self.all_configs)-1]
if key == "hostname" {
c = NewConfig()
self.all_configs = append(self.all_configs, c)
}
return c.Parse(key, val)
}
func load_config(hostname_to_match string, username_to_match string, overrides []string, paths ...string) (*Config, error) {
ans := &ConfigSet{all_configs: []*Config{NewConfig()}}
p := config.ConfigParser{LineHandler: ans.line_handler}
if len(paths) == 0 {
paths = []string{filepath.Join(utils.ConfigDir(), "ssh.conf")}
}
err := p.ParseFiles(paths...)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, err
}
if len(overrides) > 0 {
err = p.ParseOverrides(overrides...)
if err != nil {
return nil, err
}
}
return config_for_hostname(hostname_to_match, username_to_match, ans), nil
}

View File

@ -0,0 +1,76 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"fmt"
"kitty/tools/utils"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestSSHConfigParsing(t *testing.T) {
tdir := t.TempDir()
hostname := "unmatched"
username := ""
conf := ""
for_python := false
rt := func(expected_env ...string) {
cf := filepath.Join(tdir, "ssh.conf")
os.WriteFile(cf, []byte(conf), 0o600)
c, err := load_config(hostname, username, nil, cf)
if err != nil {
t.Fatal(err)
}
actual := c.final_env_instructions(for_python, func(key string) (string, bool) {
if key == "LOCAL_ENV" {
return "LOCAL_VAL", true
}
return "", false
})
if expected_env == nil {
expected_env = []string{}
}
diff := cmp.Diff(expected_env, utils.Splitlines(actual))
if diff != "" {
t.Fatalf("Unexpected env for\nhostname: %#v\nusername: %#v\nconf: %s\n%s", hostname, username, conf, diff)
}
}
rt()
conf = "env a=b"
rt(`export 'a'="b"`)
conf = "env a=b\nhostname 2\nenv a=c\nenv b=b"
rt(`export 'a'="b"`)
hostname = "2"
rt(`export 'a'="c"`, `export 'b'="b"`)
conf = "env a="
rt(`export 'a'=""`)
conf = "env a"
rt(`unset 'a'`)
conf = "env a=b\nhostname test@2\nenv a=c\nenv b=b"
hostname = "unmatched"
rt(`export 'a'="b"`)
hostname = "2"
rt(`export 'a'="b"`)
username = "test"
rt(`export 'a'="c"`, `export 'b'="b"`)
conf = "env a=b\nhostname 1 2\nenv a=c\nenv b=b"
username = ""
hostname = "unmatched"
rt(`export 'a'="b"`)
hostname = "1"
rt(`export 'a'="c"`, `export 'b'="b"`)
hostname = "2"
rt(`export 'a'="c"`, `export 'b'="b"`)
for_python = true
rt(`export ["a","c",false]`, `export ["b","b",false]`)
conf = "env a="
rt(`export ["a"]`)
conf = "env a"
rt(`unset ["a"]`)
}