Get syscall based SHM working

This commit is contained in:
Kovid Goyal 2022-12-20 20:29:46 +05:30
parent d01d5297b8
commit 7e161ea94b
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 63 additions and 52 deletions

View File

@ -51,34 +51,27 @@ type MMap interface {
Unlink() error
Slice() []byte
Name() string
IsFilesystemBacked() bool
}
type ProtectionFlags int
type AccessFlags int
const (
READ ProtectionFlags = iota
READ AccessFlags = iota
WRITE
COPY
RDWR
EXEC
ANON
)
func mmap(sz int, inprot ProtectionFlags, anonymous bool, fd int, off int64) ([]byte, error) {
func mmap(sz int, access AccessFlags, fd int, off int64) ([]byte, error) {
flags := unix.MAP_SHARED
prot := unix.PROT_READ
switch {
case inprot&COPY != 0:
switch access {
case COPY:
prot |= unix.PROT_WRITE
flags = unix.MAP_PRIVATE
case inprot&RDWR != 0:
case WRITE:
prot |= unix.PROT_WRITE
}
if inprot&EXEC != 0 {
prot |= unix.PROT_EXEC
}
if anonymous {
flags |= unix.MAP_ANON
}
b, err := unix.Mmap(fd, off, sz, prot, flags)
if err != nil {
@ -87,20 +80,24 @@ func mmap(sz int, inprot ProtectionFlags, anonymous bool, fd int, off int64) ([]
return b, nil
}
func munmap(s []byte) error {
return unix.Munmap(s)
}
type file_based_mmap struct {
f *os.File
region []byte
unlinked bool
}
func file_mmap(f *os.File, size uint64, access ProtectionFlags, truncate bool) (MMap, error) {
func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool) (MMap, error) {
if truncate {
err := truncate_or_unlink(f, size)
if err != nil {
return nil, err
}
}
region, err := mmap(int(size), access, false, int(f.Fd()), 0)
region, err := mmap(int(size), access, int(f.Fd()), 0)
if err != nil {
f.Close()
os.Remove(f.Name())
@ -117,9 +114,12 @@ func (self *file_based_mmap) Slice() []byte {
return self.region
}
func (self *file_based_mmap) Close() error {
err := self.f.Close()
self.region = nil
func (self *file_based_mmap) Close() (err error) {
if self.region != nil {
self.f.Close()
err = munmap(self.region)
self.region = nil
}
return err
}
@ -131,6 +131,8 @@ func (self *file_based_mmap) Unlink() (err error) {
return os.Remove(self.f.Name())
}
func (self *file_based_mmap) IsFilesystemBacked() bool { return true }
func CreateTemp(pattern string, size uint64) (MMap, error) {
return create_temp(pattern, size)
}

View File

@ -16,10 +16,10 @@ func create_temp(pattern string, size uint64) (MMap, error) {
if err != nil {
return nil, err
}
return file_mmap(ans, size, RDWR, true)
return file_mmap(ans, size, WRITE, true)
}
func Open(name string) (MMap, error) {
func Open(name string, size uint64) (MMap, error) {
if !filepath.IsAbs(name) {
name = filepath.Join(SHM_DIR, name)
}
@ -27,10 +27,5 @@ func Open(name string) (MMap, error) {
if err != nil {
return nil, err
}
s, err := os.Stat(name)
if err != nil {
ans.Close()
return nil, err
}
return file_mmap(ans, uint64(s.Size()), READ, false)
return file_mmap(ans, size, READ, false)
}

View File

@ -31,7 +31,7 @@ func BytePtrFromString(s string) *byte {
func shm_unlink(name string) (err error) {
bname := BytePtrFromString(name)
for {
_, _, errno := unix.Syscall(unix.SYS_SHM_OPEN, uintptr(unsafe.Pointer(bname)), 0, 0)
_, _, errno := unix.Syscall(unix.SYS_SHM_UNLINK, uintptr(unsafe.Pointer(bname)), 0, 0)
if errno != unix.EINTR {
if errno != 0 {
err = fmt.Errorf("shm_unlink() failed with error: %w", errno)
@ -67,14 +67,14 @@ type syscall_based_mmap struct {
unlinked bool
}
func syscall_mmap(f *os.File, size uint64, access ProtectionFlags, truncate bool) (MMap, error) {
func syscall_mmap(f *os.File, size uint64, access AccessFlags, truncate bool) (MMap, error) {
if truncate {
err := truncate_or_unlink(f, size)
if err != nil {
return nil, fmt.Errorf("truncate failed with error: %w", err)
}
}
region, err := mmap(int(size), access, false, int(f.Fd()), 0)
region, err := mmap(int(size), access, int(f.Fd()), 0)
if err != nil {
f.Close()
os.Remove(f.Name())
@ -91,10 +91,13 @@ func (self *syscall_based_mmap) Slice() []byte {
return self.region
}
func (self *syscall_based_mmap) Close() error {
err := self.f.Close()
self.region = nil
return err
func (self *syscall_based_mmap) Close() (err error) {
if self.region != nil {
self.f.Close()
munmap(self.region)
self.region = nil
}
return
}
func (self *syscall_based_mmap) Unlink() (err error) {
@ -105,6 +108,8 @@ func (self *syscall_based_mmap) Unlink() (err error) {
return shm_unlink(self.Name())
}
func (self *syscall_based_mmap) IsFilesystemBacked() bool { return false }
func create_temp(pattern string, size uint64) (ans MMap, err error) {
var prefix, suffix string
prefix, suffix, err = prefix_and_suffix(pattern)
@ -126,7 +131,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
if err != nil && (errors.Is(err, fs.ErrExist) || errors.Unwrap(err) == unix.EEXIST) {
try += 1
if try > 10000 {
return nil, &os.PathError{Op: "createtemp", Path: prefix + "*" + suffix, Err: ErrExist}
return nil, &os.PathError{Op: "createtemp", Path: prefix + "*" + suffix, Err: fs.ErrExist}
}
continue
}
@ -135,18 +140,13 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
if err != nil {
return nil, err
}
return syscall_mmap(f, size, RDWR, true)
return syscall_mmap(f, size, WRITE, true)
}
func Open(name string) (MMap, error) {
func Open(name string, size uint64) (MMap, error) {
ans, err := shm_open(name, os.O_RDONLY, 0)
if err != nil {
return nil, err
}
s, err := os.Stat(name)
if err != nil {
ans.Close()
return nil, err
}
return syscall_mmap(ans, uint64(s.Size()), READ, false)
return syscall_mmap(ans, size, READ, false)
}

View File

@ -23,9 +23,12 @@ func TestSHM(t *testing.T) {
}
copy(mm.Slice(), data)
mm.Close()
err = mm.Close()
if err != nil {
t.Fatalf("Failed to close with error: %v", err)
}
g, err := Open(mm.Name())
g, err := Open(mm.Name(), uint64(len(data)))
if err != nil {
t.Fatal(err)
}
@ -33,11 +36,22 @@ func TestSHM(t *testing.T) {
if !reflect.DeepEqual(data, data2) {
t.Fatalf("Could not read back written data: Written data length: %d Read data length: %d", len(data), len(data2))
}
g.Close()
g.Unlink()
_, err = os.Stat(mm.Name())
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("Unlinking %s did not work", mm.Name())
err = g.Close()
if err != nil {
t.Fatalf("Failed to close with error: %v", err)
}
err = g.Unlink()
if err != nil {
t.Fatalf("Failed to unlink with error: %v", err)
}
g, err = Open(mm.Name(), uint64(len(data)))
if err == nil {
t.Fatalf("Unlinking failed could re-open the SHM data. Data equal: %v Data length: %d", reflect.DeepEqual(g.Slice(), data), len(g.Slice()))
}
if mm.IsFilesystemBacked() {
_, err = os.Stat(mm.Name())
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("Unlinking %s did not work", mm.Name())
}
}
}