Some work on implementing TTYIO

This commit is contained in:
Kovid Goyal 2022-08-18 15:40:19 +05:30
parent 6c3a439455
commit 43e93414ea
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
5 changed files with 307 additions and 10 deletions

1
go.mod
View File

@ -6,6 +6,7 @@ require (
github.com/fatih/color v1.13.0 github.com/fatih/color v1.13.0
github.com/mattn/go-isatty v0.0.14 github.com/mattn/go-isatty v0.0.14
github.com/mattn/go-runewidth v0.0.13 github.com/mattn/go-runewidth v0.0.13
github.com/pkg/term v1.1.0
github.com/seancfoley/ipaddress-go v1.2.1 github.com/seancfoley/ipaddress-go v1.2.1
github.com/spf13/cobra v1.5.0 github.com/spf13/cobra v1.5.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5

3
go.sum
View File

@ -10,6 +10,8 @@ github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU=
github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/pkg/term v1.1.0 h1:xIAAdCMh3QIAy+5FrE8Ad8XoDhEU4ufwbaSozViP9kk=
github.com/pkg/term v1.1.0/go.mod h1:E25nymQcrSllhX42Ok8MRm1+hyBdHY0dCeiKZ9jpNGw=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
@ -25,6 +27,7 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJ
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -3,9 +3,10 @@ package at
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"strings" "strings"
"time"
"github.com/mattn/go-isatty" "github.com/mattn/go-isatty"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -89,6 +90,36 @@ func create_serializer(password string, encoded_pubkey string, response_timeout
return simple_serializer, timeout, nil return simple_serializer, timeout, nil
} }
type TTYIO interface {
WriteAllWithTimeout(b []byte, d time.Duration) (n int, err error)
WriteFromReader(r utils.Reader, read_timeout time.Duration, write_timeout time.Duration) (n int, err error)
Restore() error
Close() error
}
func do_tty_io(tty TTYIO, input utils.Reader, response_timeout time.Duration) (err error) {
defer func() {
tty.Restore()
tty.Close()
}()
_, err = tty.WriteAllWithTimeout([]byte("\x1bP@kitty-cmd"), 2*time.Second)
if err != nil {
return err
}
_, err = tty.WriteFromReader(input, 2*time.Second, 2*time.Second)
if err != nil {
return err
}
_, err = tty.WriteAllWithTimeout([]byte("\x1b\\"), 2*time.Second)
if err != nil {
return err
}
return
}
func send_rc_command(rc *utils.RemoteControlCmd, timeout float64) (err error) { func send_rc_command(rc *utils.RemoteControlCmd, timeout float64) (err error) {
serializer, timeout, err = create_serializer(global_options.password, "", timeout) serializer, timeout, err = create_serializer(global_options.password, "", timeout)
if err != nil { if err != nil {
@ -99,9 +130,15 @@ func send_rc_command(rc *utils.RemoteControlCmd, timeout float64) (err error) {
return return
} }
r := utils.BytesReader{Data: d} r := utils.BytesReader{Data: d}
if global_options.to_network == "" {
println(string(r.Data)) tty, err := utils.OpenControllingTerm(true)
return if err != nil {
return err
}
return do_tty_io(tty, &r, time.Duration(timeout*1e9))
} else {
return fmt.Errorf("TODO: Implement socket IO")
}
} }
func get_password(password string, password_file string, password_env string, use_password string) (ans string, err error) { func get_password(password string, password_file string, password_env string, use_password string) (ans string, err error) {
@ -119,7 +156,7 @@ func get_password(password string, password_file string, password_env string, us
ans = string(q) ans = string(q)
} }
} else { } else {
q, err := ioutil.ReadAll(os.Stdin) q, err := io.ReadAll(os.Stdin)
if err != nil { if err != nil {
ans = strings.TrimRight(string(q), " \n\t") ans = strings.TrimRight(string(q), " \n\t")
} }
@ -130,7 +167,7 @@ func get_password(password string, password_file string, password_env string, us
} }
} }
} else { } else {
q, err := ioutil.ReadFile(password_file) q, err := os.ReadFile(password_file)
if err != nil { if err != nil {
ans = strings.TrimRight(string(q), " \n\t") ans = strings.TrimRight(string(q), " \n\t")
} }

View File

@ -2,18 +2,37 @@ package utils
import ( import (
"io" "io"
"time"
)
const (
DEFAULT_IO_BUFFER_SIZE = 8192
) )
type BytesReader struct { type BytesReader struct {
Data []byte Data []byte
Pos int64 }
type Reader interface {
ReadWithTimeout(b []byte, timeout time.Duration) (n int, err error)
GetBuf() []byte
} }
func (self *BytesReader) Read(b []byte) (n int, err error) { func (self *BytesReader) Read(b []byte) (n int, err error) {
if self.Pos >= int64(len(self.Data)) { if len(self.Data) == 0 {
return 0, io.EOF return 0, io.EOF
} }
n = copy(b, self.Data[self.Pos:]) n = copy(b, self.Data)
self.Pos += int64(n) self.Data = self.Data[n:]
return
}
func (self *BytesReader) ReadWithTimeout(b []byte, timeout time.Duration) (n int, err error) {
return self.Read(b)
}
func (self *BytesReader) GetBuf() (ans []byte) {
ans = self.Data
self.Data = make([]byte, 0)
return return
} }

237
tools/utils/tty.go Normal file
View File

@ -0,0 +1,237 @@
package utils
import (
"errors"
"github.com/pkg/term/termios"
"golang.org/x/sys/unix"
"io"
"os"
"syscall"
"time"
)
type Term struct {
name string
fd int
states []unix.Termios
}
func OpenTerm(name string, in_raw_mode bool) (self *Term, err error) {
fd, err := unix.Open(name, unix.O_NOCTTY|unix.O_CLOEXEC|unix.O_NDELAY|unix.O_RDWR, 0666)
if err != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: err}
}
self = &Term{name: name, fd: fd}
err = unix.SetNonblock(self.fd, false)
if err != nil {
return
}
if in_raw_mode {
err = self.SetRaw()
if err != nil {
return
}
}
return
}
func OpenControllingTerm(in_raw_mode bool) (self *Term, err error) {
return OpenTerm("/dev/tty", in_raw_mode) // go doesnt have a wrapper for ctermid()
}
func (self *Term) Fd() int { return self.fd }
func (self *Term) Close() error {
err := unix.Close(self.fd)
self.fd = -1
return err
}
func (self *Term) SetRawWhen(when uintptr) (err error) {
var state unix.Termios
if err = termios.Tcgetattr(uintptr(self.fd), &state); err != nil {
return
}
new_state := state
termios.Cfmakeraw(&new_state)
err = termios.Tcsetattr(uintptr(self.fd), when, &new_state)
if err != nil {
self.states = append(self.states, state)
}
return
}
func (self *Term) SetRaw() error {
return self.SetRawWhen(termios.TCSANOW)
}
func (self *Term) PopStateWhen(when uintptr) (err error) {
if len(self.states) == 0 {
return nil
}
idx := len(self.states) - 1
err = termios.Tcsetattr(uintptr(self.fd), when, &self.states[idx])
if err != nil {
self.states = self.states[:idx]
}
return
}
func (self *Term) PopState() error {
return self.PopStateWhen(termios.TCIOFLUSH)
}
func (self *Term) RestoreWhen(when uintptr) (err error) {
if len(self.states) == 0 {
return nil
}
self.states = self.states[:1]
return self.PopStateWhen(when)
}
func (self *Term) Restore() error {
return self.RestoreWhen(termios.TCIOFLUSH)
}
func clamp(v, lo, hi int64) int64 {
if v < lo {
return lo
}
if v > hi {
return hi
}
return v
}
func get_vmin_and_vtime(d time.Duration) (uint8, uint8) {
if d > 0 {
// VTIME is expressed in terms of deciseconds
vtimeDeci := d.Milliseconds() / 100
// ensure valid range
vtime := uint8(clamp(vtimeDeci, 1, 0xff))
return 0, vtime
}
// block indefinitely until we receive at least 1 byte
return 1, 0
}
func (self *Term) SetReadTimeout(d time.Duration) (err error) {
var a unix.Termios
if err := termios.Tcgetattr(uintptr(self.fd), &a); err != nil {
return err
}
b := a
b.Cc[unix.VMIN], b.Cc[unix.VTIME] = get_vmin_and_vtime(d)
err = termios.Tcsetattr(uintptr(self.fd), termios.TCSANOW, &b)
if err != nil {
self.states = append(self.states, a)
}
return
}
func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error) {
var read, write, in_err unix.FdSet
tv := unix.Timeval(syscall.NsecToTimeval(int64(d)))
read.Set(self.fd)
num_ready, err := unix.Select(self.fd, &read, &write, &in_err, &tv)
if err != nil {
return
}
if num_ready == 0 {
err = os.ErrDeadlineExceeded
return
}
return self.Read(b)
}
func (t *Term) Read(b []byte) (int, error) {
n, e := unix.Read(t.fd, b)
if n < 0 {
n = 0
}
if n == 0 && len(b) > 0 && e == nil {
return 0, io.EOF
}
if e != nil {
return n, &os.PathError{Op: "read", Path: t.name, Err: e}
}
return n, nil
}
func (t *Term) Write(b []byte) (int, error) {
n, e := unix.Write(t.fd, b)
if n < 0 {
n = 0
}
if n != len(b) {
return n, io.ErrShortWrite
}
if e != nil {
return n, &os.PathError{Op: "write", Path: t.name, Err: e}
}
return n, nil
}
func (self *Term) WriteAllWithTimeout(b []byte, d time.Duration) (n int, err error) {
var read, write, in_err unix.FdSet
var num_ready int
n = len(b)
for {
if len(b) == 0 {
return
}
tv := unix.Timeval(syscall.NsecToTimeval(int64(d)))
read.Zero()
write.Zero()
in_err.Zero()
write.Set(self.fd)
num_ready, err = unix.Select(self.fd, &read, &write, &in_err, &tv)
if err != nil {
n -= len(b)
return
}
if num_ready == 0 {
err = os.ErrDeadlineExceeded
n -= len(b)
return
}
num_written, werr := self.Write(b)
if werr == nil {
n -= len(b)
return
}
if errors.Is(werr, io.ErrShortWrite) {
b = b[num_written:]
continue
}
err = werr
n -= len(b)
return
}
}
func (self *Term) WriteFromReader(r Reader, read_timeout time.Duration, write_timeout time.Duration) (n int, err error) {
buf := r.GetBuf()
var rn, wn int
var rerr error
for {
if len(buf) == 0 {
rn, rerr = r.ReadWithTimeout(buf, read_timeout)
if rerr != nil && !errors.Is(rerr, io.EOF) {
err = rerr
return
}
if rn == 0 {
return n, nil
}
}
wn, err = self.WriteAllWithTimeout(buf, write_timeout)
n += wn
if err != nil {
return
}
}
}