// License: GPLv3 Copyright: 2023, Kovid Goyal, package ssh import ( "archive/tar" "encoding/json" "errors" "fmt" "io/fs" "os" "path/filepath" "strings" "time" "kitty/tools/config" "kitty/tools/utils" "kitty/tools/utils/paths" "kitty/tools/utils/shlex" "github.com/bmatcuk/doublestar" "golang.org/x/sys/unix" ) var _ = fmt.Print type EnvInstruction struct { key, val string delete_on_remote, copy_from_local, literal_quote bool } func quote_for_sh(val string, literal_quote bool) string { if literal_quote { return utils.QuoteStringForSH(val) } // See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html b := strings.Builder{} b.Grow(len(val) + 16) b.WriteRune('"') runes := []rune(val) for i, ch := range runes { if ch == '\\' || ch == '`' || ch == '"' || (ch == '$' && i+1 < len(runes) && runes[i+1] == '(') { // special chars are escaped // $( is escaped to prevent execution b.WriteRune('\\') } b.WriteRune(ch) } b.WriteRune('"') return b.String() } func (self *EnvInstruction) Serialize(for_python bool, get_local_env func(string) (string, bool)) string { var unset func() string var export func(string) string if for_python { dumps := func(x ...any) string { ans, _ := json.Marshal(x) return utils.UnsafeBytesToString(ans) } export = func(val string) string { if val == "" { return fmt.Sprintf("export %s", dumps(self.key)) } return fmt.Sprintf("export %s", dumps(self.key, val, self.literal_quote)) } unset = func() string { return fmt.Sprintf("unset %s", dumps(self.key)) } } else { kq := utils.QuoteStringForSH(self.key) unset = func() string { return fmt.Sprintf("unset %s", kq) } export = func(val string) string { return fmt.Sprintf("export %s=%s", kq, quote_for_sh(val, self.literal_quote)) } } if self.delete_on_remote { return unset() } if self.copy_from_local { val, found := get_local_env(self.key) if !found { return "" } return export(val) } return export(self.val) } func final_env_instructions(for_python bool, get_local_env func(string) (string, bool), env ...*EnvInstruction) string { seen := make(map[string]int, len(env)) ans := make([]string, 0, len(env)) for _, ei := range env { q := ei.Serialize(for_python, get_local_env) if q != "" { if pos, found := seen[ei.key]; found { ans[pos] = q } else { seen[ei.key] = len(ans) ans = append(ans, q) } } } return strings.Join(ans, "\n") } type CopyInstruction struct { local_path, arcname string exclude_patterns []string } func ParseEnvInstruction(spec string) (ans []*EnvInstruction, err error) { const COPY_FROM_LOCAL string = "_kitty_copy_env_var_" ei := &EnvInstruction{} found := false ei.key, ei.val, found = strings.Cut(spec, "=") ei.key = strings.TrimSpace(ei.key) if found { ei.val = strings.TrimSpace(ei.val) if ei.val == COPY_FROM_LOCAL { ei.val = "" ei.copy_from_local = true } } else { ei.delete_on_remote = true } if ei.key == "" { err = fmt.Errorf("The env directive must not be empty") } ans = []*EnvInstruction{ei} return } var paths_ctx *paths.Ctx func resolve_file_spec(spec string, is_glob bool) ([]string, error) { if paths_ctx == nil { paths_ctx = &paths.Ctx{} } ans := os.ExpandEnv(paths_ctx.ExpandHome(spec)) if !filepath.IsAbs(ans) { ans = paths_ctx.AbspathFromHome(ans) } if is_glob { files, err := doublestar.Glob(ans) if err != nil { return nil, fmt.Errorf("%s is not a valid glob pattern with error: %w", spec, err) } if len(files) == 0 { return nil, fmt.Errorf("%s matches no files", spec) } return files, nil } err := unix.Access(ans, unix.R_OK) if err != nil { if errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("%s does not exist", spec) } return nil, fmt.Errorf("Cannot read from: %s with error: %w", spec, err) } return []string{ans}, nil } func get_arcname(loc, dest, home string) (arcname string) { if dest != "" { arcname = dest } else { arcname = filepath.Clean(loc) if filepath.HasPrefix(arcname, home) { ra, err := filepath.Rel(home, arcname) if err == nil { arcname = ra } } } prefix := "home/" if strings.HasPrefix(arcname, "/") { prefix = "root" } return prefix + arcname } func ParseCopyInstruction(spec string) (ans []*CopyInstruction, err error) { args, err := shlex.Split("copy " + spec) if err != nil { return nil, err } opts, args, err := parse_copy_args(args) if err != nil { return nil, err } locations := make([]string, 0, len(args)) for _, arg := range args { locs, err := resolve_file_spec(arg, opts.Glob) if err != nil { return nil, err } locations = append(locations, locs...) } if len(locations) == 0 { return nil, fmt.Errorf("No files to copy specified") } if len(locations) > 1 && opts.Dest != "" { return nil, fmt.Errorf("Specifying a remote location with more than one file is not supported") } home := paths_ctx.HomePath() ans = make([]*CopyInstruction, 0, len(locations)) for _, loc := range locations { ci := CopyInstruction{local_path: loc, exclude_patterns: opts.Exclude} if opts.SymlinkStrategy != "preserve" { ci.local_path, err = filepath.EvalSymlinks(loc) if err != nil { return nil, fmt.Errorf("Failed to resolve symlinks in %#v with error: %w", loc, err) } } if opts.SymlinkStrategy == "resolve" { ci.arcname = get_arcname(ci.local_path, opts.Dest, home) } else { ci.arcname = get_arcname(loc, opts.Dest, home) } ans = append(ans, &ci) } return } type file_unique_id struct { dev, inode uint64 } func excluded(pattern, path string) bool { if strings.HasPrefix(pattern, "*") && !strings.HasPrefix(pattern, "**") { path = filepath.Base(path) } if matched, err := doublestar.PathMatch(pattern, path); matched && err == nil { return true } return false } func get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string, local_path, arcname string, exclude_patterns []string, recurse bool) error { s, err := os.Lstat(local_path) if err != nil { return err } u, ok := s.Sys().(unix.Stat_t) cb := func(h *tar.Header, data []byte) error { h.Name = arcname h.Size = int64(len(data)) h.Mode = int64(s.Mode()) h.ModTime = s.ModTime() h.Uid, h.Gid = 0, 0 h.Uname, h.Gname = "", "" h.Format = tar.FormatPAX if ok { h.AccessTime = time.Unix(0, u.Atim.Nano()) h.ChangeTime = time.Unix(0, u.Ctim.Nano()) } return callback(h, data) } // we only copy regular files, directories and symlinks switch s.Mode().Type() { case fs.ModeSymlink: target, err := os.Readlink(local_path) if err != nil { return err } err = cb(&tar.Header{ Typeflag: tar.TypeSymlink, Linkname: target, }, nil) if err != nil { return err } case fs.ModeDir: err = cb(&tar.Header{Typeflag: tar.TypeDir}, nil) if err != nil { return err } if recurse { local_path = filepath.Clean(local_path) return filepath.WalkDir(local_path, func(path string, d fs.DirEntry, werr error) error { clean_path := filepath.Clean(path) if clean_path == local_path { return nil } for _, pat := range exclude_patterns { if excluded(pat, clean_path) { return nil } } if werr == nil { rel, err := filepath.Rel(local_path, path) if err != nil { aname := filepath.Join(arcname, rel) return get_file_data(callback, seen, path, aname, nil, false) } } return nil }) } case 0: // Regular file fid := file_unique_id{dev: u.Dev, inode: u.Ino} if prev, ok := seen[fid]; ok { // Hard link err = cb(&tar.Header{Typeflag: tar.TypeLink, Linkname: prev}, nil) if err != nil { return err } } seen[fid] = arcname data, err := os.ReadFile(local_path) if err != nil { return err } err = cb(&tar.Header{Typeflag: tar.TypeReg}, data) if err != nil { return err } } return nil } func (ci *CopyInstruction) get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string) (err error) { ep := ci.exclude_patterns for _, folder_name := range []string{"__pycache__", ".DS_Store"} { ep = append(ep, "*/"+folder_name, "*/"+folder_name+"/*") } return get_file_data(callback, seen, ci.local_path, ci.arcname, ep, true) } type ConfigSet struct { all_configs []*Config } func config_for_hostname(hostname_to_match, username_to_match string, cs *ConfigSet) *Config { matcher := func(q *Config) bool { for _, pat := range strings.Split(q.Hostname, " ") { upat := "*" if strings.Contains(pat, "@") { upat, pat, _ = strings.Cut(pat, "@") } var host_matched, user_matched bool if matched, err := filepath.Match(pat, hostname_to_match); matched && err == nil { host_matched = true } if matched, err := filepath.Match(upat, username_to_match); matched && err == nil { user_matched = true } if host_matched && user_matched { return true } } return false } for _, c := range utils.Reversed(cs.all_configs) { if matcher(c) { return c } } return cs.all_configs[0] } func (self *ConfigSet) line_handler(key, val string) error { c := self.all_configs[len(self.all_configs)-1] if key == "hostname" { c = NewConfig() self.all_configs = append(self.all_configs, c) } return c.Parse(key, val) } func load_config(hostname_to_match string, username_to_match string, overrides []string, paths ...string) (*Config, []config.ConfigLine, error) { ans := &ConfigSet{all_configs: []*Config{NewConfig()}} p := config.ConfigParser{LineHandler: ans.line_handler} if len(paths) == 0 { paths = []string{filepath.Join(utils.ConfigDir(), "ssh.conf")} } paths = utils.Filter(paths, func(x string) bool { return x != "" }) err := p.ParseFiles(paths...) if err != nil && !errors.Is(err, fs.ErrNotExist) { return nil, nil, err } if len(overrides) > 0 { err = p.ParseOverrides(overrides...) if err != nil { return nil, nil, err } } return config_for_hostname(hostname_to_match, username_to_match, ans), p.BadLines(), nil }