From 64cb9c95422eb213715e9c0664c4ef9fa304bdb8 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 26 Feb 2023 11:12:31 +0530 Subject: [PATCH] More work on porting ssh kitten --- tools/cmd/ssh/main.go | 78 ++++++++++++++++++++++++++++++++++++++++--- tools/tty/tty.go | 22 +++++++++--- 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index cbe257247..6fdf98968 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -26,6 +26,7 @@ import ( "kitty/tools/cli" "kitty/tools/tty" + "kitty/tools/tui" "kitty/tools/tui/loop" "kitty/tools/utils" "kitty/tools/utils/secrets" @@ -136,7 +137,7 @@ func connection_sharing_args(kitty_pid int) ([]string, error) { cp = strings.Replace(cp, "{ssh_placeholder}", "%C", 1) return []string{ "-o", "ControlMaster=auto", - "-o", "ControlPath=" + cp, + "-o", "ControlPath=" + filepath.Join(rd, cp), "-o", "ControlPersist=yes", "-o", "ServerAliveInterval=60", "-o", "ServerAliveCountMax=5", @@ -513,6 +514,41 @@ func get_remote_command(cd *connection_data) error { return nil } +func drain_potential_tty_garbage(term *tty.Term) { + err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho) + if err != nil { + return + } + canary, err := secrets.TokenBase64() + if err != nil { + return + } + dcs, err := tui.DCSToKitty("echo", canary+"\n\r") + if err != nil { + return + } + err = term.WriteAllString(dcs) + if err != nil { + return + } + q := utils.UnsafeStringToBytes(canary) + data := make([]byte, 0) + give_up_at := time.Now().Add(2 * time.Second) + buf := make([]byte, 0, 8192) + for !bytes.Contains(data, q) { + buf = buf[:cap(buf)] + timeout := give_up_at.Sub(time.Now()) + if timeout < 0 { + break + } + n, err := term.ReadWithTimeout(buf, timeout) + if err != nil { + return + } + data = append(data, buf[:n]...) + } +} + func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) { go Data() go RelevantKittyOpts() @@ -575,14 +611,48 @@ func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err erro cd.host_opts, cd.literal_env = host_opts, literal_env cd.request_data = need_to_request_data cd.hostname_for_match, cd.username = hostname_for_match, uname - term.WriteString(loop.SAVE_PRIVATE_MODE_VALUES) - term.WriteString(loop.HANDLE_TERMIOS_SIGNALS.EscapeCodeToSet()) - defer term.WriteString(loop.RESTORE_PRIVATE_MODE_VALUES) + err = term.WriteAllString(loop.SAVE_PRIVATE_MODE_VALUES + loop.HANDLE_TERMIOS_SIGNALS.EscapeCodeToSet()) + if err != nil { + return 1, err + } + defer term.WriteAllString(loop.RESTORE_PRIVATE_MODE_VALUES) defer term.RestoreAndClose() err = get_remote_command(&cd) if err != nil { return 1, err } + cmd = append(cmd, cd.rcmd...) + c := exec.Command(cmd[0], cmd[1:]...) + c.Stdin, c.Stdout, c.Stderr = os.Stdin, os.Stdout, os.Stderr + err = c.Start() + if err != nil { + return 1, err + } + if !cd.request_data { + rq := fmt.Sprintf("id=%s:pwfile=%s:pw=%s", cd.replacements["REQUEST_ID"], cd.replacements["PASSWORD_FILENAME"], cd.replacements["DATA_PASSWORD"]) + err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho) + if err == nil { + var dcs string + dcs, err = tui.DCSToKitty("ssh", rq) + if err == nil { + err = term.WriteAllString(dcs) + } + } + if err != nil { + c.Process.Kill() + c.Wait() + return 1, err + } + } + err = c.Wait() + drain_potential_tty_garbage(term) + if err != nil { + var exit_err *exec.ExitError + if errors.As(err, &exit_err) { + return exit_err.ExitCode(), nil + } + return 1, err + } return 0, nil } diff --git a/tools/tty/tty.go b/tools/tty/tty.go index 478a1a576..7904a3b7a 100644 --- a/tools/tty/tty.go +++ b/tools/tty/tty.go @@ -263,13 +263,19 @@ func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error) } num_ready, err := pselect() if err != nil { - return + return 0, err } if num_ready == 0 { err = os.ErrDeadlineExceeded - return + return 0, err + } + for { + n, err = self.Read(b) + if errors.Is(err, unix.EINTR) { + continue + } + return n, err } - return self.Read(b) } func (self *Term) Read(b []byte) (int, error) { @@ -280,10 +286,14 @@ func (self *Term) Write(b []byte) (int, error) { return self.os_file.Write(b) } +func is_temporary_error(err error) bool { + return errors.Is(err, unix.EINTR) || errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, io.ErrShortWrite) +} + func (self *Term) WriteAll(b []byte) error { for len(b) > 0 { n, err := self.os_file.Write(b) - if err != nil && !errors.Is(err, io.ErrShortWrite) { + if err != nil && !is_temporary_error(err) { return err } b = b[n:] @@ -291,6 +301,10 @@ func (self *Term) WriteAll(b []byte) error { return nil } +func (self *Term) WriteAllString(s string) error { + return self.WriteAll(utils.UnsafeStringToBytes(s)) +} + func (self *Term) WriteString(b string) (int, error) { return self.os_file.WriteString(b) }