diff --git a/tools/cmd/clipboard/legacy.go b/tools/cmd/clipboard/legacy.go index e97362100..12bc09d9f 100644 --- a/tools/cmd/clipboard/legacy.go +++ b/tools/cmd/clipboard/legacy.go @@ -3,6 +3,7 @@ package clipboard import ( + "bytes" "encoding/base64" "errors" "fmt" @@ -28,17 +29,74 @@ func encode_read_from_clipboard(use_primary bool) string { } type base64_streaming_enc struct { - output func(string) + output func(string) loop.IdType + last_written_id loop.IdType } func (self *base64_streaming_enc) Write(p []byte) (int, error) { if len(p) > 0 { - self.output(string(p)) + self.last_written_id = self.output(string(p)) } return len(p), nil } +var ErrTooMuchPipedData = errors.New("Too much piped data") + +func read_all_with_max_size(r io.Reader, max_size int) ([]byte, error) { + b := make([]byte, 0, utils.Min(8192, max_size)) + for { + if len(b) == cap(b) { + new_size := utils.Min(2*cap(b), max_size) + if new_size <= cap(b) { + return b, ErrTooMuchPipedData + } + b = append(make([]byte, 0, new_size), b...) + } + n, err := r.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if err == io.EOF { + err = nil + } + return b, err + } + } +} + func run_plain_text_loop(opts *Options) (err error) { + stdin_is_tty := tty.IsTerminal(os.Stdin.Fd()) + var stdin_data []byte + var data_src io.Reader + var tempfile *os.File + if !stdin_is_tty { + // we pre-read STDIN because otherwise if the output of a command is being piped in + // and that command itself transmits on the tty we will break. For example + // kitten @ ls | kitten clipboard + stdin_data, err = read_all_with_max_size(os.Stdin, 2*1024*1024) + if err == nil { + os.Stdin.Close() + } else if err != ErrTooMuchPipedData { + return fmt.Errorf("Failed to read from STDIN pipe with error: %w", err) + } + } + if err == ErrTooMuchPipedData { + tempfile, err = utils.CreateAnonymousTemp("") + if err != nil { + return fmt.Errorf("Failed to create a temporary from STDIN pipe with error: %w", err) + } + defer tempfile.Close() + tempfile.Write(stdin_data) + _, err = io.Copy(tempfile, os.Stdin) + if err != nil { + return fmt.Errorf("Failed to copy data from STDIN pipe to temp file with error: %w", err) + } + os.Stdin.Close() + tempfile.Seek(0, os.SEEK_SET) + data_src = tempfile + } else if stdin_data != nil { + data_src = bytes.NewBuffer(stdin_data) + } + lp, err := loop.New(loop.NoAlternateScreen, loop.NoRestoreColors, loop.NoMouseTracking) if err != nil { return @@ -47,13 +105,12 @@ func run_plain_text_loop(opts *Options) (err error) { if opts.UsePrimary { dest = "p" } - stdin_is_tty := tty.IsTerminal(os.Stdin.Fd()) - var buf [8192]byte - send_to_loop := func(data string) { - lp.QueueWriteString(data) + send_to_loop := func(data string) loop.IdType { + return lp.QueueWriteString(data) } - enc := base64.NewEncoder(base64.StdEncoding, &base64_streaming_enc{send_to_loop}) + enc_writer := base64_streaming_enc{output: send_to_loop} + enc := base64.NewEncoder(base64.StdEncoding, &enc_writer) transmitting := true after_read_from_stdin := func() { @@ -67,37 +124,38 @@ func run_plain_text_loop(opts *Options) (err error) { } } - read_from_stdin := func() error { - n, err := os.Stdin.Read(buf[:]) + buf := make([]byte, 8192) + write_one_chunk := func() error { + n, err := data_src.Read(buf[:cap(buf)]) + if err != nil && !errors.Is(err, io.EOF) { + send_to_loop("\x1b\\") + return err + } if n > 0 { enc.Write(buf[:n]) } - if err != nil { - if errors.Is(err, io.EOF) { - enc.Close() - send_to_loop("\x1b\\") - os.Stdin.Close() - after_read_from_stdin() - return nil - } - return fmt.Errorf("Failed to read from STDIN with error: %w", err) + if errors.Is(err, io.EOF) { + enc.Close() + send_to_loop("\x1b\\") + after_read_from_stdin() } - lp.WakeupMainThread() return nil } - lp.OnWakeup = func() error { - return read_from_stdin() + lp.OnInitialize = func() (string, error) { + if data_src != nil { + send_to_loop(fmt.Sprintf("\x1b]52;%s;", dest)) + return "", write_one_chunk() + } + after_read_from_stdin() + return "", nil } - lp.OnInitialize = func() (string, error) { - if !stdin_is_tty { - send_to_loop(fmt.Sprintf("\x1b]52;%s;", dest)) - read_from_stdin() - } else { - after_read_from_stdin() + lp.OnWriteComplete = func(id loop.IdType) error { + if id == enc_writer.last_written_id { + return write_one_chunk() } - return "", nil + return nil } var clipboard_contents []byte