More work on porting ssh kitten

This commit is contained in:
Kovid Goyal 2023-02-26 11:12:31 +05:30
parent 4a5c6ad47f
commit 64cb9c9542
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 92 additions and 8 deletions

View File

@ -26,6 +26,7 @@ import (
"kitty/tools/cli" "kitty/tools/cli"
"kitty/tools/tty" "kitty/tools/tty"
"kitty/tools/tui"
"kitty/tools/tui/loop" "kitty/tools/tui/loop"
"kitty/tools/utils" "kitty/tools/utils"
"kitty/tools/utils/secrets" "kitty/tools/utils/secrets"
@ -136,7 +137,7 @@ func connection_sharing_args(kitty_pid int) ([]string, error) {
cp = strings.Replace(cp, "{ssh_placeholder}", "%C", 1) cp = strings.Replace(cp, "{ssh_placeholder}", "%C", 1)
return []string{ return []string{
"-o", "ControlMaster=auto", "-o", "ControlMaster=auto",
"-o", "ControlPath=" + cp, "-o", "ControlPath=" + filepath.Join(rd, cp),
"-o", "ControlPersist=yes", "-o", "ControlPersist=yes",
"-o", "ServerAliveInterval=60", "-o", "ServerAliveInterval=60",
"-o", "ServerAliveCountMax=5", "-o", "ServerAliveCountMax=5",
@ -513,6 +514,41 @@ func get_remote_command(cd *connection_data) error {
return nil return nil
} }
func drain_potential_tty_garbage(term *tty.Term) {
err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho)
if err != nil {
return
}
canary, err := secrets.TokenBase64()
if err != nil {
return
}
dcs, err := tui.DCSToKitty("echo", canary+"\n\r")
if err != nil {
return
}
err = term.WriteAllString(dcs)
if err != nil {
return
}
q := utils.UnsafeStringToBytes(canary)
data := make([]byte, 0)
give_up_at := time.Now().Add(2 * time.Second)
buf := make([]byte, 0, 8192)
for !bytes.Contains(data, q) {
buf = buf[:cap(buf)]
timeout := give_up_at.Sub(time.Now())
if timeout < 0 {
break
}
n, err := term.ReadWithTimeout(buf, timeout)
if err != nil {
return
}
data = append(data, buf[:n]...)
}
}
func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) { func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) {
go Data() go Data()
go RelevantKittyOpts() go RelevantKittyOpts()
@ -575,14 +611,48 @@ func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err erro
cd.host_opts, cd.literal_env = host_opts, literal_env cd.host_opts, cd.literal_env = host_opts, literal_env
cd.request_data = need_to_request_data cd.request_data = need_to_request_data
cd.hostname_for_match, cd.username = hostname_for_match, uname cd.hostname_for_match, cd.username = hostname_for_match, uname
term.WriteString(loop.SAVE_PRIVATE_MODE_VALUES) err = term.WriteAllString(loop.SAVE_PRIVATE_MODE_VALUES + loop.HANDLE_TERMIOS_SIGNALS.EscapeCodeToSet())
term.WriteString(loop.HANDLE_TERMIOS_SIGNALS.EscapeCodeToSet()) if err != nil {
defer term.WriteString(loop.RESTORE_PRIVATE_MODE_VALUES) return 1, err
}
defer term.WriteAllString(loop.RESTORE_PRIVATE_MODE_VALUES)
defer term.RestoreAndClose() defer term.RestoreAndClose()
err = get_remote_command(&cd) err = get_remote_command(&cd)
if err != nil { if err != nil {
return 1, err return 1, err
} }
cmd = append(cmd, cd.rcmd...)
c := exec.Command(cmd[0], cmd[1:]...)
c.Stdin, c.Stdout, c.Stderr = os.Stdin, os.Stdout, os.Stderr
err = c.Start()
if err != nil {
return 1, err
}
if !cd.request_data {
rq := fmt.Sprintf("id=%s:pwfile=%s:pw=%s", cd.replacements["REQUEST_ID"], cd.replacements["PASSWORD_FILENAME"], cd.replacements["DATA_PASSWORD"])
err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho)
if err == nil {
var dcs string
dcs, err = tui.DCSToKitty("ssh", rq)
if err == nil {
err = term.WriteAllString(dcs)
}
}
if err != nil {
c.Process.Kill()
c.Wait()
return 1, err
}
}
err = c.Wait()
drain_potential_tty_garbage(term)
if err != nil {
var exit_err *exec.ExitError
if errors.As(err, &exit_err) {
return exit_err.ExitCode(), nil
}
return 1, err
}
return 0, nil return 0, nil
} }

View File

@ -263,13 +263,19 @@ func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error)
} }
num_ready, err := pselect() num_ready, err := pselect()
if err != nil { if err != nil {
return return 0, err
} }
if num_ready == 0 { if num_ready == 0 {
err = os.ErrDeadlineExceeded err = os.ErrDeadlineExceeded
return return 0, err
}
for {
n, err = self.Read(b)
if errors.Is(err, unix.EINTR) {
continue
}
return n, err
} }
return self.Read(b)
} }
func (self *Term) Read(b []byte) (int, error) { func (self *Term) Read(b []byte) (int, error) {
@ -280,10 +286,14 @@ func (self *Term) Write(b []byte) (int, error) {
return self.os_file.Write(b) return self.os_file.Write(b)
} }
func is_temporary_error(err error) bool {
return errors.Is(err, unix.EINTR) || errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, io.ErrShortWrite)
}
func (self *Term) WriteAll(b []byte) error { func (self *Term) WriteAll(b []byte) error {
for len(b) > 0 { for len(b) > 0 {
n, err := self.os_file.Write(b) n, err := self.os_file.Write(b)
if err != nil && !errors.Is(err, io.ErrShortWrite) { if err != nil && !is_temporary_error(err) {
return err return err
} }
b = b[n:] b = b[n:]
@ -291,6 +301,10 @@ func (self *Term) WriteAll(b []byte) error {
return nil return nil
} }
func (self *Term) WriteAllString(s string) error {
return self.WriteAll(utils.UnsafeStringToBytes(s))
}
func (self *Term) WriteString(b string) (int, error) { func (self *Term) WriteString(b string) (int, error) {
return self.os_file.WriteString(b) return self.os_file.WriteString(b)
} }