From 43e93414eabff62f2a2adc86f810b2ea6190826f Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Thu, 18 Aug 2022 15:40:19 +0530 Subject: [PATCH] Some work on implementing TTYIO --- go.mod | 1 + go.sum | 3 + tools/cmd/at/main.go | 49 +++++++-- tools/utils/io.go | 27 ++++- tools/utils/tty.go | 237 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 307 insertions(+), 10 deletions(-) create mode 100644 tools/utils/tty.go diff --git a/go.mod b/go.mod index ceb162e69..1ca9beddb 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/fatih/color v1.13.0 github.com/mattn/go-isatty v0.0.14 github.com/mattn/go-runewidth v0.0.13 + github.com/pkg/term v1.1.0 github.com/seancfoley/ipaddress-go v1.2.1 github.com/spf13/cobra v1.5.0 github.com/spf13/pflag v1.0.5 diff --git a/go.sum b/go.sum index f35e817c2..a1b68de4d 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9 github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/pkg/term v1.1.0 h1:xIAAdCMh3QIAy+5FrE8Ad8XoDhEU4ufwbaSozViP9kk= +github.com/pkg/term v1.1.0/go.mod h1:E25nymQcrSllhX42Ok8MRm1+hyBdHY0dCeiKZ9jpNGw= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -25,6 +27,7 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/tools/cmd/at/main.go b/tools/cmd/at/main.go index a09244463..17dbb1c32 100644 --- a/tools/cmd/at/main.go +++ b/tools/cmd/at/main.go @@ -3,9 +3,10 @@ package at import ( "encoding/json" "fmt" - "io/ioutil" + "io" "os" "strings" + "time" "github.com/mattn/go-isatty" "github.com/spf13/cobra" @@ -89,6 +90,36 @@ func create_serializer(password string, encoded_pubkey string, response_timeout return simple_serializer, timeout, nil } +type TTYIO interface { + WriteAllWithTimeout(b []byte, d time.Duration) (n int, err error) + WriteFromReader(r utils.Reader, read_timeout time.Duration, write_timeout time.Duration) (n int, err error) + + Restore() error + Close() error +} + +func do_tty_io(tty TTYIO, input utils.Reader, response_timeout time.Duration) (err error) { + + defer func() { + tty.Restore() + tty.Close() + }() + + _, err = tty.WriteAllWithTimeout([]byte("\x1bP@kitty-cmd"), 2*time.Second) + if err != nil { + return err + } + _, err = tty.WriteFromReader(input, 2*time.Second, 2*time.Second) + if err != nil { + return err + } + _, err = tty.WriteAllWithTimeout([]byte("\x1b\\"), 2*time.Second) + if err != nil { + return err + } + return +} + func send_rc_command(rc *utils.RemoteControlCmd, timeout float64) (err error) { serializer, timeout, err = create_serializer(global_options.password, "", timeout) if err != nil { @@ -99,9 +130,15 @@ func send_rc_command(rc *utils.RemoteControlCmd, timeout float64) (err error) { return } r := utils.BytesReader{Data: d} - - println(string(r.Data)) - return + if global_options.to_network == "" { + tty, err := utils.OpenControllingTerm(true) + if err != nil { + return err + } + return do_tty_io(tty, &r, time.Duration(timeout*1e9)) + } else { + return fmt.Errorf("TODO: Implement socket IO") + } } func get_password(password string, password_file string, password_env string, use_password string) (ans string, err error) { @@ -119,7 +156,7 @@ func get_password(password string, password_file string, password_env string, us ans = string(q) } } else { - q, err := ioutil.ReadAll(os.Stdin) + q, err := io.ReadAll(os.Stdin) if err != nil { ans = strings.TrimRight(string(q), " \n\t") } @@ -130,7 +167,7 @@ func get_password(password string, password_file string, password_env string, us } } } else { - q, err := ioutil.ReadFile(password_file) + q, err := os.ReadFile(password_file) if err != nil { ans = strings.TrimRight(string(q), " \n\t") } diff --git a/tools/utils/io.go b/tools/utils/io.go index 29bcc4311..02454fabf 100644 --- a/tools/utils/io.go +++ b/tools/utils/io.go @@ -2,18 +2,37 @@ package utils import ( "io" + "time" +) + +const ( + DEFAULT_IO_BUFFER_SIZE = 8192 ) type BytesReader struct { Data []byte - Pos int64 +} + +type Reader interface { + ReadWithTimeout(b []byte, timeout time.Duration) (n int, err error) + GetBuf() []byte } func (self *BytesReader) Read(b []byte) (n int, err error) { - if self.Pos >= int64(len(self.Data)) { + if len(self.Data) == 0 { return 0, io.EOF } - n = copy(b, self.Data[self.Pos:]) - self.Pos += int64(n) + n = copy(b, self.Data) + self.Data = self.Data[n:] + return +} + +func (self *BytesReader) ReadWithTimeout(b []byte, timeout time.Duration) (n int, err error) { + return self.Read(b) +} + +func (self *BytesReader) GetBuf() (ans []byte) { + ans = self.Data + self.Data = make([]byte, 0) return } diff --git a/tools/utils/tty.go b/tools/utils/tty.go new file mode 100644 index 000000000..232a82fb3 --- /dev/null +++ b/tools/utils/tty.go @@ -0,0 +1,237 @@ +package utils + +import ( + "errors" + "github.com/pkg/term/termios" + "golang.org/x/sys/unix" + "io" + "os" + "syscall" + "time" +) + +type Term struct { + name string + fd int + states []unix.Termios +} + +func OpenTerm(name string, in_raw_mode bool) (self *Term, err error) { + fd, err := unix.Open(name, unix.O_NOCTTY|unix.O_CLOEXEC|unix.O_NDELAY|unix.O_RDWR, 0666) + if err != nil { + return nil, &os.PathError{Op: "open", Path: name, Err: err} + } + + self = &Term{name: name, fd: fd} + err = unix.SetNonblock(self.fd, false) + if err != nil { + return + } + if in_raw_mode { + err = self.SetRaw() + if err != nil { + return + } + } + return +} + +func OpenControllingTerm(in_raw_mode bool) (self *Term, err error) { + return OpenTerm("/dev/tty", in_raw_mode) // go doesnt have a wrapper for ctermid() +} + +func (self *Term) Fd() int { return self.fd } + +func (self *Term) Close() error { + err := unix.Close(self.fd) + self.fd = -1 + return err +} + +func (self *Term) SetRawWhen(when uintptr) (err error) { + var state unix.Termios + if err = termios.Tcgetattr(uintptr(self.fd), &state); err != nil { + return + } + new_state := state + termios.Cfmakeraw(&new_state) + err = termios.Tcsetattr(uintptr(self.fd), when, &new_state) + if err != nil { + self.states = append(self.states, state) + } + return +} + +func (self *Term) SetRaw() error { + return self.SetRawWhen(termios.TCSANOW) +} + +func (self *Term) PopStateWhen(when uintptr) (err error) { + if len(self.states) == 0 { + return nil + } + idx := len(self.states) - 1 + err = termios.Tcsetattr(uintptr(self.fd), when, &self.states[idx]) + if err != nil { + self.states = self.states[:idx] + } + return +} + +func (self *Term) PopState() error { + return self.PopStateWhen(termios.TCIOFLUSH) +} + +func (self *Term) RestoreWhen(when uintptr) (err error) { + if len(self.states) == 0 { + return nil + } + self.states = self.states[:1] + return self.PopStateWhen(when) +} + +func (self *Term) Restore() error { + return self.RestoreWhen(termios.TCIOFLUSH) +} + +func clamp(v, lo, hi int64) int64 { + if v < lo { + return lo + } + if v > hi { + return hi + } + return v +} + +func get_vmin_and_vtime(d time.Duration) (uint8, uint8) { + if d > 0 { + // VTIME is expressed in terms of deciseconds + vtimeDeci := d.Milliseconds() / 100 + // ensure valid range + vtime := uint8(clamp(vtimeDeci, 1, 0xff)) + return 0, vtime + } + // block indefinitely until we receive at least 1 byte + return 1, 0 +} + +func (self *Term) SetReadTimeout(d time.Duration) (err error) { + var a unix.Termios + if err := termios.Tcgetattr(uintptr(self.fd), &a); err != nil { + return err + } + b := a + b.Cc[unix.VMIN], b.Cc[unix.VTIME] = get_vmin_and_vtime(d) + err = termios.Tcsetattr(uintptr(self.fd), termios.TCSANOW, &b) + if err != nil { + self.states = append(self.states, a) + } + return +} + +func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error) { + var read, write, in_err unix.FdSet + tv := unix.Timeval(syscall.NsecToTimeval(int64(d))) + read.Set(self.fd) + num_ready, err := unix.Select(self.fd, &read, &write, &in_err, &tv) + if err != nil { + return + } + if num_ready == 0 { + err = os.ErrDeadlineExceeded + return + } + + return self.Read(b) +} + +func (t *Term) Read(b []byte) (int, error) { + n, e := unix.Read(t.fd, b) + if n < 0 { + n = 0 + } + if n == 0 && len(b) > 0 && e == nil { + return 0, io.EOF + } + if e != nil { + return n, &os.PathError{Op: "read", Path: t.name, Err: e} + } + return n, nil +} + +func (t *Term) Write(b []byte) (int, error) { + n, e := unix.Write(t.fd, b) + if n < 0 { + n = 0 + } + if n != len(b) { + return n, io.ErrShortWrite + } + if e != nil { + return n, &os.PathError{Op: "write", Path: t.name, Err: e} + } + return n, nil +} + +func (self *Term) WriteAllWithTimeout(b []byte, d time.Duration) (n int, err error) { + var read, write, in_err unix.FdSet + var num_ready int + n = len(b) + for { + if len(b) == 0 { + return + } + tv := unix.Timeval(syscall.NsecToTimeval(int64(d))) + read.Zero() + write.Zero() + in_err.Zero() + write.Set(self.fd) + num_ready, err = unix.Select(self.fd, &read, &write, &in_err, &tv) + if err != nil { + n -= len(b) + return + } + if num_ready == 0 { + err = os.ErrDeadlineExceeded + n -= len(b) + return + } + num_written, werr := self.Write(b) + if werr == nil { + n -= len(b) + return + } + if errors.Is(werr, io.ErrShortWrite) { + b = b[num_written:] + continue + } + err = werr + n -= len(b) + return + } +} + +func (self *Term) WriteFromReader(r Reader, read_timeout time.Duration, write_timeout time.Duration) (n int, err error) { + buf := r.GetBuf() + var rn, wn int + var rerr error + for { + if len(buf) == 0 { + rn, rerr = r.ReadWithTimeout(buf, read_timeout) + if rerr != nil && !errors.Is(rerr, io.EOF) { + err = rerr + return + } + if rn == 0 { + return n, nil + } + } + wn, err = self.WriteAllWithTimeout(buf, write_timeout) + n += wn + if err != nil { + return + } + + } +}