diff --git a/gen-rc-go.py b/gen-rc-go.py index 33079e680..1a9b91542 100755 --- a/gen-rc-go.py +++ b/gen-rc-go.py @@ -155,7 +155,7 @@ def build_go_code(name: str, cmd: RemoteCommand, seq: OptionSpecSeq, template: s if o.aliases: alias_map[o.long] = tuple(o.aliases) a(o.to_flag_definition()) - if o.dest == 'no_response': + if o.dest in ('no_response', 'response_timeout'): continue od.append(f'{o.go_var_name} {o.go_type}') ov.append(o.set_flag_value()) diff --git a/go.mod b/go.mod index 0815f6d29..ceb162e69 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/fatih/color v1.13.0 github.com/mattn/go-isatty v0.0.14 github.com/mattn/go-runewidth v0.0.13 + github.com/seancfoley/ipaddress-go v1.2.1 github.com/spf13/cobra v1.5.0 github.com/spf13/pflag v1.0.5 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa @@ -17,4 +18,5 @@ require ( github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/mattn/go-colorable v0.1.9 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/seancfoley/bintree v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index e07abef29..f35e817c2 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,10 @@ github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/seancfoley/bintree v1.1.0 h1:6J0rj9hLNLIcWSsfYdZ4ZHkMHokaK/PHkak8qyBO/mc= +github.com/seancfoley/bintree v1.1.0/go.mod h1:CtE6qO6/n9H3V2CAGEC0lpaYr6/OijhNaMG/dt7P70c= +github.com/seancfoley/ipaddress-go v1.2.1 h1:yEZxnyC6NQEDDPflyQm4KkWozffx1vHWsx+knKBr/n0= +github.com/seancfoley/ipaddress-go v1.2.1/go.mod h1:/UEVHyrBg1ASVap2ffdY2cq5UMYIX9f3QW3uWSVqpbo= github.com/spf13/cobra v1.5.0 h1:X+jTBEBqF0bHN+9cSMgmfuvv2VHJ9ezmFNf9Y/XstYU= github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/tools/cmd/at/main.go b/tools/cmd/at/main.go index 31e3f299d..4658fc71a 100644 --- a/tools/cmd/at/main.go +++ b/tools/cmd/at/main.go @@ -29,20 +29,12 @@ func add_bool_set(cmd *cobra.Command, name string, short string, usage string) * } type GlobalOptions struct { - to_address, password string - to_address_is_from_env_var bool + to_network, to_address, password string + to_address_is_from_env_var bool } var global_options GlobalOptions -func cut(a string, sep string) (string, string, bool) { - idx := strings.Index(a, sep) - if idx < 0 { - return "", "", false - } - return a[:idx], a[idx+len(sep):], true -} - func get_pubkey(encoded_key string) (encryption_version string, pubkey []byte, err error) { if encoded_key == "" { encoded_key = os.Getenv("KITTY_PUBLIC_KEY") @@ -51,7 +43,7 @@ func get_pubkey(encoded_key string) (encryption_version string, pubkey []byte, e return } } - encryption_version, encoded_key, found := cut(encoded_key, ":") + encryption_version, encoded_key, found := utils.Cut(encoded_key, ":") if !found { err = fmt.Errorf("KITTY_PUBLIC_KEY environment variable does not have a : in it") return @@ -77,23 +69,28 @@ type serializer_func func(rc *utils.RemoteControlCmd) ([]byte, error) var serializer serializer_func = simple_serializer -func create_serializer(password string, encoded_pubkey string) (ans serializer_func, err error) { +func create_serializer(password string, encoded_pubkey string, response_timeout float64) (ans serializer_func, timeout float64, err error) { + timeout = response_timeout if password != "" { encryption_version, pubkey, err := get_pubkey(encoded_pubkey) if err != nil { - return nil, err + return nil, timeout, err } ans = func(rc *utils.RemoteControlCmd) (ans []byte, err error) { ec, err := crypto.Encrypt_cmd(rc, global_options.password, pubkey, encryption_version) ans, err = json.Marshal(ec) return } + if timeout < 120 { + timeout = 120 + } + return ans, timeout, nil } - return simple_serializer, nil + return simple_serializer, timeout, nil } func send_rc_command(rc *utils.RemoteControlCmd, timeout float64) (err error) { - serializer, err = create_serializer(global_options.password, "") + serializer, timeout, err = create_serializer(global_options.password, "", timeout) if err != nil { return } @@ -165,7 +162,14 @@ func EntryPoint(tool_root *cobra.Command) *cobra.Command { *to = os.Getenv("KITTY_LISTEN_ON") global_options.to_address_is_from_env_var = true } - global_options.to_address = *to + if *to != "" { + network, address, err := utils.ParseSocketAddress(*to) + if err != nil { + return err + } + global_options.to_network = network + global_options.to_address = address + } q, err := get_password(*password, *password_file, *password_env, use_password.Choice) global_options.password = q return err diff --git a/tools/cmd/at/main_test.go b/tools/cmd/at/main_test.go index 587e94f95..4de4d87a3 100644 --- a/tools/cmd/at/main_test.go +++ b/tools/cmd/at/main_test.go @@ -35,7 +35,7 @@ func TestCommandToJSON(t *testing.T) { } func TestRCSerialization(t *testing.T) { - serializer, err := create_serializer("", "") + serializer, _, err := create_serializer("", "", 0) if err != nil { t.Fatal(err) } @@ -62,7 +62,7 @@ func TestRCSerialization(t *testing.T) { if err != nil { t.Fatal(err) } - serializer, err = create_serializer("tpw", pubkey) + serializer, _, err = create_serializer("tpw", pubkey, 0) if err != nil { t.Fatal(err) } diff --git a/tools/cmd/at/template.go b/tools/cmd/at/template.go index d5d3fa658..409f61032 100644 --- a/tools/cmd/at/template.go +++ b/tools/cmd/at/template.go @@ -54,7 +54,12 @@ func run_CMD_NAME(cmd *cobra.Command, args []string) (err error) { if err == nil { rc.NoResponse = nrv } - err = send_rc_command(rc, WAIT_TIMEOUT) + var timeout float64 = WAIT_TIMEOUT + rt, err := cmd.Flags().GetFloat64("response-timeout") + if err == nil { + timeout = rt + } + err = send_rc_command(rc, timeout) return } diff --git a/tools/utils/sockets.go b/tools/utils/sockets.go new file mode 100644 index 000000000..245e9a2fd --- /dev/null +++ b/tools/utils/sockets.go @@ -0,0 +1,46 @@ +package utils + +import ( + "fmt" + "github.com/seancfoley/ipaddress-go/ipaddr" + "runtime" + "strings" +) + +func Cut(s string, sep string) (string, string, bool) { + if i := strings.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, "", false +} + +func ParseSocketAddress(spec string) (network string, addr string, err error) { + network, addr, found := Cut(spec, ":") + if !found { + err = fmt.Errorf("Invalid socket address: %s", spec) + return + } + if network == "unix" { + if strings.HasSuffix(addr, "@") && runtime.GOOS != "linux" { + err = fmt.Errorf("Abstract UNIX sockets are only supported on Linux. Cannot use: %s", spec) + } + return + } + + if network == "tcp" || network == "tcp6" || network == "tcp4" { + host := ipaddr.NewHostName(addr) + if host.IsAddress() { + network = "ip" + } + return + } + if network == "ip" || network == "ip6" || network == "ip4" { + host := ipaddr.NewHostName(addr) + if !host.IsAddress() { + err = fmt.Errorf("Not a valid IP address: %#v. Cannot use: %s", addr, spec) + } + return + } + err = fmt.Errorf("Unknown network type: %#v in socket address: %s", network, spec) + return +} diff --git a/tools/utils/sockets_test.go b/tools/utils/sockets_test.go new file mode 100644 index 000000000..071cde079 --- /dev/null +++ b/tools/utils/sockets_test.go @@ -0,0 +1,57 @@ +package utils + +import ( + "fmt" + "runtime" + "testing" +) + +func TestParseSocketAddress(t *testing.T) { + en := "unix" + ea := "/tmp/test" + var eerr error = nil + + test := func(spec string) { + n, a, err := ParseSocketAddress(spec) + if err != eerr { + if eerr == nil { + t.Fatalf("Parsing of %s failed with unexpected error: %s", spec, err) + } + if err == nil { + t.Fatalf("Parsing of %s did not fail, unexpectedly", spec) + } + return + } + if a != ea { + t.Fatalf("actual != expected, %s != %s, when parsing %s", a, ea, spec) + } + if n != en { + t.Fatalf("actual != expected, %s != %s, when parsing %s", n, en, spec) + } + } + + testf := func(spec string, netw string, addr string) { + eerr = nil + en = netw + ea = addr + test(spec) + } + teste := func(spec string, e string) { + eerr = fmt.Errorf(e) + test(spec) + } + + test("unix:/tmp/test") + if runtime.GOOS == "linux" { + ea = "@test" + } else { + eerr = fmt.Errorf("bad kitty") + } + test("unix:@test") + testf("tcp:localhost:123", "tcp", "localhost:123") + testf("tcp:1.1.1.1:123", "ip", "1.1.1.1:123") + testf("tcp:fe80::1", "ip", "fe80::1") + teste("xxx", "bad kitty") + teste("xxx:yyy", "bad kitty") + teste(":yyy", "bad kitty") +}