From ee12349a50b30bf6b38467cf7325f130d5d2576a Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Fri, 26 Aug 2022 05:46:43 +0530 Subject: [PATCH] Use Go's os.File this allows us to implement WriteString without using unsafe --- tools/tty/tty.go | 226 ++++++++--------------------------------------- 1 file changed, 39 insertions(+), 187 deletions(-) diff --git a/tools/tty/tty.go b/tools/tty/tty.go index ab0520c78..46773b007 100644 --- a/tools/tty/tty.go +++ b/tools/tty/tty.go @@ -4,11 +4,8 @@ package tty import ( "encoding/base64" - "errors" "fmt" - "io" "os" - "runtime" "time" "golang.org/x/sys/unix" @@ -23,9 +20,8 @@ const ( ) type Term struct { - name string - fd int - states []unix.Termios + os_file *os.File + states []unix.Termios } func eintr_retry_noret(f func() error) error { @@ -102,12 +98,16 @@ var SetReadPassword TermiosOperation = func(t *unix.Termios) { t.Cc[unix.VTIME] = 0 } -func WrapTerm(fd int, operations ...TermiosOperation) (self *Term, err error) { - self = &Term{name: fmt.Sprintf("", fd), fd: fd} - err = eintr_retry_noret(func() error { return unix.SetNonblock(self.fd, false) }) - if err == nil { - err = self.ApplyOperations(TCSANOW, operations...) +func WrapTerm(fd int, name string, operations ...TermiosOperation) (self *Term, err error) { + if name == "" { + name = fmt.Sprintf("", fd) } + os_file := os.NewFile(uintptr(fd), name) + if os_file == nil { + return nil, os.ErrInvalid + } + self = &Term{os_file: os_file} + err = self.ApplyOperations(TCSANOW, operations...) if err != nil { self.Close() self = nil @@ -122,10 +122,7 @@ func OpenTerm(name string, operations ...TermiosOperation) (self *Term, err erro if err != nil { return nil, &os.PathError{Op: "open", Path: name, Err: err} } - self, err = WrapTerm(fd, operations...) - if err != nil { - self.name = name - } + self, err = WrapTerm(fd, name, operations...) return } @@ -133,20 +130,28 @@ func OpenControllingTerm(operations ...TermiosOperation) (self *Term, err error) return OpenTerm(Ctermid(), operations...) } -func (self *Term) Fd() int { return self.fd } +func (self *Term) Fd() int { + if self.os_file == nil { + return -1 + } + return int(self.os_file.Fd()) +} func (self *Term) Close() error { - err := eintr_retry_noret(func() error { return unix.Close(self.fd) }) - self.fd = -1 + if self.os_file == nil { + return nil + } + err := eintr_retry_noret(func() error { return self.os_file.Close() }) + self.os_file = nil return err } func (self *Term) Tcgetattr(ans *unix.Termios) error { - return eintr_retry_noret(func() error { return Tcgetattr(self.fd, ans) }) + return eintr_retry_noret(func() error { return Tcgetattr(self.Fd(), ans) }) } func (self *Term) Tcsetattr(when uintptr, ans *unix.Termios) error { - return eintr_retry_noret(func() error { return Tcsetattr(self.fd, when, ans) }) + return eintr_retry_noret(func() error { return Tcsetattr(self.Fd(), when, ans) }) } func (self *Term) set_termios_attrs(when uintptr, modify func(*unix.Termios)) (err error) { @@ -218,8 +223,8 @@ func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error) read.Zero() write.Zero() in_err.Zero() - read.Set(self.fd) - return utils.Select(self.fd+1, &read, &write, &in_err, d) + read.Set(self.Fd()) + return utils.Select(self.Fd()+1, &read, &write, &in_err, d) } num_ready, err := pselect() if err != nil { @@ -233,195 +238,42 @@ func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error) } func (self *Term) Read(b []byte) (int, error) { - n, e := eintr_retry_intret(func() (int, error) { return unix.Read(self.fd, b) }) - if n < 0 { - n = 0 - } - if n == 0 && len(b) > 0 && e == nil { - return 0, io.EOF - } - if e != nil { - return n, &os.PathError{Op: "read", Path: self.name, Err: e} - } - return n, nil + return self.os_file.Read(b) } func (self *Term) Write(b []byte) (int, error) { - n, e := eintr_retry_intret(func() (int, error) { return unix.Write(self.fd, b) }) - if n < 0 { - n = 0 - } - if n != len(b) { - return n, io.ErrShortWrite - } - if e != nil { - return n, &os.PathError{Op: "write", Path: self.name, Err: e} - } - return n, nil + return self.os_file.Write(b) +} + +func (self *Term) WriteString(b string) (int, error) { + return self.os_file.WriteString(b) } func (self *Term) DebugPrintln(a ...interface{}) { msg := []byte(fmt.Sprintln(a...)) - for i := 0; i < len(msg); i += 256 { - end := i + 256 + const limit = 2048 + for i := 0; i < len(msg); i += limit { + end := i + limit if end > len(msg) { end = len(msg) } chunk := msg[i:end] encoded := make([]byte, base64.StdEncoding.EncodedLen(len(chunk))) base64.StdEncoding.Encode(encoded, chunk) - self.Write([]byte("\x1bP@kitty-print|")) + self.WriteString("\x1bP@kitty-print|") self.Write(encoded) - self.Write([]byte("\x1b\\")) - } -} - -func (self *Term) WriteAllWithTimeout(b []byte, d time.Duration) (n int, err error) { - var read, write, in_err unix.FdSet - var num_ready int - n = len(b) - pselect := func() (int, error) { - write.Zero() - read.Zero() - in_err.Zero() - write.Set(self.fd) - return utils.Select(self.fd+1, &read, &write, &in_err, d) - } - for { - if len(b) == 0 { - return - } - read.Zero() - in_err.Zero() - num_ready, err = pselect() - if err != nil { - n -= len(b) - return - } - if num_ready == 0 { - err = os.ErrDeadlineExceeded - n -= len(b) - return - } - num_written, werr := self.Write(b) - if werr == nil { - n -= len(b) - return - } - if errors.Is(werr, io.ErrShortWrite) { - b = b[num_written:] - continue - } - err = werr - n -= len(b) - return - } -} - -func (self *Term) WriteFromReader(r utils.Reader, read_timeout time.Duration, write_timeout time.Duration) (n int, err error) { - buf := r.GetBuf() - var rn, wn int - var rerr error - for { - if len(buf) == 0 { - rn, rerr = r.ReadWithTimeout(buf, read_timeout) - if rerr != nil && !errors.Is(rerr, io.EOF) { - err = rerr - return - } - if rn == 0 { - return n, nil - } - } - wn, err = self.WriteAllWithTimeout(buf, write_timeout) - n += wn - if err != nil { - return - } - buf = buf[:0] + self.WriteString("\x1b\\") } } func (self *Term) GetSize() (*unix.Winsize, error) { for { - sz, err := unix.IoctlGetWinsize(self.fd, unix.TIOCGWINSZ) + sz, err := unix.IoctlGetWinsize(self.Fd(), unix.TIOCGWINSZ) if err != unix.EINTR { return sz, err } } } -func (self *Term) read_line(in_password_mode bool) ([]byte, error) { - var buf [1]byte - var ret []byte - - for { - n, err := self.Read(buf[:]) - if n > 0 { - switch buf[0] { - case '\b', 0x7f: - if len(ret) > 0 { - ret = ret[:len(ret)-1] - _, err = self.WriteAllWithTimeout([]byte("\x08\x1b[P"), 5*time.Second) - if err != nil { - return nil, err - } - } - case '\n': - if runtime.GOOS != "windows" { - return ret, nil - } - // otherwise ignore \n - case '\r': - if runtime.GOOS == "windows" { - return ret, nil - } - // otherwise ignore \r - default: - ret = append(ret, buf[0]) - _, err = self.WriteAllWithTimeout([]byte("*"), 5*time.Second) - if err != nil { - return nil, err - } - } - continue - } - if err != nil { - if err == io.EOF && len(ret) > 0 { - return ret, nil - } - return nil, err - } - } -} - -func (self *Term) ReadLine() ([]byte, error) { - return self.read_line(false) -} - -func (self *Term) ReadPassword() (string, error) { - pw, err := self.read_line(false) - return string(pw), err -} - // go doesnt have a wrapper for ctermid() func Ctermid() string { return "/dev/tty" } - -func ReadPassword(prompt string) (string, error) { - term, err := OpenControllingTerm(SetReadPassword) - if err != nil { - return "", err - } - defer term.RestoreAndClose() - if len(prompt) > 0 { - _, err = term.WriteAllWithTimeout([]byte(prompt), 5*time.Second) - if err != nil { - return "", err - } - } - pw, err := term.ReadPassword() - if err != nil { - return "", err - } - return pw, nil -}