From b2e610f9b10f8ad36d357ad7fb198f68fe6738b0 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Fri, 26 Aug 2022 20:09:03 +0530 Subject: [PATCH] Implement socket I/O --- tools/cmd/at/main.go | 5 ++- tools/cmd/at/socket_io.go | 82 +++++++++++++++++++++++++++++++++++++++ tools/utils/sockets.go | 2 +- 3 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 tools/cmd/at/socket_io.go diff --git a/tools/cmd/at/main.go b/tools/cmd/at/main.go index 63709782b..fd6756c8c 100644 --- a/tools/cmd/at/main.go +++ b/tools/cmd/at/main.go @@ -223,7 +223,10 @@ func send_rc_command(io_data *rc_io_data) (err error) { return } } else { - return fmt.Errorf("TODO: Implement socket IO") + response, err = get_response(do_socket_io, io_data) + if err != nil { + return + } } if err != nil || response == nil { return err diff --git a/tools/cmd/at/socket_io.go b/tools/cmd/at/socket_io.go new file mode 100644 index 000000000..2ed9d7359 --- /dev/null +++ b/tools/cmd/at/socket_io.go @@ -0,0 +1,82 @@ +// License: GPLv3 Copyright: 2022, Kovid Goyal, + +package at + +import ( + "bytes" + "errors" + "io" + "net" + "time" + + "kitty/tools/utils" + "kitty/tools/wcswidth" +) + +func write_all_to_conn(conn *net.Conn, data []byte) error { + for len(data) > 0 { + n, err := (*conn).Write(data) + if err != nil && errors.Is(err, io.ErrShortWrite) { + err = nil + } + if err != nil { + return err + } + data = data[n:] + } + return nil +} + +func read_response_from_conn(conn *net.Conn, timeout time.Duration) (serialized_response []byte, err error) { + p := wcswidth.EscapeCodeParser{} + keep_going := true + p.HandleDCS = func(data []byte) error { + if bytes.HasPrefix(data, []byte("@kitty-cmd")) { + serialized_response = data[len("@kitty-cmd"):] + keep_going = false + } + return nil + } + buf := make([]byte, utils.DEFAULT_IO_BUFFER_SIZE) + for keep_going { + var n int + (*conn).SetDeadline(time.Now().Add(timeout)) + n, err = (*conn).Read(buf) + if err != nil { + keep_going = false + break + } + p.Parse(buf[:n]) + } + return +} + +func simple_socket_io(conn *net.Conn, io_data *rc_io_data) (serialized_response []byte, err error) { + for { + var chunk []byte + chunk, err = io_data.next_chunk(false) + if err != nil { + return + } + if len(chunk) == 0 { + break + } + err = write_all_to_conn(conn, chunk) + if err != nil { + return + } + } + if io_data.rc.NoResponse { + return + } + return read_response_from_conn(conn, io_data.timeout) +} + +func do_socket_io(io_data *rc_io_data) (serialized_response []byte, err error) { + conn, err := net.Dial(global_options.to_network, global_options.to_address) + if err != nil { + return + } + defer conn.Close() + return simple_socket_io(&conn, io_data) +} diff --git a/tools/utils/sockets.go b/tools/utils/sockets.go index a89813b3b..c5823a2fb 100644 --- a/tools/utils/sockets.go +++ b/tools/utils/sockets.go @@ -20,7 +20,7 @@ func Cut(s string, sep string) (string, string, bool) { func ParseSocketAddress(spec string) (network string, addr string, err error) { network, addr, found := Cut(spec, ":") if !found { - err = fmt.Errorf("Invalid socket address: %s", spec) + err = fmt.Errorf("Invalid socket address: %s must be prefix by a protocol such as unix:", spec) return } if network == "unix" {