diff --git a/kitty_tests/shm.py b/kitty_tests/shm.py new file mode 100644 index 000000000..b2cc70406 --- /dev/null +++ b/kitty_tests/shm.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# License: GPLv3 Copyright: 2023, Kovid Goyal + + +import os +import subprocess + +from kitty.constants import kitten_exe +from kitty.fast_data_types import shm_unlink +from kitty.shm import SharedMemory + +from . import BaseTest + + +class SHMTest(BaseTest): + + def test_shm_with_kitten(self): + data = os.urandom(333) + with SharedMemory(size=363) as shm: + shm.write_data_with_size(data) + cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'read', shm.name], stdout=subprocess.PIPE) + self.assertEqual(cp.returncode, 0) + self.assertEqual(cp.stdout, data) + self.assertRaises(FileNotFoundError, shm_unlink, shm.name) + cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'write'], input=data, stdout=subprocess.PIPE) + self.assertEqual(cp.returncode, 0) + name = cp.stdout.decode().strip() + with SharedMemory(name=name, unlink_on_exit=True) as shm: + q = shm.read_data_with_size() + self.assertEqual(data, q) diff --git a/tools/cmd/pytest/main.go b/tools/cmd/pytest/main.go new file mode 100644 index 000000000..973d68b6b --- /dev/null +++ b/tools/cmd/pytest/main.go @@ -0,0 +1,20 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package pytest + +import ( + "fmt" + + "kitty/tools/cli" + "kitty/tools/utils/shm" +) + +var _ = fmt.Print + +func EntryPoint(root *cli.Command) { + root = root.AddSubCommand(&cli.Command{ + Name: "__pytest__", + Hidden: true, + }) + shm.TestEntryPoint(root) +} diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index e18059b3b..8697d7ac0 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -5,10 +5,13 @@ package ssh import ( "errors" "fmt" + "net/url" "os" + "os/user" "strings" "kitty/tools/cli" + "kitty/tools/tty" "golang.org/x/exp/maps" "golang.org/x/sys/unix" @@ -16,6 +19,74 @@ import ( var _ = fmt.Print +func get_destination(hostname string) (username, hostname_for_match string) { + u, err := user.Current() + if err == nil { + username = u.Username + } + hostname_for_match = hostname + if strings.HasPrefix(hostname, "ssh://") { + p, err := url.Parse(hostname) + if err == nil { + hostname_for_match = p.Hostname() + if p.User.Username() != "" { + username = p.User.Username() + } + } + } else if strings.Contains(hostname, "@") && hostname[0] != '@' { + username, hostname_for_match, _ = strings.Cut(hostname, "@") + } + if strings.Contains(hostname, "@") && hostname[0] != '@' { + _, hostname_for_match, _ = strings.Cut(hostname_for_match, "@") + } + hostname_for_match, _, _ = strings.Cut(hostname_for_match, ":") + return +} + +func add_cloned_env(val string) map[string]string { + return nil // TODO: Implement me +} + +func parse_kitten_args(found_extra_args []string, username, hostname_for_match string) (overrides []string, literal_env map[string]string) { + literal_env = make(map[string]string) + overrides = make([]string, 0, 4) + for i, a := range found_extra_args { + if i%2 == 0 { + continue + } + if key, val, found := strings.Cut(a, "="); found { + if key == "clone_env" { + le := add_cloned_env(val) + if le != nil { + literal_env = le + } + } else if key != "hostname" { + overrides = append(overrides, key+" "+val) + } + } + } + if len(overrides) > 0 { + overrides = append([]string{"hostname " + username + "@" + hostname_for_match}, overrides...) + } + return +} + +func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) { + cmd := append([]string{ssh_exe()}, ssh_args...) + hostname, remote_args := server_args[0], server_args[1:] + if len(remote_args) == 0 { + cmd = append(cmd, "-t") + } + insertion_point := len(cmd) + cmd = append(cmd, "--", hostname) + uname, hostname_for_match := get_destination(hostname) + overrides, literal_env := parse_kitten_args(found_extra_args, uname, hostname_for_match) + if insertion_point > 0 && overrides != nil && literal_env != nil { + } + // TODO: Implement me + return +} + func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { if len(args) > 0 { switch args[0] { @@ -44,10 +115,13 @@ func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { } return 1, unix.Exec(ssh_exe(), append([]string{"ssh"}, args...), os.Environ()) } - if false { - return len(ssh_args) + len(server_args), nil + if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" { + return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window") } - return + if !tty.IsTerminal(os.Stdin.Fd()) { + return 1, fmt.Errorf("The SSH kitten is meant for interactive use only, STDIN must be a terminal") + } + return run_ssh(ssh_args, server_args, found_extra_args) } func EntryPoint(parent *cli.Command) { diff --git a/tools/cmd/tool/main.go b/tools/cmd/tool/main.go index 51798ccb9..a83ff5b56 100644 --- a/tools/cmd/tool/main.go +++ b/tools/cmd/tool/main.go @@ -10,6 +10,7 @@ import ( "kitty/tools/cmd/clipboard" "kitty/tools/cmd/edit_in_kitty" "kitty/tools/cmd/icat" + "kitty/tools/cmd/pytest" "kitty/tools/cmd/ssh" "kitty/tools/cmd/unicode_input" "kitty/tools/cmd/update_self" @@ -35,6 +36,8 @@ func KittyToolEntryPoints(root *cli.Command) { ssh.EntryPoint(root) // unicode_input unicode_input.EntryPoint(root) + // __pytest__ + pytest.EntryPoint(root) // __hold_till_enter__ root.AddSubCommand(&cli.Command{ Name: "__hold_till_enter__", diff --git a/tools/utils/shm/shm.go b/tools/utils/shm/shm.go index 79a04a716..90a7de63f 100644 --- a/tools/utils/shm/shm.go +++ b/tools/utils/shm/shm.go @@ -5,13 +5,17 @@ package shm import ( "crypto/rand" "encoding/base32" + "encoding/binary" "errors" "fmt" + "io" not_rand "math/rand" "os" "strconv" "strings" + "kitty/tools/cli" + "golang.org/x/sys/unix" ) @@ -109,3 +113,69 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) { } return } + +func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) { + p := buf + for len(p) > 0 { + n, err := f.Read(p) + p = p[n:] + if err != nil { + if len(p) == 0 && errors.Is(err, io.EOF) { + err = nil + } + return buf[:len(buf)-len(p)], err + } + } + return buf, nil +} + +func read_with_size(f *os.File) ([]byte, error) { + szbuf := []byte{0, 0, 0, 0} + szbuf, err := read_till_buf_full(f, szbuf) + if err != nil { + return nil, err + } + size := int(binary.BigEndian.Uint32(szbuf)) + return read_till_buf_full(f, make([]byte, size)) +} + +func test_integration_with_python(args []string) (rc int, err error) { + switch args[0] { + default: + return 1, fmt.Errorf("Unknown test type: %s", args[0]) + case "read": + data, err := ReadWithSizeAndUnlink(args[1]) + if err != nil { + return 1, err + } + _, err = os.Stdout.Write(data) + if err != nil { + return 1, err + } + case "write": + data, err := io.ReadAll(os.Stdin) + if err != nil { + return 1, err + } + mmap, err := CreateTemp("shmtest-", uint64(len(data)+4)) + if err != nil { + return 1, err + } + defer mmap.Close() + binary.BigEndian.PutUint32(mmap.Slice(), uint32(len(data))) + copy(mmap.Slice()[4:], data) + fmt.Println(mmap.Name()) + } + return 0, nil +} + +func TestEntryPoint(root *cli.Command) { + root.AddSubCommand(&cli.Command{ + Name: "shm", + OnlyArgsAllowed: true, + Run: func(cmd *cli.Command, args []string) (rc int, err error) { + return test_integration_with_python(args) + }, + }) + +} diff --git a/tools/utils/shm/shm_fs.go b/tools/utils/shm/shm_fs.go index d5866f3b9..1b4b0c298 100644 --- a/tools/utils/shm/shm_fs.go +++ b/tools/utils/shm/shm_fs.go @@ -113,7 +113,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) { return file_mmap(f, size, WRITE, true, special_name) } -func Open(name string, size uint64) (MMap, error) { +func open(name string) (*os.File, error) { ans, err := os.OpenFile(file_path_from_name(name), os.O_RDONLY, 0) if err != nil { if errors.Is(err, fs.ErrNotExist) { @@ -123,5 +123,23 @@ func Open(name string, size uint64) (MMap, error) { } return nil, err } + return ans, nil +} + +func Open(name string, size uint64) (MMap, error) { + ans, err := open(name) + if err != nil { + return nil, err + } return file_mmap(ans, size, READ, false, name) } + +func ReadWithSizeAndUnlink(name string) ([]byte, error) { + f, err := open(name) + if err != nil { + return nil, err + } + defer f.Close() + defer os.Remove(f.Name()) + return read_with_size(f) +} diff --git a/tools/utils/shm/shm_syscall.go b/tools/utils/shm/shm_syscall.go index 8974af3f5..a48ab0e2d 100644 --- a/tools/utils/shm/shm_syscall.go +++ b/tools/utils/shm/shm_syscall.go @@ -151,3 +151,13 @@ func Open(name string, size uint64) (MMap, error) { } return syscall_mmap(ans, size, READ, false) } + +func ReadWithSizeAndUnlink(name string) ([]byte, error) { + f, err := shm_open(name, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + defer f.Close() + defer shm_unlink(f.Name()) + return read_with_size(f) +}