diff --git a/kitty/graphics.c b/kitty/graphics.c index 72a1593e5..0bd114623 100644 --- a/kitty/graphics.c +++ b/kitty/graphics.c @@ -312,6 +312,7 @@ handle_add_command(GraphicsManager *self, const GraphicsCommand *g, const uint8_ } } int fd; + static char fname[2056] = {0}; switch(tt) { case 'd': // direct if (g->payload_sz >= img->load_data.buf_capacity - img->load_data.buf_used) { @@ -324,15 +325,17 @@ handle_add_command(GraphicsManager *self, const GraphicsCommand *g, const uint8_ case 'f': // file case 't': // temporary file case 's': // POSIX shared memory - if (tt == 's') fd = shm_open((const char*)payload, O_RDONLY, 0); - else fd = open((const char*)payload, O_CLOEXEC | O_RDONLY); + if (g->payload_sz > 2048) ABRT(EINVAL, "Filename too long"); + snprintf(fname, sizeof(fname)/sizeof(fname[0]), "%.*s", (int)g->payload_sz, payload); + if (tt == 's') fd = shm_open(fname, O_RDONLY, 0); + else fd = open(fname, O_CLOEXEC | O_RDONLY); if (fd == -1) { - ABRT(EBADF, "Failed to open file for graphics transmission with error: [%d] %s", errno, strerror(errno)); + ABRT(EBADF, "Failed to open file %s for graphics transmission with error: [%d] %s", fname, errno, strerror(errno)); } img->load_data.fd = fd; img->data_loaded = mmap_img_file(self, img); - if (tt == 't') unlink((const char*)payload); - else if (tt == 's') shm_unlink((const char*)payload); + if (tt == 't') unlink(fname); + else if (tt == 's') shm_unlink(fname); break; default: ABRT(EINVAL, "Unknown transmission type: %c", g->transmission_type); diff --git a/kitty_tests/graphics.py b/kitty_tests/graphics.py index 5414168f2..bc8259433 100644 --- a/kitty_tests/graphics.py +++ b/kitty_tests/graphics.py @@ -3,6 +3,7 @@ # License: GPL v3 Copyright: 2016, Kovid Goyal import os +import tempfile import zlib from base64 import standard_b64encode @@ -45,14 +46,14 @@ class TestGraphics(BaseTest): return res.decode('ascii').partition(';')[2].partition(':')[0].partition('\033')[0] def sl(payload, **kw): - pc = kw.pop('payload_check', None) if isinstance(payload, str): payload = payload.encode('utf-8') + data = kw.pop('expecting_data', payload) cid = kw.setdefault('i', 1) self.ae('OK', l(payload, **kw)) img = g.image_for_client_id(cid) self.ae(img['client_id'], cid) - self.ae(img['data'], payload if pc is None else pc) + self.ae(img['data'], data) if 's' in kw: self.ae((kw['s'], kw['v']), (img['width'], img['height'])) self.ae(img['is_4byte_aligned'], kw.get('f') != 24) @@ -74,4 +75,14 @@ class TestGraphics(BaseTest): # Test compression random_data = os.urandom(3 * 1024) - sl(zlib.compress(random_data), s=24, v=32, o='z', payload_check=random_data) + compressed_random_data = zlib.compress(random_data) + sl(compressed_random_data, s=24, v=32, o='z', expecting_data=random_data) + + # Test loading from file + f = tempfile.NamedTemporaryFile() + f.write(random_data), f.flush() + sl(f.name, s=24, v=32, t='f', expecting_data=random_data) + self.assertTrue(os.path.exists(f.name)) + f.seek(0), f.truncate(), f.write(compressed_random_data), f.flush() + sl(f.name, s=24, v=32, t='t', o='z', expecting_data=random_data) + self.assertRaises(FileNotFoundError, f.close) # check that file was deleted