DRYer
This commit is contained in:
parent
1b2fe90ed1
commit
944e036611
@ -72,7 +72,10 @@ func RunSSHAskpass() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data = shm.ReadWithSize(data_shm, 1)
|
data, err = shm.ReadWithSize(data_shm, 1)
|
||||||
|
if err != nil {
|
||||||
|
fatal(fmt.Errorf("Failed to read from SHM file with error: %w", err))
|
||||||
|
}
|
||||||
response := ""
|
response := ""
|
||||||
if is_confirm {
|
if is_confirm {
|
||||||
var ok bool
|
var ok bool
|
||||||
|
|||||||
@ -65,11 +65,7 @@ func get_destination(hostname string) (username, hostname_for_match string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func read_data_from_shared_memory(shm_name string) ([]byte, error) {
|
func read_data_from_shared_memory(shm_name string) ([]byte, error) {
|
||||||
data, err := shm.ReadWithSizeAndUnlink(shm_name, func(f *os.File) error {
|
data, err := shm.ReadWithSizeAndUnlink(shm_name, func(s fs.FileInfo) error {
|
||||||
s, err := f.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Failed to stat SHM file with error: %w", err)
|
|
||||||
}
|
|
||||||
if stat, ok := s.Sys().(unix.Stat_t); ok {
|
if stat, ok := s.Sys().(unix.Stat_t); ok {
|
||||||
if os.Getuid() != int(stat.Uid) || os.Getgid() != int(stat.Gid) {
|
if os.Getuid() != int(stat.Uid) || os.Getgid() != int(stat.Gid) {
|
||||||
return fmt.Errorf("Incorrect owner on SHM file")
|
return fmt.Errorf("Incorrect owner on SHM file")
|
||||||
|
|||||||
@ -104,43 +104,58 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) {
|
const NUM_BYTES_FOR_SIZE = 4
|
||||||
p := buf
|
|
||||||
for len(p) > 0 {
|
|
||||||
n, err := f.Read(p)
|
|
||||||
p = p[n:]
|
|
||||||
if err != nil {
|
|
||||||
if len(p) == 0 && errors.Is(err, io.EOF) {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
err = fmt.Errorf("Failed to read from SHM file with error: %w", err)
|
|
||||||
return buf[:len(buf)-len(p)], err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func read_with_size(f *os.File) ([]byte, error) {
|
|
||||||
szbuf := []byte{0, 0, 0, 0}
|
|
||||||
szbuf, err := read_till_buf_full(f, szbuf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
size := int(binary.BigEndian.Uint32(szbuf))
|
|
||||||
return read_till_buf_full(f, make([]byte, size))
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteWithSize(self MMap, b []byte, at int) error {
|
func WriteWithSize(self MMap, b []byte, at int) error {
|
||||||
szbuf := []byte{0, 0, 0, 0}
|
if len(self.Slice()) < at+len(b)+NUM_BYTES_FOR_SIZE {
|
||||||
binary.BigEndian.PutUint32(szbuf, uint32(len(b)))
|
return io.ErrShortBuffer
|
||||||
copy(self.Slice()[at:], szbuf)
|
}
|
||||||
copy(self.Slice()[at+4:], b)
|
binary.BigEndian.PutUint32(self.Slice()[at:], uint32(len(b)))
|
||||||
|
copy(self.Slice()[at+NUM_BYTES_FOR_SIZE:], b)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadWithSize(self MMap, at int) []byte {
|
func ReadWithSize(self MMap, at int) ([]byte, error) {
|
||||||
size := int(binary.BigEndian.Uint32(self.Slice()[at : at+4]))
|
s := self.Slice()[at:]
|
||||||
return self.Slice()[at+4 : at+4+size]
|
if len(s) < NUM_BYTES_FOR_SIZE {
|
||||||
|
return nil, io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
size := int(binary.BigEndian.Uint32(self.Slice()[at : at+NUM_BYTES_FOR_SIZE]))
|
||||||
|
s = s[NUM_BYTES_FOR_SIZE:]
|
||||||
|
if len(s) < size {
|
||||||
|
return nil, io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
return s[:size], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadWithSizeAndUnlink(name string, file_callback ...func(fs.FileInfo) error) ([]byte, error) {
|
||||||
|
mmap, err := Open(name, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(file_callback) > 0 {
|
||||||
|
s, err := mmap.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err)
|
||||||
|
}
|
||||||
|
for _, f := range file_callback {
|
||||||
|
err = f(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
mmap.Close()
|
||||||
|
mmap.Unlink()
|
||||||
|
}()
|
||||||
|
slice, err := ReadWithSize(mmap, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ans := make([]byte, len(slice))
|
||||||
|
copy(ans, slice)
|
||||||
|
return ans, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func test_integration_with_python(args []string) (rc int, err error) {
|
func test_integration_with_python(args []string) (rc int, err error) {
|
||||||
@ -161,7 +176,7 @@ func test_integration_with_python(args []string) (rc int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 1, err
|
return 1, err
|
||||||
}
|
}
|
||||||
mmap, err := CreateTemp("shmtest-", uint64(len(data)+4))
|
mmap, err := CreateTemp("shmtest-", uint64(len(data)+NUM_BYTES_FOR_SIZE))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 1, err
|
return 1, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -142,21 +142,13 @@ func Open(name string, size uint64) (MMap, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if size == 0 {
|
||||||
|
s, err := ans.Stat()
|
||||||
|
if err != nil {
|
||||||
|
ans.Close()
|
||||||
|
return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err)
|
||||||
|
}
|
||||||
|
size = uint64(s.Size())
|
||||||
|
}
|
||||||
return file_mmap(ans, size, READ, false, name)
|
return file_mmap(ans, size, READ, false, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) {
|
|
||||||
f, err := open(name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
defer os.Remove(f.Name())
|
|
||||||
for _, cb := range file_callback {
|
|
||||||
err = cb(f)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return read_with_size(f)
|
|
||||||
}
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
package shm
|
package shm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
@ -159,23 +158,13 @@ func Open(name string, size uint64) (MMap, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if size == 0 {
|
||||||
|
s, err := ans.Stat()
|
||||||
|
if err != nil {
|
||||||
|
ans.Close()
|
||||||
|
return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err)
|
||||||
|
}
|
||||||
|
size = uint64(s.Size())
|
||||||
|
}
|
||||||
return syscall_mmap(ans, size, READ, false)
|
return syscall_mmap(ans, size, READ, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) {
|
|
||||||
mmap, err := Open(name, 4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
size := uint64(binary.BigEndian.Uint32(mmap.Slice()))
|
|
||||||
mmap.Close()
|
|
||||||
mmap, err = Open(name, 4+size)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ans := make([]byte, size)
|
|
||||||
copy(ans, mmap.Slice()[4:])
|
|
||||||
mmap.Close()
|
|
||||||
mmap.Unlink()
|
|
||||||
return ans, nil
|
|
||||||
}
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user