From 7e161ea94ba86757f9f76cfe006760561e718971 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Tue, 20 Dec 2022 20:29:46 +0530 Subject: [PATCH] Get syscall based SHM working --- tools/utils/shm/shm.go | 42 ++++++++++++++++++---------------- tools/utils/shm/shm_fs.go | 11 +++------ tools/utils/shm/shm_syscall.go | 32 +++++++++++++------------- tools/utils/shm/shm_test.go | 30 +++++++++++++++++------- 4 files changed, 63 insertions(+), 52 deletions(-) diff --git a/tools/utils/shm/shm.go b/tools/utils/shm/shm.go index 3aa423112..f7c287f58 100644 --- a/tools/utils/shm/shm.go +++ b/tools/utils/shm/shm.go @@ -51,34 +51,27 @@ type MMap interface { Unlink() error Slice() []byte Name() string + IsFilesystemBacked() bool } -type ProtectionFlags int +type AccessFlags int const ( - READ ProtectionFlags = iota + READ AccessFlags = iota + WRITE COPY - RDWR - EXEC - ANON ) -func mmap(sz int, inprot ProtectionFlags, anonymous bool, fd int, off int64) ([]byte, error) { +func mmap(sz int, access AccessFlags, fd int, off int64) ([]byte, error) { flags := unix.MAP_SHARED prot := unix.PROT_READ - switch { - case inprot© != 0: + switch access { + case COPY: prot |= unix.PROT_WRITE flags = unix.MAP_PRIVATE - case inprot&RDWR != 0: + case WRITE: prot |= unix.PROT_WRITE } - if inprot&EXEC != 0 { - prot |= unix.PROT_EXEC - } - if anonymous { - flags |= unix.MAP_ANON - } b, err := unix.Mmap(fd, off, sz, prot, flags) if err != nil { @@ -87,20 +80,24 @@ func mmap(sz int, inprot ProtectionFlags, anonymous bool, fd int, off int64) ([] return b, nil } +func munmap(s []byte) error { + return unix.Munmap(s) +} + type file_based_mmap struct { f *os.File region []byte unlinked bool } -func file_mmap(f *os.File, size uint64, access ProtectionFlags, truncate bool) (MMap, error) { +func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool) (MMap, error) { if truncate { err := truncate_or_unlink(f, size) if err != nil { return nil, err } } - region, err := mmap(int(size), access, false, int(f.Fd()), 0) + region, err := mmap(int(size), access, int(f.Fd()), 0) if err != nil { f.Close() os.Remove(f.Name()) @@ -117,9 +114,12 @@ func (self *file_based_mmap) Slice() []byte { return self.region } -func (self *file_based_mmap) Close() error { - err := self.f.Close() - self.region = nil +func (self *file_based_mmap) Close() (err error) { + if self.region != nil { + self.f.Close() + err = munmap(self.region) + self.region = nil + } return err } @@ -131,6 +131,8 @@ func (self *file_based_mmap) Unlink() (err error) { return os.Remove(self.f.Name()) } +func (self *file_based_mmap) IsFilesystemBacked() bool { return true } + func CreateTemp(pattern string, size uint64) (MMap, error) { return create_temp(pattern, size) } diff --git a/tools/utils/shm/shm_fs.go b/tools/utils/shm/shm_fs.go index 6c3290344..dbadff30c 100644 --- a/tools/utils/shm/shm_fs.go +++ b/tools/utils/shm/shm_fs.go @@ -16,10 +16,10 @@ func create_temp(pattern string, size uint64) (MMap, error) { if err != nil { return nil, err } - return file_mmap(ans, size, RDWR, true) + return file_mmap(ans, size, WRITE, true) } -func Open(name string) (MMap, error) { +func Open(name string, size uint64) (MMap, error) { if !filepath.IsAbs(name) { name = filepath.Join(SHM_DIR, name) } @@ -27,10 +27,5 @@ func Open(name string) (MMap, error) { if err != nil { return nil, err } - s, err := os.Stat(name) - if err != nil { - ans.Close() - return nil, err - } - return file_mmap(ans, uint64(s.Size()), READ, false) + return file_mmap(ans, size, READ, false) } diff --git a/tools/utils/shm/shm_syscall.go b/tools/utils/shm/shm_syscall.go index c7493937d..1dc5c6196 100644 --- a/tools/utils/shm/shm_syscall.go +++ b/tools/utils/shm/shm_syscall.go @@ -31,7 +31,7 @@ func BytePtrFromString(s string) *byte { func shm_unlink(name string) (err error) { bname := BytePtrFromString(name) for { - _, _, errno := unix.Syscall(unix.SYS_SHM_OPEN, uintptr(unsafe.Pointer(bname)), 0, 0) + _, _, errno := unix.Syscall(unix.SYS_SHM_UNLINK, uintptr(unsafe.Pointer(bname)), 0, 0) if errno != unix.EINTR { if errno != 0 { err = fmt.Errorf("shm_unlink() failed with error: %w", errno) @@ -67,14 +67,14 @@ type syscall_based_mmap struct { unlinked bool } -func syscall_mmap(f *os.File, size uint64, access ProtectionFlags, truncate bool) (MMap, error) { +func syscall_mmap(f *os.File, size uint64, access AccessFlags, truncate bool) (MMap, error) { if truncate { err := truncate_or_unlink(f, size) if err != nil { return nil, fmt.Errorf("truncate failed with error: %w", err) } } - region, err := mmap(int(size), access, false, int(f.Fd()), 0) + region, err := mmap(int(size), access, int(f.Fd()), 0) if err != nil { f.Close() os.Remove(f.Name()) @@ -91,10 +91,13 @@ func (self *syscall_based_mmap) Slice() []byte { return self.region } -func (self *syscall_based_mmap) Close() error { - err := self.f.Close() - self.region = nil - return err +func (self *syscall_based_mmap) Close() (err error) { + if self.region != nil { + self.f.Close() + munmap(self.region) + self.region = nil + } + return } func (self *syscall_based_mmap) Unlink() (err error) { @@ -105,6 +108,8 @@ func (self *syscall_based_mmap) Unlink() (err error) { return shm_unlink(self.Name()) } +func (self *syscall_based_mmap) IsFilesystemBacked() bool { return false } + func create_temp(pattern string, size uint64) (ans MMap, err error) { var prefix, suffix string prefix, suffix, err = prefix_and_suffix(pattern) @@ -126,7 +131,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) { if err != nil && (errors.Is(err, fs.ErrExist) || errors.Unwrap(err) == unix.EEXIST) { try += 1 if try > 10000 { - return nil, &os.PathError{Op: "createtemp", Path: prefix + "*" + suffix, Err: ErrExist} + return nil, &os.PathError{Op: "createtemp", Path: prefix + "*" + suffix, Err: fs.ErrExist} } continue } @@ -135,18 +140,13 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) { if err != nil { return nil, err } - return syscall_mmap(f, size, RDWR, true) + return syscall_mmap(f, size, WRITE, true) } -func Open(name string) (MMap, error) { +func Open(name string, size uint64) (MMap, error) { ans, err := shm_open(name, os.O_RDONLY, 0) if err != nil { return nil, err } - s, err := os.Stat(name) - if err != nil { - ans.Close() - return nil, err - } - return syscall_mmap(ans, uint64(s.Size()), READ, false) + return syscall_mmap(ans, size, READ, false) } diff --git a/tools/utils/shm/shm_test.go b/tools/utils/shm/shm_test.go index af9837112..0574e7ba1 100644 --- a/tools/utils/shm/shm_test.go +++ b/tools/utils/shm/shm_test.go @@ -23,9 +23,12 @@ func TestSHM(t *testing.T) { } copy(mm.Slice(), data) - mm.Close() + err = mm.Close() + if err != nil { + t.Fatalf("Failed to close with error: %v", err) + } - g, err := Open(mm.Name()) + g, err := Open(mm.Name(), uint64(len(data))) if err != nil { t.Fatal(err) } @@ -33,11 +36,22 @@ func TestSHM(t *testing.T) { if !reflect.DeepEqual(data, data2) { t.Fatalf("Could not read back written data: Written data length: %d Read data length: %d", len(data), len(data2)) } - g.Close() - g.Unlink() - _, err = os.Stat(mm.Name()) - if !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("Unlinking %s did not work", mm.Name()) + err = g.Close() + if err != nil { + t.Fatalf("Failed to close with error: %v", err) + } + err = g.Unlink() + if err != nil { + t.Fatalf("Failed to unlink with error: %v", err) + } + g, err = Open(mm.Name(), uint64(len(data))) + if err == nil { + t.Fatalf("Unlinking failed could re-open the SHM data. Data equal: %v Data length: %d", reflect.DeepEqual(g.Slice(), data), len(g.Slice())) + } + if mm.IsFilesystemBacked() { + _, err = os.Stat(mm.Name()) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("Unlinking %s did not work", mm.Name()) + } } - }