diff --git a/tools/utils/shm/shm.go b/tools/utils/shm/shm.go index 02427183e..3c512c217 100644 --- a/tools/utils/shm/shm.go +++ b/tools/utils/shm/shm.go @@ -53,6 +53,9 @@ type MMap interface { FileSystemName() string Stat() (fs.FileInfo, error) Flush() error + Seek(offset int64, whence int) (ret int64, err error) + Read(b []byte) (n int, err error) + Write(b []byte) (n int, err error) } type AccessFlags int @@ -106,9 +109,11 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) { const NUM_BYTES_FOR_SIZE = 4 +var ErrRegionTooSmall = errors.New("mmaped region too small") + func WriteWithSize(self MMap, b []byte, at int) error { if len(self.Slice()) < at+len(b)+NUM_BYTES_FOR_SIZE { - return io.ErrShortBuffer + return ErrRegionTooSmall } binary.BigEndian.PutUint32(self.Slice()[at:], uint32(len(b))) copy(self.Slice()[at+NUM_BYTES_FOR_SIZE:], b) @@ -118,12 +123,12 @@ func WriteWithSize(self MMap, b []byte, at int) error { func ReadWithSize(self MMap, at int) ([]byte, error) { s := self.Slice()[at:] if len(s) < NUM_BYTES_FOR_SIZE { - return nil, io.ErrShortBuffer + return nil, ErrRegionTooSmall } size := int(binary.BigEndian.Uint32(self.Slice()[at : at+NUM_BYTES_FOR_SIZE])) s = s[NUM_BYTES_FOR_SIZE:] if len(s) < size { - return nil, io.ErrShortBuffer + return nil, ErrRegionTooSmall } return s[:size], nil } @@ -158,6 +163,41 @@ func ReadWithSizeAndUnlink(name string, file_callback ...func(fs.FileInfo) error return ans, nil } +func Read(self MMap, b []byte) (n int, err error) { + pos, _ := self.Seek(0, io.SeekCurrent) + if pos < 0 { + pos = 0 + } + s := self.Slice() + sz := int64(len(s)) + if pos >= sz { + return 0, io.EOF + } + n = copy(b, s[pos:]) + self.Seek(int64(n), io.SeekCurrent) + return +} + +func Write(self MMap, b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + pos, _ := self.Seek(0, io.SeekCurrent) + if pos < 0 { + pos = 0 + } + s := self.Slice() + if pos >= int64(len(s)) { + return 0, io.ErrShortWrite + } + n = copy(s[pos:], b) + self.Seek(int64(n), io.SeekCurrent) + if n < len(b) { + return n, io.ErrShortWrite + } + return n, nil +} + func test_integration_with_python(args []string) (rc int, err error) { switch args[0] { default: diff --git a/tools/utils/shm/shm_fs.go b/tools/utils/shm/shm_fs.go index f0fd82003..bf0a817a1 100644 --- a/tools/utils/shm/shm_fs.go +++ b/tools/utils/shm/shm_fs.go @@ -7,6 +7,7 @@ import ( "crypto/sha256" "errors" "fmt" + "io" "io/fs" "os" "path/filepath" @@ -21,6 +22,7 @@ var _ = fmt.Print type file_based_mmap struct { f *os.File + pos int64 region []byte unlinked bool special_name string @@ -42,6 +44,26 @@ func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool, speci return &file_based_mmap{f: f, region: region, special_name: special_name}, nil } +func (self *file_based_mmap) Seek(offset int64, whence int) (ret int64, err error) { + switch whence { + case io.SeekStart: + self.pos = offset + case os.SEEK_END: + self.pos = int64(len(self.region)) + offset + case os.SEEK_CUR: + self.pos += offset + } + return self.pos, nil +} + +func (self *file_based_mmap) Read(b []byte) (n int, err error) { + return Read(self, b) +} + +func (self *file_based_mmap) Write(b []byte) (n int, err error) { + return Write(self, b) +} + func (self *file_based_mmap) Stat() (fs.FileInfo, error) { return self.f.Stat() } diff --git a/tools/utils/shm/shm_syscall.go b/tools/utils/shm/shm_syscall.go index c1718e826..14203d2bd 100644 --- a/tools/utils/shm/shm_syscall.go +++ b/tools/utils/shm/shm_syscall.go @@ -6,6 +6,7 @@ package shm import ( "errors" "fmt" + "io" "io/fs" "os" "strings" @@ -65,6 +66,7 @@ func shm_open(name string, flags, perm int) (ans *os.File, err error) { type syscall_based_mmap struct { f *os.File + pos int64 region []byte unlinked bool } @@ -117,6 +119,26 @@ func (self *syscall_based_mmap) Unlink() (err error) { return shm_unlink(self.Name()) } +func (self *syscall_based_mmap) Seek(offset int64, whence int) (ret int64, err error) { + switch whence { + case io.SeekStart: + self.pos = offset + case os.SEEK_END: + self.pos = int64(len(self.region)) + offset + case os.SEEK_CUR: + self.pos += offset + } + return self.pos, nil +} + +func (self *syscall_based_mmap) Read(b []byte) (n int, err error) { + return Read(self, b) +} + +func (self *syscall_based_mmap) Write(b []byte) (n int, err error) { + return Write(self, b) +} + func (self *syscall_based_mmap) IsFileSystemBacked() bool { return false } func (self *syscall_based_mmap) FileSystemName() string { return "" }