diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index 8697d7ac0..019c6c396 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -3,8 +3,10 @@ package ssh import ( + "encoding/json" "errors" "fmt" + "io/fs" "net/url" "os" "os/user" @@ -12,6 +14,7 @@ import ( "kitty/tools/cli" "kitty/tools/tty" + "kitty/tools/utils/shm" "golang.org/x/exp/maps" "golang.org/x/sys/unix" @@ -43,11 +46,35 @@ func get_destination(hostname string) (username, hostname_for_match string) { return } -func add_cloned_env(val string) map[string]string { - return nil // TODO: Implement me +func read_data_from_shared_memory(shm_name string) ([]byte, error) { + data, err := shm.ReadWithSizeAndUnlink(shm_name, func(f *os.File) error { + s, err := f.Stat() + if err != nil { + return fmt.Errorf("Failed to stat SHM file with error: %w", err) + } + if stat, ok := s.Sys().(unix.Stat_t); ok { + if os.Getuid() != int(stat.Uid) || os.Getgid() != int(stat.Gid) { + return fmt.Errorf("Incorrect owner on SHM file") + } + } + if s.Mode().Perm() != 0o600 { + return fmt.Errorf("Incorrect permissions on SHM file") + } + return nil + }) + return data, err } -func parse_kitten_args(found_extra_args []string, username, hostname_for_match string) (overrides []string, literal_env map[string]string) { +func add_cloned_env(val string) (ans map[string]string, err error) { + data, err := read_data_from_shared_memory(val) + if err != nil { + return nil, err + } + err = json.Unmarshal(data, &ans) + return ans, err +} + +func parse_kitten_args(found_extra_args []string, username, hostname_for_match string) (overrides []string, literal_env map[string]string, ferr error) { literal_env = make(map[string]string) overrides = make([]string, 0, 4) for i, a := range found_extra_args { @@ -56,8 +83,12 @@ func parse_kitten_args(found_extra_args []string, username, hostname_for_match s } if key, val, found := strings.Cut(a, "="); found { if key == "clone_env" { - le := add_cloned_env(val) - if le != nil { + le, err := add_cloned_env(val) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return nil, nil, ferr + } + } else if le != nil { literal_env = le } } else if key != "hostname" { @@ -80,7 +111,10 @@ func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err erro insertion_point := len(cmd) cmd = append(cmd, "--", hostname) uname, hostname_for_match := get_destination(hostname) - overrides, literal_env := parse_kitten_args(found_extra_args, uname, hostname_for_match) + overrides, literal_env, err := parse_kitten_args(found_extra_args, uname, hostname_for_match) + if err != nil { + return 1, err + } if insertion_point > 0 && overrides != nil && literal_env != nil { } // TODO: Implement me diff --git a/tools/cmd/ssh/main_test.go b/tools/cmd/ssh/main_test.go new file mode 100644 index 000000000..7d02100f7 --- /dev/null +++ b/tools/cmd/ssh/main_test.go @@ -0,0 +1,39 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "kitty/tools/utils/shm" + "testing" + + "github.com/google/go-cmp/cmp" +) + +var _ = fmt.Print + +func TestCloneEnv(t *testing.T) { + env := map[string]string{"a": "1", "b": "2"} + data, err := json.Marshal(env) + if err != nil { + t.Fatal(err) + } + mmap, err := shm.CreateTemp("", 128) + if err != nil { + t.Fatal(err) + } + defer mmap.Unlink() + copy(mmap.Slice()[4:], data) + binary.BigEndian.PutUint32(mmap.Slice(), uint32(len(data))) + mmap.Close() + x, err := add_cloned_env(mmap.Name()) + if err != nil { + t.Fatal(err) + } + diff := cmp.Diff(env, x) + if diff != "" { + t.Fatalf("Failed to deserialize env\n%s", diff) + } +}