diff --git a/tools/cmd/ssh/askpass.go b/tools/cmd/ssh/askpass.go index edc7bbf31..304d369d7 100644 --- a/tools/cmd/ssh/askpass.go +++ b/tools/cmd/ssh/askpass.go @@ -3,7 +3,6 @@ package ssh import ( - "encoding/binary" "encoding/json" "fmt" "os" @@ -53,40 +52,27 @@ func RunSSHAskpass() { if err != nil { fatal(err) } - shm, err := shm.CreateTemp("askpass-*", uint64(len(data)+32)) + data_shm, err := shm.CreateTemp("askpass-*", uint64(len(data)+32)) if err != nil { fatal(fmt.Errorf("Failed to create SHM file with error: %w", err)) } - defer shm.Close() - defer shm.Unlink() + defer data_shm.Close() + defer data_shm.Unlink() - shm.Slice()[0] = 0 - binary.BigEndian.PutUint32(shm.Slice()[1:], uint32(len(data))) - copy(shm.Slice()[5:], data) - err = shm.Flush() + data_shm.Slice()[0] = 0 + shm.WriteWithSize(data_shm, data, 1) + err = data_shm.Flush() if err != nil { fatal(fmt.Errorf("Failed to flush SHM file with error: %w", err)) } - trigger_ask(shm.Name()) - buf := []byte{0} + trigger_ask(data_shm.Name()) for { time.Sleep(50 * time.Millisecond) - _, err = shm.Seek(0, os.SEEK_SET) - if err != nil { - fatal(fmt.Errorf("Failed to seek into SHM file while waiting for response with error: %w", err)) - } - _, err = shm.Read(buf) - if err != nil { - fatal(fmt.Errorf("Failed to read from SHM file while waiting for response with error: %w", err)) - } - if buf[0] == 1 { + if data_shm.Slice()[0] == 1 { break } } - data, err = shm.ReadWithSize() - if err != nil { - fatal(fmt.Errorf("Failed to read response data from SHM file with error: %w", err)) - } + data = shm.ReadWithSize(data_shm, 1) response := "" if is_confirm { var ok bool diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index e0ddc0010..fb48137f1 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -439,7 +439,7 @@ func bootstrap_script(cd *connection_data) (err error) { if err == nil && !cd.dont_create_shm { data_shm, err = shm.CreateTemp(fmt.Sprintf("kssh-%d-", os.Getpid()), uint64(len(encoded_data)+8)) if err == nil { - err = data_shm.WriteWithSize(encoded_data) + err = shm.WriteWithSize(data_shm, encoded_data, 0) if err == nil { err = data_shm.Flush() } diff --git a/tools/utils/shm/shm.go b/tools/utils/shm/shm.go index c0fa3f110..5f8618e29 100644 --- a/tools/utils/shm/shm.go +++ b/tools/utils/shm/shm.go @@ -53,11 +53,6 @@ type MMap interface { FileSystemName() string Stat() (fs.FileInfo, error) Flush() error - Seek(offset int64, whence int) (int64, error) - Read(b []byte) (int, error) - ReadWithSize() ([]byte, error) - Write(p []byte) (n int, err error) - WriteWithSize([]byte) error } type AccessFlags int @@ -135,14 +130,17 @@ func read_with_size(f *os.File) ([]byte, error) { return read_till_buf_full(f, make([]byte, size)) } -func write_with_size(f *os.File, b []byte) error { +func WriteWithSize(self MMap, b []byte, at int) error { szbuf := []byte{0, 0, 0, 0} binary.BigEndian.PutUint32(szbuf, uint32(len(b))) - _, err := f.Write(szbuf) - if err == nil { - _, err = f.Write(b) - } - return err + copy(self.Slice()[at:], szbuf) + copy(self.Slice()[at+4:], b) + return nil +} + +func ReadWithSize(self MMap, at int) []byte { + size := int(binary.BigEndian.Uint32(self.Slice()[at : at+4])) + return self.Slice()[at+4 : at+4+size] } func test_integration_with_python(args []string) (rc int, err error) { @@ -167,7 +165,7 @@ func test_integration_with_python(args []string) (rc int, err error) { if err != nil { return 1, err } - mmap.WriteWithSize(data) + WriteWithSize(mmap, data, 0) mmap.Close() fmt.Println(mmap.Name()) } diff --git a/tools/utils/shm/shm_fs.go b/tools/utils/shm/shm_fs.go index 15a7d8321..f505cf3ae 100644 --- a/tools/utils/shm/shm_fs.go +++ b/tools/utils/shm/shm_fs.go @@ -57,26 +57,6 @@ func (self *file_based_mmap) Flush() error { return unix.Msync(self.region, unix.MS_SYNC) } -func (self *file_based_mmap) Seek(offset int64, whence int) (int64, error) { - return self.f.Seek(offset, whence) -} - -func (self *file_based_mmap) Read(b []byte) (int, error) { - return self.f.Read(b) -} - -func (self *file_based_mmap) Write(b []byte) (int, error) { - return self.f.Write(b) -} - -func (self *file_based_mmap) WriteWithSize(b []byte) error { - return write_with_size(self.f, b) -} - -func (self *file_based_mmap) ReadWithSize() ([]byte, error) { - return read_with_size(self.f) -} - func (self *file_based_mmap) FileSystemName() string { return self.f.Name() } diff --git a/tools/utils/shm/shm_syscall.go b/tools/utils/shm/shm_syscall.go index 4e0eb8154..b19b4cdda 100644 --- a/tools/utils/shm/shm_syscall.go +++ b/tools/utils/shm/shm_syscall.go @@ -101,30 +101,6 @@ func (self *syscall_based_mmap) Slice() []byte { return self.region } -func (self *syscall_based_mmap) Seek(offset int64, whence int) (int64, error) { - return self.f.Seek(offset, whence) -} - -func (self *syscall_based_mmap) Read(b []byte) (int, error) { - return self.f.Read(b) -} - -func (self *syscall_based_mmap) Write(b []byte) (int, error) { - return self.f.Write(b) -} - -func (self *syscall_based_mmap) WriteWithSize(b []byte) error { - szbuf := []byte{0, 0, 0, 0} - binary.BigEndian.PutUint32(szbuf, uint32(len(b))) - copy(self.Slice(), szbuf) - copy(self.Slice()[4:], b) - return nil -} - -func (self *syscall_based_mmap) ReadWithSize() ([]byte, error) { - return read_with_size(self.f) -} - func (self *syscall_based_mmap) Close() (err error) { if self.region != nil { self.f.Close()