diff --git a/tools/tty/tty.go b/tools/tty/tty.go index b065e1473..84dfeb73e 100644 --- a/tools/tty/tty.go +++ b/tools/tty/tty.go @@ -4,7 +4,9 @@ package tty import ( "encoding/base64" + "errors" "fmt" + "io" "os" "time" @@ -207,6 +209,22 @@ func (self *Term) RestoreAndClose() error { return self.Close() } +func (self *Term) SuspendAndRun(callback func() error) error { + var state unix.Termios + err := self.Tcgetattr(&state) + if err != nil { + return err + } + if len(self.states) > 0 { + err := self.Tcsetattr(TCSANOW, &self.states[0]) + if err != nil { + return err + } + } + defer self.Tcsetattr(TCSANOW, &state) + return callback() +} + func clamp(v, lo, hi int64) int64 { if v < lo { return lo @@ -245,6 +263,17 @@ func (self *Term) Write(b []byte) (int, error) { return self.os_file.Write(b) } +func (self *Term) WriteAll(b []byte) error { + for len(b) > 0 { + n, err := self.os_file.Write(b) + if err != nil && !errors.Is(err, io.ErrShortWrite) { + return err + } + b = b[n:] + } + return nil +} + func (self *Term) WriteString(b string) (int, error) { return self.os_file.WriteString(b) } diff --git a/tools/tui/loop/api.go b/tools/tui/loop/api.go index 9daeef1c2..d66a06b63 100644 --- a/tools/tui/loop/api.go +++ b/tools/tui/loop/api.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "fmt" "kitty/tools/tty" - "os" "time" "golang.org/x/sys/unix" @@ -45,8 +44,8 @@ type Loop struct { timers, timers_temp []*timer timer_id_counter, write_msg_id_counter IdType wakeup_channel chan byte - signal_channel chan os.Signal pending_writes []*write_msg + on_SIGTSTP func() error // Callbacks @@ -71,6 +70,9 @@ type Loop struct { // Called when any input from tty is received OnReceivedData func(data []byte) error + + // Called when resuming from a SIGTSTP or Ctrl-z + OnResumeFromStop func() error } func New(options ...func(self *Loop)) (*Loop, error) { diff --git a/tools/tui/loop/run.go b/tools/tui/loop/run.go index 92dd2f1d6..b6b8b837b 100644 --- a/tools/tui/loop/run.go +++ b/tools/tui/loop/run.go @@ -75,7 +75,6 @@ func (self *Loop) handle_csi(raw []byte) error { } func (self *Loop) handle_key_event(ev *KeyEvent) error { - self.DebugPrintln(ev) if self.OnKeyEvent != nil { err := self.OnKeyEvent(ev) if err != nil { @@ -177,14 +176,6 @@ func (self *Loop) on_SIGTERM() error { return nil } -func (self *Loop) on_SIGTSTP() error { - signal.Reset(unix.SIGTSTP) - unix.Kill(os.Getpid(), unix.SIGTSTP) - time.Sleep(20 * time.Millisecond) - signal.Notify(self.signal_channel, unix.SIGTSTP) - return nil -} - func (self *Loop) on_SIGHUP() error { self.death_signal = unix.SIGHUP self.keep_going = false @@ -192,9 +183,9 @@ func (self *Loop) on_SIGHUP() error { } func (self *Loop) run() (err error) { - self.signal_channel = make(chan os.Signal, 256) + signal_channel := make(chan os.Signal, 256) handled_signals := []os.Signal{unix.SIGINT, unix.SIGTERM, unix.SIGTSTP, unix.SIGHUP, unix.SIGWINCH, unix.SIGPIPE} - signal.Notify(self.signal_channel, handled_signals...) + signal.Notify(signal_channel, handled_signals...) defer signal.Reset(handled_signals...) controlling_term, err := tty.OpenControllingTerm() @@ -203,10 +194,10 @@ func (self *Loop) run() (err error) { } self.controlling_term = controlling_term defer func() { - self.controlling_term.RestoreAndClose() + controlling_term.RestoreAndClose() self.controlling_term = nil }() - err = self.controlling_term.ApplyOperations(tty.TCSANOW, tty.SetRaw) + err = controlling_term.ApplyOperations(tty.TCSANOW, tty.SetRaw) if err != nil { return nil } @@ -239,6 +230,7 @@ func (self *Loop) run() (err error) { return err } self.QueueWriteBytesDangerous(self.terminal_options.SetStateEscapeCodes()) + needs_reset_escape_codes := true defer func() { // notify tty reader that we are shutting down @@ -248,7 +240,9 @@ func (self *Loop) run() (err error) { if finalizer != "" { self.QueueWriteString(finalizer) } - self.QueueWriteBytesDangerous(self.terminal_options.ResetStateEscapeCodes()) + if needs_reset_escape_codes { + self.QueueWriteBytesDangerous(self.terminal_options.ResetStateEscapeCodes()) + } // flush queued data and wait for it to be written for a timeout, then wait for writer to shutdown flush_writer(w_w, tty_write_channel, write_done_channel, self.pending_writes, 2*time.Second) self.pending_writes = nil @@ -257,8 +251,8 @@ func (self *Loop) run() (err error) { } }() - go write_to_tty(w_r, self.controlling_term, tty_write_channel, err_channel, write_done_channel) - go read_from_tty(r_r, self.controlling_term, tty_read_channel, err_channel, tty_reading_done_channel) + go write_to_tty(w_r, controlling_term, tty_write_channel, err_channel, write_done_channel) + go read_from_tty(r_r, controlling_term, tty_read_channel, err_channel, tty_reading_done_channel) if self.OnInitialize != nil { finalizer, err = self.OnInitialize() @@ -267,6 +261,33 @@ func (self *Loop) run() (err error) { } } + self.on_SIGTSTP = func() error { + write_id := self.QueueWriteBytesDangerous(self.terminal_options.ResetStateEscapeCodes()) + needs_reset_escape_codes = false + err := self.wait_for_write_to_complete(write_id, tty_write_channel, write_done_channel, 2*time.Second) + if err != nil { + return err + } + err = controlling_term.SuspendAndRun(func() error { + unix.Kill(os.Getpid(), unix.SIGSTOP) + time.Sleep(20 * time.Millisecond) + return nil + }) + if err != nil { + return err + } + write_id = self.QueueWriteBytesDangerous(self.terminal_options.SetStateEscapeCodes()) + needs_reset_escape_codes = true + err = self.wait_for_write_to_complete(write_id, tty_write_channel, write_done_channel, 2*time.Second) + if err != nil { + return err + } + if self.OnResumeFromStop != nil { + return self.OnResumeFromStop() + } + return nil + } + for self.keep_going { self.flush_pending_writes(tty_write_channel) timeout_chan := no_timeout_channel @@ -296,7 +317,7 @@ func (self *Loop) run() (err error) { return err } } - case s := <-self.signal_channel: + case s := <-signal_channel: err = self.on_signal(s.(unix.Signal)) if err != nil { return err diff --git a/tools/tui/loop/write.go b/tools/tui/loop/write.go index 52451d5cd..fdd17e293 100644 --- a/tools/tui/loop/write.go +++ b/tools/tui/loop/write.go @@ -69,6 +69,31 @@ func (self *Loop) flush_pending_writes(tty_write_channel chan<- *write_msg) { } } +func (self *Loop) wait_for_write_to_complete(sentinel IdType, tty_write_channel chan<- *write_msg, write_done_channel <-chan IdType, timeout time.Duration) error { + for len(self.pending_writes) > 0 { + select { + case tty_write_channel <- self.pending_writes[0]: + self.pending_writes = self.pending_writes[1:] + case write_id, more := <-write_done_channel: + if write_id == sentinel { + return nil + } + if self.OnWriteComplete != nil { + err := self.OnWriteComplete(write_id) + if err != nil { + return err + } + } + if !more { + return fmt.Errorf("The write_done_channel was unexpectedly closed") + } + case <-time.After(timeout): + return os.ErrDeadlineExceeded + } + } + return nil +} + func (self *Loop) add_write_to_pending_queue(data *write_msg) { self.pending_writes = append(self.pending_writes, data) } @@ -183,7 +208,7 @@ func flush_writer(pipe_w *os.File, tty_write_channel chan<- *write_msg, write_do } }() deadline := time.Now().Add(timeout) - for len(pending_writes) > 0 { + for len(pending_writes) > 0 && !writer_quit { timeout = deadline.Sub(time.Now()) if timeout <= 0 { return @@ -213,7 +238,7 @@ func flush_writer(pipe_w *os.File, tty_write_channel chan<- *write_msg, write_do writer_quit = true } case <-time.After(timeout): - break + return } } return diff --git a/tools/tui/password.go b/tools/tui/password.go index 0371d059f..42f22e6a7 100644 --- a/tools/tui/password.go +++ b/tools/tui/password.go @@ -79,6 +79,11 @@ func ReadPassword(prompt string, kill_if_signaled bool) (password string, err er return nil } + lp.OnResumeFromStop = func() error { + lp.QueueWriteString("\r" + prompt + shadow) + return nil + } + err = lp.Run() if err != nil { return