diff --git a/tools/tui/graphics/command.go b/tools/tui/graphics/command.go index 9b8fab835..65918d091 100644 --- a/tools/tui/graphics/command.go +++ b/tools/tui/graphics/command.go @@ -510,6 +510,13 @@ func compress_with_zlib(data []byte) []byte { return b.Bytes() } +func (self *GraphicsCommand) AsAPC(payload []byte) string { + buf := strings.Builder{} + buf.Grow(1024) + self.WriteWithPayloadTo(&buf, payload) + return buf.String() +} + func (self *GraphicsCommand) WriteWithPayloadTo(o io.StringWriter, payload []byte) (err error) { const compression_threshold = 1024 if len(payload) == 0 { @@ -526,16 +533,23 @@ func (self *GraphicsCommand) WriteWithPayloadTo(o io.StringWriter, payload []byt } gc.SetDataSize(uint64(len(payload))) data := base64.StdEncoding.EncodeToString(payload) - for len(data) > 0 && err != nil { - chunk := data[:4096] - data = data[4096:] + for len(data) > 0 && err == nil { + chunk := data + if len(data) > 4096 { + chunk = data[:4096] + data = data[4096:] + } else { + data = "" + } if len(data) > 0 { gc.m = GRT_more_more } else { gc.m = GRT_more_nomore } err = gc.serialize_to(o, chunk) - gc = GraphicsCommand{} + if gc.DataSize() > 0 { + gc = GraphicsCommand{} + } } return } @@ -664,6 +678,7 @@ func GraphicsCommandFromAPCPayload(raw []byte) *GraphicsCommand { if state == expecting_value { add_key(pos) } + state = expecting_key payload_start_at = pos + 1 break } @@ -685,6 +700,9 @@ func GraphicsCommandFromAPCPayload(raw []byte) *GraphicsCommand { } } } + if state == expecting_value { + add_key(len(raw)) + } if payload_start_at > -1 { payload := raw[payload_start_at:] if len(payload) > 0 { diff --git a/tools/tui/graphics/command_test.go b/tools/tui/graphics/command_test.go new file mode 100644 index 000000000..90d7c3fb0 --- /dev/null +++ b/tools/tui/graphics/command_test.go @@ -0,0 +1,103 @@ +// License: GPLv3 Copyright: 2022, Kovid Goyal, + +package graphics + +import ( + "bytes" + "compress/zlib" + "encoding/base64" + "fmt" + "io" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/exp/rand" +) + +var _ = fmt.Print + +func from_full_apc_escape_code(raw string) *GraphicsCommand { + return GraphicsCommandFromAPC([]byte(raw[2 : len(raw)-2])) +} + +func TestGraphicsCommandSerialization(t *testing.T) { + gc := &GraphicsCommand{} + + test_serialize := func(payload string, vals ...string) { + expected := "\033_G" + strings.Join(vals, ",") + if payload != "" { + expected += ";" + base64.StdEncoding.EncodeToString([]byte(payload)) + } + expected += "\033\\" + if diff := cmp.Diff(expected, gc.AsAPC([]byte(payload))); diff != "" { + t.Fatalf("Failed to write vals: %#v with payload: %#v\n%s", vals, payload, diff) + } + } + + test_chunked_payload := func(payload []byte) { + c := &GraphicsCommand{} + data := c.AsAPC([]byte(payload)) + encoded := strings.Builder{} + compressed := false + var data_size uint64 + for { + idx := strings.Index(data, "\033_") + if idx < 0 { + break + } + l := strings.Index(data, "\033\\") + apc := data[idx+2 : l] + data = data[l+2:] + g := GraphicsCommandFromAPC([]byte(apc)) + if data_size == 0 { + data_size = g.DataSize() + compressed = g.Compression() != 0 + } + encoded.WriteString(g.ResponseMessage()) + if g.m == GRT_more_nomore { + break + } + } + if len(data) > 0 { + t.Fatalf("Unparsed remnant: %#v", string(data)) + } + decoded, err := base64.StdEncoding.DecodeString(encoded.String()) + if err != nil { + t.Fatalf("Encoded data not valid base-64 with error: %v", err) + } + if data_size > 0 && uint64(len(decoded)) != data_size { + t.Fatalf("Data size %d != decoded size %d", data_size, len(decoded)) + } + if compressed { + b := bytes.Buffer{} + b.Write(decoded) + r, _ := zlib.NewReader(&b) + o := bytes.Buffer{} + io.Copy(&o, r) + r.Close() + decoded = o.Bytes() + } + if diff := cmp.Diff(payload, decoded); diff != "" { + t.Fatalf("Decoded payload does not match original\nlen decoded=%d len payload=%d", len(decoded), len(payload)) + } + } + + test_serialize("") + gc.SetTransmission(GRT_transmission_sharedmem).SetAction(GRT_action_query).SetZIndex(-3).SetWidth(33).SetImageId(11) + test_serialize("abcd", "a=q", "t=s", "w=33", "i=11", "z=-3") + q := from_full_apc_escape_code(gc.AsAPC([]byte("abcd"))) + if diff := cmp.Diff(gc.AsAPC(nil), q.AsAPC(nil)); diff != "" { + t.Fatalf("Parsing failed:\n%s", diff) + } + if diff := cmp.Diff(q.response_message, base64.StdEncoding.EncodeToString([]byte("abcd"))); diff != "" { + t.Fatalf("Failed to parse payload:\n%s", diff) + } + + test_chunked_payload([]byte("abcd")) + data := make([]byte, 8111) + rand.Read(data) + test_chunked_payload(data) + test_chunked_payload([]byte(strings.Repeat("a", 8007))) + +}