From 4a49c3940a5c731ac024a285dc69e8f3e7ae24bf Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Fri, 26 Aug 2022 08:37:14 +0530 Subject: [PATCH] Switch to using goroutines rather than a select() More complex code since now we have to synchronize between threads, but a good way to teach myself more about goroutines. --- tools/cmd/at/main.go | 4 +- tools/cmd/at/tty_io.go | 21 +- tools/tui/loop.go | 564 +++++++++++++++++++++++++++++------------ tools/tui/signal.go | 183 ------------- tools/utils/io.go | 33 --- 5 files changed, 418 insertions(+), 387 deletions(-) delete mode 100644 tools/tui/signal.go diff --git a/tools/cmd/at/main.go b/tools/cmd/at/main.go index 127ecc5a3..54a39845a 100644 --- a/tools/cmd/at/main.go +++ b/tools/cmd/at/main.go @@ -70,9 +70,9 @@ func wrap_in_escape_code(data []byte) []byte { const prefix = "\x1bP@kitty-cmd" const suffix = "\x1b\\" ans := make([]byte, len(prefix)+len(data)+len(suffix)) - n := copy(ans, []byte(prefix)) + n := copy(ans, prefix) n += copy(ans[n:], data) - copy(ans[n:], []byte(suffix)) + copy(ans[n:], suffix) return ans } diff --git a/tools/cmd/at/tty_io.go b/tools/cmd/at/tty_io.go index 9e82abcc4..f313f1a09 100644 --- a/tools/cmd/at/tty_io.go +++ b/tools/cmd/at/tty_io.go @@ -17,9 +17,10 @@ func do_chunked_io(io_data *rc_io_data) (serialized_response []byte, err error) } var last_received_data_at time.Time - var check_for_timeout func(loop *tui.Loop, timer_id tui.TimerId) error + var final_write_id tui.IdType + var check_for_timeout func(loop *tui.Loop, timer_id tui.IdType) error - check_for_timeout = func(loop *tui.Loop, timer_id tui.TimerId) error { + check_for_timeout = func(loop *tui.Loop, timer_id tui.IdType) error { time_since_last_received_data := time.Now().Sub(last_received_data_at) if time_since_last_received_data >= io_data.timeout { return os.ErrDeadlineExceeded @@ -46,23 +47,25 @@ func do_chunked_io(io_data *rc_io_data) (serialized_response []byte, err error) if err != nil { return "", err } + write_id := loop.QueueWriteBytesDangerous(chunk) if len(chunk) == 0 { - transition_to_read() - } else { - loop.QueueWriteBytes(chunk) + final_write_id = write_id } return "", nil } - loop.OnWriteComplete = func(loop *tui.Loop) error { + loop.OnWriteComplete = func(loop *tui.Loop, completed_write_id tui.IdType) error { + if completed_write_id == final_write_id { + transition_to_read() + return nil + } chunk, err := io_data.next_chunk(true) if err != nil { return err } + write_id := loop.QueueWriteBytesDangerous(chunk) if len(chunk) == 0 { - transition_to_read() - } else { - loop.QueueWriteBytes(chunk) + final_write_id = write_id } return nil } diff --git a/tools/tui/loop.go b/tools/tui/loop.go index 2534e04b0..77f5657ac 100644 --- a/tools/tui/loop.go +++ b/tools/tui/loop.go @@ -4,10 +4,13 @@ package tui import ( "bytes" + "errors" "fmt" "io" "kitty/tools/tty" "os" + "os/signal" + "runtime/debug" "sort" "time" @@ -17,8 +20,8 @@ import ( "kitty/tools/wcswidth" ) -func read_ignoring_temporary_errors(fd int, buf []byte) (int, error) { - n, err := unix.Read(fd, buf) +func read_ignoring_temporary_errors(f *tty.Term, buf []byte) (int, error) { + n, err := f.Read(buf) if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { return 0, nil } @@ -28,10 +31,31 @@ func read_ignoring_temporary_errors(fd int, buf []byte) (int, error) { return n, err } -func write_ignoring_temporary_errors(fd int, buf []byte) (int, error) { - n, err := unix.Write(fd, buf) - if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { - return 0, nil +func is_temporary_error(err error) bool { + return errors.Is(err, unix.EINTR) || errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, io.ErrShortWrite) +} + +func write_ignoring_temporary_errors(f *tty.Term, buf []byte) (int, error) { + n, err := f.Write(buf) + if err != nil { + if is_temporary_error(err) { + err = nil + } + return n, err + } + if n == 0 { + return 0, io.EOF + } + return n, err +} + +func writestring_ignoring_temporary_errors(f *tty.Term, buf string) (int, error) { + n, err := f.WriteString(buf) + if err != nil { + if is_temporary_error(err) { + err = nil + } + return n, err } if n == 0 { return 0, io.EOF @@ -44,14 +68,14 @@ type ScreenSize struct { updated bool } -type TimerId uint64 -type TimerCallback func(loop *Loop, timer_id TimerId) error +type IdType uint64 +type TimerCallback func(loop *Loop, timer_id IdType) error type timer struct { interval time.Duration deadline time.Time repeats bool - id TimerId + id IdType callback TimerCallback } @@ -59,18 +83,34 @@ func (self *timer) update_deadline(now time.Time) { self.deadline = now.Add(self.interval) } +var SIGNULL unix.Signal + +type write_msg struct { + id IdType + bytes []byte + str string +} + +func (self *write_msg) String() string { + return fmt.Sprintf("write_msg{%v %#v %#v}", self.id, string(self.bytes), self.str) +} + type Loop struct { - controlling_term *tty.Term - terminal_options TerminalStateOptions - screen_size ScreenSize - escape_code_parser wcswidth.EscapeCodeParser - keep_going bool - flush_write_buf bool - death_signal Signal - exit_code int - write_buf []byte - timers []*timer - timer_id_counter TimerId + controlling_term *tty.Term + terminal_options TerminalStateOptions + screen_size ScreenSize + escape_code_parser wcswidth.EscapeCodeParser + keep_going bool + death_signal unix.Signal + exit_code int + timers []*timer + timer_id_counter, write_msg_id_counter IdType + tty_read_channel chan []byte + tty_write_channel chan *write_msg + write_done_channel chan IdType + err_channel chan error + tty_writing_done_channel, tty_reading_done_channel, wakeup_channel chan byte + pending_writes []*write_msg // Callbacks @@ -88,7 +128,7 @@ type Loop struct { OnResize func(loop *Loop, old_size ScreenSize, new_size ScreenSize) error // Called when writing is done - OnWriteComplete func(loop *Loop) error + OnWriteComplete func(loop *Loop, msg_id IdType) error // Called when a response to an rc command is received OnRCResponse func(loop *Loop, data []byte) error @@ -178,8 +218,27 @@ func (self *Loop) handle_rune(raw rune) error { return nil } +func (self *Loop) on_signal(s unix.Signal) error { + switch s { + case unix.SIGINT: + return self.on_SIGINT() + case unix.SIGPIPE: + return self.on_SIGPIPE() + case unix.SIGWINCH: + return self.on_SIGWINCH() + case unix.SIGTERM: + return self.on_SIGTERM() + case unix.SIGTSTP: + return self.on_SIGTSTP() + case unix.SIGHUP: + return self.on_SIGHUP() + default: + return nil + } +} + func (self *Loop) on_SIGINT() error { - self.death_signal = SIGINT + self.death_signal = unix.SIGINT self.keep_going = false return nil } @@ -202,7 +261,7 @@ func (self *Loop) on_SIGWINCH() error { } func (self *Loop) on_SIGTERM() error { - self.death_signal = SIGTERM + self.death_signal = unix.SIGTERM self.keep_going = false return nil } @@ -212,8 +271,7 @@ func (self *Loop) on_SIGTSTP() error { } func (self *Loop) on_SIGHUP() error { - self.flush_write_buf = false - self.death_signal = SIGHUP + self.death_signal = unix.SIGHUP self.keep_going = false return nil } @@ -231,7 +289,7 @@ func CreateLoop() (*Loop, error) { return &l, nil } -func (self *Loop) AddTimer(interval time.Duration, repeats bool, callback TimerCallback) TimerId { +func (self *Loop) AddTimer(interval time.Duration, repeats bool, callback TimerCallback) IdType { self.timer_id_counter++ t := timer{interval: interval, repeats: repeats, callback: callback, id: self.timer_id_counter} t.update_deadline(time.Now()) @@ -240,7 +298,7 @@ func (self *Loop) AddTimer(interval time.Duration, repeats bool, callback TimerC return t.id } -func (self *Loop) RemoveTimer(id TimerId) bool { +func (self *Loop) RemoveTimer(id IdType) bool { for i := 0; i < len(self.timers); i++ { if self.timers[i].id == id { self.timers = append(self.timers[:i], self.timers[i+1:]...) @@ -280,13 +338,8 @@ func kill_self(sig unix.Signal) { } func (self *Loop) KillIfSignalled() { - switch self.death_signal { - case SIGINT: - kill_self(unix.SIGINT) - case SIGTERM: - kill_self(unix.SIGTERM) - case SIGHUP: - kill_self(unix.SIGHUP) + if self.death_signal != SIGNULL { + kill_self(self.death_signal) } } @@ -297,33 +350,15 @@ func (self *Loop) DebugPrintln(args ...interface{}) { } func (self *Loop) Run() (err error) { - signal_read_file, signal_write_file, err := os.Pipe() - if err != nil { - return err - } - defer func() { - signal_read_file.Close() - signal_write_file.Close() - }() - sigchnl := make(chan os.Signal, 256) - reset_signals := notify_signals(sigchnl, SIGINT, SIGTERM, SIGTSTP, SIGHUP, SIGWINCH, SIGPIPE) - defer reset_signals() - - go func() { - for { - s := <-sigchnl - if write_signal(signal_write_file, s) != nil { - break - } - } - }() + handled_signals := []os.Signal{unix.SIGINT, unix.SIGTERM, unix.SIGTSTP, unix.SIGHUP, unix.SIGWINCH, unix.SIGPIPE} + signal.Notify(sigchnl, handled_signals...) + defer signal.Reset(handled_signals...) controlling_term, err := tty.OpenControllingTerm() if err != nil { return err } - tty_fd := controlling_term.Fd() self.controlling_term = controlling_term defer func() { self.controlling_term.RestoreAndClose() @@ -334,14 +369,55 @@ func (self *Loop) Run() (err error) { return nil } - selector := CreateSelect(8) - selector.RegisterRead(int(signal_read_file.Fd())) - selector.RegisterRead(tty_fd) - self.keep_going = true - self.flush_write_buf = true - self.queue_write_to_tty(self.terminal_options.SetStateEscapeCodes()) + self.tty_read_channel = make(chan []byte) + self.tty_write_channel = make(chan *write_msg, 1) // buffered so there is no race between initial queueing and startup of writer thread + self.write_done_channel = make(chan IdType) + self.tty_writing_done_channel = make(chan byte) + self.tty_reading_done_channel = make(chan byte) + self.wakeup_channel = make(chan byte, 256) + self.pending_writes = make([]*write_msg, 0, 256) + self.err_channel = make(chan error, 8) + self.death_signal = SIGNULL + self.escape_code_parser.Reset() + self.exit_code = 0 + no_timeout_channel := make(<-chan time.Time) finalizer := "" + + w_r, w_w, err := os.Pipe() + var r_r, r_w *os.File + if err == nil { + r_r, r_w, err = os.Pipe() + if err != nil { + w_r.Close() + w_w.Close() + return err + } + } else { + return err + } + self.QueueWriteBytesDangerous(self.terminal_options.SetStateEscapeCodes()) + + defer func() { + // notify tty reader that we are shutting down + r_w.Close() + close(self.tty_reading_done_channel) + + if finalizer != "" { + self.QueueWriteString(finalizer) + } + self.QueueWriteBytesDangerous(self.terminal_options.ResetStateEscapeCodes()) + // flush queued data and wait for it to be written for a timeout, then wait for writer to shutdown + flush_writer(w_w, self.tty_write_channel, self.tty_writing_done_channel, self.pending_writes, 2*time.Second) + self.pending_writes = nil + // wait for tty reader to exit cleanly + for more := true; more; _, more = <-self.tty_read_channel { + } + }() + + go write_to_tty(w_r, self.controlling_term, self.tty_write_channel, self.err_channel, self.write_done_channel, self.tty_writing_done_channel) + go read_from_tty(r_r, self.controlling_term, self.tty_read_channel, self.err_channel, self.tty_reading_done_channel) + if self.OnInitialize != nil { finalizer, err = self.OnInitialize(self) if err != nil { @@ -349,30 +425,9 @@ func (self *Loop) Run() (err error) { } } - defer func() { - if self.flush_write_buf { - self.flush() - } - self.write_buf = self.write_buf[:0] - if finalizer != "" { - self.queue_write_to_tty([]byte(finalizer)) - } - self.queue_write_to_tty(self.terminal_options.ResetStateEscapeCodes()) - self.flush() - }() - - read_buf := make([]byte, utils.DEFAULT_IO_BUFFER_SIZE) - signal_buf := make([]byte, 256) - self.death_signal = SIGNULL - self.escape_code_parser.Reset() - self.exit_code = 0 - num_ready := 0 for self.keep_going { - if len(self.write_buf) > 0 { - selector.RegisterWrite(tty_fd) - } else { - selector.UnRegisterWrite(tty_fd) - } + self.queue_write_to_tty(nil) + timeout_chan := no_timeout_channel if len(self.timers) > 0 { now := time.Now() err = self.dispatch_timers(now) @@ -383,69 +438,106 @@ func (self *Loop) Run() (err error) { if timeout < 0 { timeout = 0 } - num_ready, err = selector.Wait(timeout) - } else { - num_ready, err = selector.WaitForever() - if err != nil { - return fmt.Errorf("Failed to call select() with error: %w", err) - } + timeout_chan = time.After(timeout) } - if num_ready == 0 { - continue - } - if len(self.write_buf) > 0 && selector.IsReadyToWrite(tty_fd) { - err = self.write_to_tty() - if err != nil { - return err + select { + case <-timeout_chan: + case <-self.wakeup_channel: + for len(self.wakeup_channel) > 0 { + <-self.wakeup_channel } - if self.OnWriteComplete != nil && len(self.write_buf) == 0 { - err = self.OnWriteComplete(self) + case msg_id := <-self.write_done_channel: + self.queue_write_to_tty(nil) + if self.OnWriteComplete != nil { + err = self.OnWriteComplete(self, msg_id) if err != nil { return err } } - } - if selector.IsReadyToRead(tty_fd) { - read_buf = read_buf[:cap(read_buf)] - num_read, err := read_ignoring_temporary_errors(tty_fd, read_buf) + case s := <-sigchnl: + err = self.on_signal(s.(unix.Signal)) if err != nil { return err } - if num_read > 0 { - if self.OnReceivedData != nil { - err = self.OnReceivedData(self, read_buf[:num_read]) - if err != nil { - return err - } - } - err = self.escape_code_parser.Parse(read_buf[:num_read]) - if err != nil { - return err - } + case input_data, more := <-self.tty_read_channel: + if !more { + return io.EOF } - } - if selector.IsReadyToRead(int(signal_read_file.Fd())) { - signal_buf = signal_buf[:cap(signal_buf)] - err = self.read_signals(signal_read_file, signal_buf) + err := self.dispatch_input_data(input_data) if err != nil { return err } + } } return nil } -func (self *Loop) queue_write_to_tty(data []byte) { - self.write_buf = append(self.write_buf, data...) +func (self *Loop) dispatch_input_data(data []byte) error { + if self.OnReceivedData != nil { + err := self.OnReceivedData(self, data) + if err != nil { + return err + } + } + err := self.escape_code_parser.Parse(data) + if err != nil { + return err + } + return nil } -func (self *Loop) QueueWriteString(data string) { - self.queue_write_to_tty([]byte(data)) +func (self *Loop) print_stack() { + self.DebugPrintln(string(debug.Stack())) } -func (self *Loop) QueueWriteBytes(data []byte) { - self.queue_write_to_tty(data) +func (self *Loop) queue_write_to_tty(data *write_msg) { + for len(self.pending_writes) > 0 { + select { + case self.tty_write_channel <- self.pending_writes[0]: + n := copy(self.pending_writes, self.pending_writes[1:]) + self.pending_writes = self.pending_writes[:n] + default: + if data != nil { + self.pending_writes = append(self.pending_writes, data) + } + return + } + } + if data != nil { + select { + case self.tty_write_channel <- data: + default: + self.pending_writes = append(self.pending_writes, data) + } + } +} + +func (self *Loop) WakeupMainThread() { + self.wakeup_channel <- 1 +} + +func (self *Loop) QueueWriteString(data string) IdType { + self.write_msg_id_counter++ + msg := write_msg{str: data, id: self.write_msg_id_counter} + self.queue_write_to_tty(&msg) + return msg.id +} + +// This is dangerous as it is upto the calling code +// to ensure the data in the underlying array does not change +func (self *Loop) QueueWriteBytesDangerous(data []byte) IdType { + self.write_msg_id_counter++ + msg := write_msg{bytes: data, id: self.write_msg_id_counter} + self.queue_write_to_tty(&msg) + return msg.id +} + +func (self *Loop) QueueWriteBytesCopy(data []byte) IdType { + d := make([]byte, len(data)) + copy(d, data) + return self.QueueWriteBytesDangerous(d) } func (self *Loop) ExitCode() int { @@ -461,56 +553,208 @@ func (self *Loop) Quit(exit_code int) { self.keep_going = false } -func (self *Loop) write_to_tty() error { - if len(self.write_buf) == 0 || self.controlling_term == nil { - return nil - } - n, err := write_ignoring_temporary_errors(self.controlling_term.Fd(), self.write_buf) - if err != nil { - return err - } - if n <= 0 { - return nil - } - remainder := self.write_buf[n:] - if len(remainder) > 0 { - self.write_buf = self.write_buf[:len(remainder)] - copy(self.write_buf, remainder) - } else { - self.write_buf = self.write_buf[:0] - } - return nil -} +func read_from_tty(pipe_r *os.File, term *tty.Term, results_channel chan<- []byte, err_channel chan<- error, quit_channel <-chan byte) { + keep_going := true + pipe_fd := int(pipe_r.Fd()) + tty_fd := term.Fd() + selector := CreateSelect(2) + selector.RegisterRead(pipe_fd) + selector.RegisterRead(tty_fd) -func (self *Loop) flush() error { - if self.controlling_term == nil { - return nil + defer func() { + close(results_channel) + pipe_r.Close() + }() + + const bufsize = 2 * utils.DEFAULT_IO_BUFFER_SIZE + + wait_for_read_available := func() { + _, err := selector.WaitForever() + if err != nil { + err_channel <- err + keep_going = false + return + } + if selector.IsReadyToRead(pipe_fd) { + keep_going = false + return + } + if selector.IsReadyToRead(tty_fd) { + return + } } - selector := CreateSelect(1) - selector.RegisterWrite(self.controlling_term.Fd()) - deadline := time.Now().Add(2 * time.Second) - for len(self.write_buf) > 0 { - timeout := deadline.Sub(time.Now()) - if timeout < 0 { + + buf := make([]byte, bufsize) + for keep_going { + if len(buf) == 0 { + buf = make([]byte, bufsize) + } + if wait_for_read_available(); !keep_going { break } - num_ready, err := selector.Wait(timeout) + n, err := read_ignoring_temporary_errors(term, buf) if err != nil { - return err + err_channel <- err + keep_going = false + break } - if num_ready > 0 && selector.IsReadyToWrite(self.controlling_term.Fd()) { - err = self.write_to_tty() + if n == 0 { + err_channel <- io.EOF + keep_going = false + break + } + send := buf[:n] + buf = buf[n:] + select { + case results_channel <- send: + case <-quit_channel: + keep_going = false + break + } + } +} + +type write_dispatcher struct { + str string + bytes []byte + is_string bool + is_empty bool +} + +func create_write_dispatcher(msg *write_msg) *write_dispatcher { + self := write_dispatcher{str: msg.str, bytes: msg.bytes, is_string: msg.bytes == nil} + if self.is_string { + self.is_empty = self.str == "" + } else { + self.is_empty = len(self.bytes) == 0 + } + return &self +} + +func (self *write_dispatcher) write(f *tty.Term) (int, error) { + if self.is_string { + return writestring_ignoring_temporary_errors(f, self.str) + } + return write_ignoring_temporary_errors(f, self.bytes) +} + +func (self *write_dispatcher) slice(n int) { + if self.is_string { + self.str = self.str[n:] + self.is_empty = self.str == "" + } else { + self.bytes = self.bytes[n:] + self.is_empty = len(self.bytes) == 0 + } +} + +func write_to_tty( + pipe_r *os.File, term *tty.Term, + job_channel <-chan *write_msg, err_channel chan<- error, write_done_channel chan<- IdType, completed_channel chan<- byte, +) { + keep_going := true + defer func() { + pipe_r.Close() + close(completed_channel) + }() + selector := CreateSelect(2) + pipe_fd := int(pipe_r.Fd()) + tty_fd := term.Fd() + selector.RegisterRead(pipe_fd) + selector.RegisterWrite(tty_fd) + + wait_for_write_available := func() { + _, err := selector.WaitForever() + if err != nil { + err_channel <- err + keep_going = false + return + } + if selector.IsReadyToWrite(tty_fd) { + return + } + if selector.IsReadyToRead(pipe_fd) { + keep_going = false + } + } + + write_data := func(msg *write_msg) { + data := create_write_dispatcher(msg) + for !data.is_empty { + wait_for_write_available() + if !keep_going { + return + } + n, err := data.write(term) if err != nil { - return err + err_channel <- err + keep_going = false + return + } + if n > 0 { + data.slice(n) } } } - return nil + + for { + data, more := <-job_channel + if !more { + keep_going = false + break + } + write_data(data) + if keep_going { + write_done_channel <- data.id + } else { + break + } + } +} + +func flush_writer(pipe_w *os.File, tty_write_channel chan<- *write_msg, tty_writing_done_channel <-chan byte, pending_writes []*write_msg, timeout time.Duration) { + writer_quit := false + defer func() { + if tty_write_channel != nil { + close(tty_write_channel) + tty_write_channel = nil + } + pipe_w.Close() + if !writer_quit { + <-tty_writing_done_channel + writer_quit = true + } + }() + deadline := time.Now().Add(timeout) + for len(pending_writes) > 0 { + timeout = deadline.Sub(time.Now()) + if timeout <= 0 { + return + } + select { + case <-time.After(timeout): + return + case tty_write_channel <- pending_writes[0]: + pending_writes = pending_writes[1:] + } + } + close(tty_write_channel) + tty_write_channel = nil + timeout = deadline.Sub(time.Now()) + if timeout <= 0 { + return + } + select { + case <-tty_writing_done_channel: + writer_quit = true + case <-time.After(timeout): + } + return } func (self *Loop) dispatch_timers(now time.Time) error { updated := false - remove := make(map[TimerId]bool, 0) + remove := make(map[IdType]bool, 0) for _, t := range self.timers { if now.After(t.deadline) { err := t.callback(self, t.id) diff --git a/tools/tui/signal.go b/tools/tui/signal.go deleted file mode 100644 index 0dc283d47..000000000 --- a/tools/tui/signal.go +++ /dev/null @@ -1,183 +0,0 @@ -// License: GPLv3 Copyright: 2022, Kovid Goyal, - -package tui - -import ( - "fmt" - "golang.org/x/sys/unix" - "os" - "os/signal" -) - -type Signal byte - -const ( - SIGNULL Signal = 0 - SIGINT Signal = 1 - SIGTERM Signal = 2 - SIGTSTP Signal = 3 - SIGHUP Signal = 4 - SIGTTIN Signal = 5 - SIGTTOU Signal = 6 - SIGUSR1 Signal = 7 - SIGUSR2 Signal = 8 - SIGALRM Signal = 9 - SIGWINCH Signal = 10 - SIGPIPE Signal = 11 -) - -func (self *Signal) String() string { - switch *self { - case SIGNULL: - return "SIGNULL" - case SIGINT: - return "SIGINT" - case SIGTERM: - return "SIGTERM" - case SIGTSTP: - return "SIGTSTP" - case SIGHUP: - return "SIGHUP" - case SIGTTIN: - return "SIGTTIN" - case SIGTTOU: - return "SIGTTOU" - case SIGUSR1: - return "SIGUSR1" - case SIGUSR2: - return "SIGUSR2" - case SIGALRM: - return "SIGALRM" - case SIGWINCH: - return "SIGWINCH" - case SIGPIPE: - return "SIGPIPE" - default: - return fmt.Sprintf("SIG#%d", *self) - } -} - -func as_signal(which os.Signal) Signal { - switch which { - case os.Interrupt: - return SIGINT - case unix.SIGTERM: - return SIGTERM - case unix.SIGTSTP: - return SIGTSTP - case unix.SIGHUP: - return SIGHUP - case unix.SIGTTIN: - return SIGTTIN - case unix.SIGTTOU: - return SIGTTOU - case unix.SIGUSR1: - return SIGUSR1 - case unix.SIGUSR2: - return SIGUSR2 - case unix.SIGALRM: - return SIGALRM - case unix.SIGWINCH: - return SIGWINCH - case unix.SIGPIPE: - return SIGPIPE - default: - return SIGNULL - } -} - -const zero_go_signal = unix.Signal(0) - -func as_go_signal(which Signal) os.Signal { - switch which { - case SIGINT: - return os.Interrupt - case SIGTERM: - return unix.SIGTERM - case SIGTSTP: - return unix.SIGTSTP - case SIGHUP: - return unix.SIGHUP - case SIGTTIN: - return unix.SIGTTIN - case SIGTTOU: - return unix.SIGTTOU - case SIGUSR1: - return unix.SIGUSR1 - case SIGUSR2: - return unix.SIGUSR2 - case SIGALRM: - return unix.SIGALRM - case SIGWINCH: - return unix.SIGWINCH - case SIGPIPE: - return unix.SIGPIPE - default: - return zero_go_signal - } -} - -func write_signal(dest *os.File, which os.Signal) error { - b := make([]byte, 1) - b[0] = byte(as_signal(which)) - if b[0] == 0 { - return nil - } - _, err := dest.Write(b) - return err -} - -func notify_signals(c chan os.Signal, signals ...Signal) func() { - s := make([]os.Signal, len(signals)) - for i, x := range signals { - g := as_go_signal(x) - if g != zero_go_signal { - s[i] = g - } - } - signal.Notify(c, s...) - return func() { signal.Reset(s...) } -} - -func (self *Loop) read_signals(f *os.File, buf []byte) error { - n, err := f.Read(buf) - if err != nil { - return err - } - buf = buf[:n] - for _, s := range buf { - switch Signal(s) { - case SIGINT: - err := self.on_SIGINT() - if err != nil { - return err - } - case SIGTERM: - err := self.on_SIGTERM() - if err != nil { - return err - } - case SIGHUP: - err := self.on_SIGHUP() - if err != nil { - return err - } - case SIGWINCH: - err := self.on_SIGWINCH() - if err != nil { - return err - } - case SIGPIPE: - err := self.on_SIGPIPE() - if err != nil { - return err - } - case SIGTSTP: - err := self.on_SIGTSTP() - if err != nil { - return err - } - } - } - return nil -} diff --git a/tools/utils/io.go b/tools/utils/io.go index e9f922d65..5d4402674 100644 --- a/tools/utils/io.go +++ b/tools/utils/io.go @@ -2,39 +2,6 @@ package utils -import ( - "io" - "time" -) - const ( DEFAULT_IO_BUFFER_SIZE = 8192 ) - -type BytesReader struct { - Data []byte -} - -type Reader interface { - ReadWithTimeout(b []byte, timeout time.Duration) (n int, err error) - GetBuf() []byte -} - -func (self *BytesReader) Read(b []byte) (n int, err error) { - if len(self.Data) == 0 { - return 0, io.EOF - } - n = copy(b, self.Data) - self.Data = self.Data[n:] - return -} - -func (self *BytesReader) ReadWithTimeout(b []byte, timeout time.Duration) (n int, err error) { - return self.Read(b) -} - -func (self *BytesReader) GetBuf() (ans []byte) { - ans = self.Data - self.Data = make([]byte, 0) - return -}