Allow propagating errors from the escape code handlers

This commit is contained in:
Kovid Goyal 2022-08-23 19:51:31 +05:30
parent 3c3e7b7f70
commit e18b6638bb
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 64 additions and 74 deletions

View File

@ -2,7 +2,6 @@ package utils
import ( import (
"bytes" "bytes"
"fmt"
) )
type parser_state uint8 type parser_state uint8
@ -40,35 +39,19 @@ type EscapeCodeParser struct {
csi_state csi_state csi_state csi_state
current_buffer []byte current_buffer []byte
bracketed_paste_buffer []UTF8State bracketed_paste_buffer []UTF8State
current_callback func([]byte) current_callback func([]byte) error
// Whether to send escape code bytes as soon as they are received or to
// buffer and send full escape codes
streaming bool
// Callbacks // Callbacks
HandleRune func(rune) HandleRune func(rune) error
HandleCSI func([]byte) HandleCSI func([]byte) error
HandleOSC func([]byte) HandleOSC func([]byte) error
HandleDCS func([]byte) HandleDCS func([]byte) error
HandlePM func([]byte) HandlePM func([]byte) error
HandleSOS func([]byte) HandleSOS func([]byte) error
HandleAPC func([]byte) HandleAPC func([]byte) error
} }
func (self *EscapeCodeParser) SetStreaming(streaming bool) error { func (self *EscapeCodeParser) Parse(data []byte) error {
if self.state != normal || len(self.current_buffer) > 0 {
return fmt.Errorf("Cannot change streaming state when not in reset state")
}
self.streaming = streaming
return nil
}
func (self *EscapeCodeParser) IsStreaming() bool {
return self.streaming
}
func (self *EscapeCodeParser) Parse(data []byte) {
prev := UTF8_ACCEPT prev := UTF8_ACCEPT
codep := UTF8_ACCEPT codep := UTF8_ACCEPT
for i := 0; i < len(data); i++ { for i := 0; i < len(data); i++ {
@ -76,7 +59,11 @@ func (self *EscapeCodeParser) Parse(data []byte) {
case normal, bracketed_paste: case normal, bracketed_paste:
switch decode_utf8(&self.utf8_state, &codep, data[i]) { switch decode_utf8(&self.utf8_state, &codep, data[i]) {
case UTF8_ACCEPT: case UTF8_ACCEPT:
self.dispatch_char(codep) err := self.dispatch_char(codep)
if err != nil {
self.Reset()
return err
}
case UTF8_REJECT: case UTF8_REJECT:
self.utf8_state = UTF8_ACCEPT self.utf8_state = UTF8_ACCEPT
if prev != UTF8_ACCEPT && i > 0 { if prev != UTF8_ACCEPT && i > 0 {
@ -85,9 +72,14 @@ func (self *EscapeCodeParser) Parse(data []byte) {
} }
prev = self.utf8_state prev = self.utf8_state
default: default:
self.dispatch_byte(data[i]) err := self.dispatch_byte(data[i])
if err != nil {
self.Reset()
return err
} }
} }
}
return nil
} }
func (self *EscapeCodeParser) Reset() { func (self *EscapeCodeParser) Reset() {
@ -95,17 +87,7 @@ func (self *EscapeCodeParser) Reset() {
} }
func (self *EscapeCodeParser) write_ch(ch byte) { func (self *EscapeCodeParser) write_ch(ch byte) {
if self.streaming {
if self.current_callback != nil {
var data [1]byte = [1]byte{ch}
self.current_callback(data[:])
}
if self.state == csi && len(self.current_buffer) < 4 {
self.current_buffer = append(self.current_buffer, ch) self.current_buffer = append(self.current_buffer, ch)
}
} else {
self.current_buffer = append(self.current_buffer, ch)
}
} }
func csi_type(ch byte) csi_char_type { func csi_type(ch byte) csi_char_type {
@ -130,26 +112,29 @@ func (self *EscapeCodeParser) reset_state() {
self.csi_state = parameter self.csi_state = parameter
} }
func (self *EscapeCodeParser) dispatch_esc_code() { func (self *EscapeCodeParser) dispatch_esc_code() error {
if self.state == csi && bytes.Equal(self.current_buffer, bracketed_paste_start) { if self.state == csi && bytes.Equal(self.current_buffer, bracketed_paste_start) {
self.reset_state() self.reset_state()
self.state = bracketed_paste self.state = bracketed_paste
return return nil
} }
var err error
if self.current_callback != nil { if self.current_callback != nil {
self.current_callback(self.current_buffer) err = self.current_callback(self.current_buffer)
} }
self.reset_state() self.reset_state()
return err
} }
func (self *EscapeCodeParser) invalid_escape_code() { func (self *EscapeCodeParser) invalid_escape_code() {
self.reset_state() self.reset_state()
} }
func (self *EscapeCodeParser) dispatch_rune(ch UTF8State) { func (self *EscapeCodeParser) dispatch_rune(ch UTF8State) error {
if self.HandleRune != nil { if self.HandleRune != nil {
self.HandleRune(rune(ch)) return self.HandleRune(rune(ch))
} }
return nil
} }
func (self *EscapeCodeParser) bp_buffer_equals(chars []UTF8State) bool { func (self *EscapeCodeParser) bp_buffer_equals(chars []UTF8State) bool {
@ -164,44 +149,47 @@ func (self *EscapeCodeParser) bp_buffer_equals(chars []UTF8State) bool {
return true return true
} }
func (self *EscapeCodeParser) dispatch_char(ch UTF8State) { func (self *EscapeCodeParser) dispatch_char(ch UTF8State) error {
if self.state == bracketed_paste { if self.state == bracketed_paste {
dispatch := func() { dispatch := func() error {
if len(self.bracketed_paste_buffer) > 0 { if len(self.bracketed_paste_buffer) > 0 {
for _, c := range self.bracketed_paste_buffer { for _, c := range self.bracketed_paste_buffer {
self.dispatch_rune(c) err := self.dispatch_rune(c)
if err != nil {
return err
}
} }
self.bracketed_paste_buffer = self.bracketed_paste_buffer[:0] self.bracketed_paste_buffer = self.bracketed_paste_buffer[:0]
} }
self.dispatch_rune(ch) return self.dispatch_rune(ch)
} }
handle_ch := func(chars ...UTF8State) { handle_ch := func(chars ...UTF8State) error {
if self.bp_buffer_equals(chars) { if self.bp_buffer_equals(chars) {
self.bracketed_paste_buffer = append(self.bracketed_paste_buffer, ch) self.bracketed_paste_buffer = append(self.bracketed_paste_buffer, ch)
if self.bracketed_paste_buffer[len(self.bracketed_paste_buffer)-1] == '~' { if self.bracketed_paste_buffer[len(self.bracketed_paste_buffer)-1] == '~' {
self.reset_state() self.reset_state()
} }
return nil
} else { } else {
dispatch() return dispatch()
} }
} }
switch ch { switch ch {
case 0x1b: case 0x1b:
handle_ch() return handle_ch()
case '[': case '[':
handle_ch(0x1b) return handle_ch(0x1b)
case '2': case '2':
handle_ch(0x1b, '[') return handle_ch(0x1b, '[')
case '0': case '0':
handle_ch(0x1b, '[', '2') return handle_ch(0x1b, '[', '2')
case '1': case '1':
handle_ch(0x1b, '[', '2', '0') return handle_ch(0x1b, '[', '2', '0')
case '~': case '~':
handle_ch(0x1b, '[', '2', '0', '1') return handle_ch(0x1b, '[', '2', '0', '1')
default: default:
dispatch() return dispatch()
} }
return
} // end self.state == bracketed_paste } // end self.state == bracketed_paste
switch ch { switch ch {
@ -226,11 +214,12 @@ func (self *EscapeCodeParser) dispatch_char(ch UTF8State) {
self.state = st self.state = st
self.current_callback = self.HandleAPC self.current_callback = self.HandleAPC
default: default:
self.dispatch_rune(ch) return self.dispatch_rune(ch)
} }
return nil
} }
func (self *EscapeCodeParser) dispatch_byte(ch byte) { func (self *EscapeCodeParser) dispatch_byte(ch byte) error {
switch self.state { switch self.state {
case esc: case esc:
switch ch { switch ch {
@ -261,7 +250,7 @@ func (self *EscapeCodeParser) dispatch_byte(ch byte) {
case intermediate_csi_char: case intermediate_csi_char:
self.csi_state = intermediate self.csi_state = intermediate
case final_csi_char: case final_csi_char:
self.dispatch_esc_code() return self.dispatch_esc_code()
case unknown_csi_char: case unknown_csi_char:
self.invalid_escape_code() self.invalid_escape_code()
} }
@ -270,13 +259,12 @@ func (self *EscapeCodeParser) dispatch_byte(ch byte) {
case parameter_csi_char, unknown_csi_char: case parameter_csi_char, unknown_csi_char:
self.invalid_escape_code() self.invalid_escape_code()
case final_csi_char: case final_csi_char:
self.dispatch_esc_code() return self.dispatch_esc_code()
} }
} }
case st_or_bel: case st_or_bel:
if ch == 0x7 { if ch == 0x7 {
self.dispatch_esc_code() return self.dispatch_esc_code()
return
} }
fallthrough fallthrough
case st: case st:
@ -289,7 +277,7 @@ func (self *EscapeCodeParser) dispatch_byte(ch byte) {
} }
case esc_st: case esc_st:
if ch == '\\' { if ch == '\\' {
self.dispatch_esc_code() return self.dispatch_esc_code()
} else { } else {
self.state = st self.state = st
self.write_ch(0x1b) self.write_ch(0x1b)
@ -299,11 +287,12 @@ func (self *EscapeCodeParser) dispatch_byte(ch byte) {
} }
case c1_st: case c1_st:
if ch == 0x9c { if ch == 0x9c {
self.dispatch_esc_code() return self.dispatch_esc_code()
} else { } else {
self.state = st self.state = st
self.write_ch(0xc2) self.write_ch(0xc2)
self.write_ch(ch) self.write_ch(ch)
} }
} }
return nil
} }

View File

@ -10,18 +10,19 @@ func TestEscapeCodeParsing(t *testing.T) {
} }
var d test_parse_collection var d test_parse_collection
add := func(prefix string, b []byte) { add := func(prefix string, b []byte) error {
d.actual += "\n" + prefix + ": " + string(b) d.actual += "\n" + prefix + ": " + string(b)
return nil
} }
var test_parser = EscapeCodeParser{ var test_parser = EscapeCodeParser{
HandleCSI: func(b []byte) { add("CSI", b) }, HandleCSI: func(b []byte) error { return add("CSI", b) },
HandleOSC: func(b []byte) { add("OSC", b) }, HandleOSC: func(b []byte) error { return add("OSC", b) },
HandleDCS: func(b []byte) { add("DCS", b) }, HandleDCS: func(b []byte) error { return add("DCS", b) },
HandleSOS: func(b []byte) { add("SOS", b) }, HandleSOS: func(b []byte) error { return add("SOS", b) },
HandlePM: func(b []byte) { add("PM", b) }, HandlePM: func(b []byte) error { return add("PM", b) },
HandleAPC: func(b []byte) { add("APC", b) }, HandleAPC: func(b []byte) error { return add("APC", b) },
HandleRune: func(b rune) { add("CH", []byte(string(b))) }, HandleRune: func(b rune) error { return add("CH", []byte(string(b))) },
} }
reset_test_parser := func() { reset_test_parser := func() {