Move SSH askpass implementation into kitten

This commit is contained in:
Kovid Goyal 2023-02-22 07:15:18 +05:30
parent 6f4d89045a
commit d656017f27
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
15 changed files with 293 additions and 71 deletions

View File

@ -452,6 +452,7 @@ type VersionType struct {{
const VersionString string = "{kc.str_version}" const VersionString string = "{kc.str_version}"
const WebsiteBaseURL string = "{kc.website_base_url}" const WebsiteBaseURL string = "{kc.website_base_url}"
const VCSRevision string = "" const VCSRevision string = ""
const SSHControlMasterTemplate = "{kc.ssh_control_master_template}"
const RC_ENCRYPTION_PROTOCOL_VERSION string = "{kc.RC_ENCRYPTION_PROTOCOL_VERSION}" const RC_ENCRYPTION_PROTOCOL_VERSION string = "{kc.RC_ENCRYPTION_PROTOCOL_VERSION}"
const IsFrozenBuild bool = false const IsFrozenBuild bool = false
const IsStandaloneBuild bool = false const IsStandaloneBuild bool = false

View File

@ -67,7 +67,7 @@ class TestBuild(BaseTest):
q = stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH q = stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
return mode & q == q return mode & q == q
for x in ('kitty', 'kitten', 'askpass.py'): for x in ('kitty', 'kitten'):
x = os.path.join(shell_integration_dir, 'ssh', x) x = os.path.join(shell_integration_dir, 'ssh', x)
self.assertTrue(is_executable(x), f'{x} is not executable') self.assertTrue(is_executable(x), f'{x} is not executable')
if getattr(sys, 'frozen', False): if getattr(sys, 'frozen', False):

View File

@ -1,5 +1,5 @@
[tool.mypy] [tool.mypy]
files = 'kitty,kittens,glfw,*.py,docs/conf.py,shell-integration/ssh/askpass.py' files = 'kitty,kittens,glfw,*.py,docs/conf.py'
no_implicit_optional = true no_implicit_optional = true
sqlite_cache = true sqlite_cache = true
cache_fine_grained = true cache_fine_grained = true

View File

@ -1459,7 +1459,7 @@ def package(args: Options, bundle_type: str) -> None:
if path.endswith('.so'): if path.endswith('.so'):
return True return True
q = path.split(os.sep)[-2:] q = path.split(os.sep)[-2:]
if len(q) == 2 and q[0] == 'ssh' and q[1] in ('askpass.py', 'kitty', 'kitten'): if len(q) == 2 and q[0] == 'ssh' and q[1] in ('kitty', 'kitten'):
return True return True
return False return False

View File

@ -1,46 +0,0 @@
#!/usr/bin/env -S kitty +launch
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import json
import os
import sys
import time
from kitty.shm import SharedMemory
msg = sys.argv[-1]
prompt = os.environ.get('SSH_ASKPASS_PROMPT', '')
is_confirm = prompt == 'confirm'
is_fingerprint_check = '(yes/no/[fingerprint])' in msg
q = {
'message': msg,
'type': 'confirm' if is_confirm else 'get_line',
'is_password': not is_fingerprint_check,
}
data = json.dumps(q)
with SharedMemory(
size=len(data) + 1 + SharedMemory.num_bytes_for_size, unlink_on_exit=True, prefix=f'askpass-{os.getpid()}-') as shm, \
open(os.ctermid(), 'wb') as tty:
shm.write(b'\0')
shm.write_data_with_size(data)
shm.flush()
with open(os.ctermid(), 'wb') as f:
f.write(f'\x1bP@kitty-ask|{shm.name}\x1b\\'.encode('ascii'))
f.flush()
while True:
# TODO: Replace sleep() with a mutex and condition variable created in the shared memory
time.sleep(0.05)
shm.seek(0)
if shm.read(1) == b'\x01':
break
response = json.loads(shm.read_data_with_size())
if is_confirm:
response = 'yes' if response else 'no'
elif is_fingerprint_check:
if response.lower() in ('y', 'yes'):
response = 'yes'
if response.lower() in ('n', 'no'):
response = 'no'
if response:
print(response, flush=True)

View File

@ -3,12 +3,22 @@
package main package main
import ( import (
"os"
"kitty/tools/cli" "kitty/tools/cli"
"kitty/tools/cmd/completion" "kitty/tools/cmd/completion"
"kitty/tools/cmd/ssh"
"kitty/tools/cmd/tool" "kitty/tools/cmd/tool"
) )
func main() { func main() {
krm := os.Getenv("KITTY_KITTEN_RUN_MODULE")
os.Unsetenv("KITTY_KITTEN_RUN_MODULE")
switch krm {
case "ssh_askpass":
ssh.RunSSHAskpass()
return
}
root := cli.NewRootCommand() root := cli.NewRootCommand()
root.ShortDescription = "Fast, statically compiled implementations for various kittens (command line tools for use with kitty)" root.ShortDescription = "Fast, statically compiled implementations for various kittens (command line tools for use with kitty)"
root.Usage = "command [command options] [command args]" root.Usage = "command [command options] [command args]"

118
tools/cmd/ssh/askpass.go Normal file
View File

@ -0,0 +1,118 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"encoding/binary"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"kitty/tools/cli"
"kitty/tools/tty"
"kitty/tools/utils/shm"
)
var _ = fmt.Print
func fatal(err error) {
cli.ShowError(err)
os.Exit(1)
}
func trigger_ask(name string) {
term, err := tty.OpenControllingTerm()
if err != nil {
fatal(err)
}
defer term.Close()
_, err = term.WriteString("\x1bP@kitty-ask|" + name + "\x1b\\")
if err != nil {
fatal(err)
}
}
func RunSSHAskpass() {
msg := os.Args[len(os.Args)-1]
prompt := os.Getenv("SSH_ASKPASS_PROMPT")
is_confirm := prompt == "confirm"
q_type := "get_line"
if is_confirm {
q_type = "confirm"
}
is_fingerprint_check := strings.Contains(msg, "(yes/no/[fingerprint])")
q := map[string]any{
"message": msg,
"type": q_type,
"is_password": !is_fingerprint_check,
}
data, err := json.Marshal(q)
if err != nil {
fatal(err)
}
shm, err := shm.CreateTemp("askpass-*", uint64(len(data)+32))
if err != nil {
fatal(fmt.Errorf("Failed to create SHM file with error: %w", err))
}
defer shm.Close()
defer shm.Unlink()
shm.Slice()[0] = 0
binary.BigEndian.PutUint32(shm.Slice()[1:], uint32(len(data)))
copy(shm.Slice()[5:], data)
err = shm.Flush()
if err != nil {
fatal(fmt.Errorf("Failed to flush SHM file with error: %w", err))
}
trigger_ask(shm.Name())
buf := []byte{0}
for {
time.Sleep(50 * time.Millisecond)
_, err = shm.Seek(0, os.SEEK_SET)
if err != nil {
fatal(fmt.Errorf("Failed to seek into SHM file while waiting for response with error: %w", err))
}
_, err = shm.Read(buf)
if err != nil {
fatal(fmt.Errorf("Failed to read from SHM file while waiting for response with error: %w", err))
}
if buf[0] == 1 {
break
}
}
data, err = shm.ReadWithSize()
if err != nil {
fatal(fmt.Errorf("Failed to read response data from SHM file with error: %w", err))
}
response := ""
if is_confirm {
var ok bool
err = json.Unmarshal(data, &ok)
if err != nil {
fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err))
}
response = "no"
if ok {
response = "yes"
}
} else {
err = json.Unmarshal(data, &response)
if err != nil {
fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err))
}
if is_fingerprint_check {
response = strings.ToLower(response)
if response == "y" {
response = "yes"
} else if response == "n" {
response = "no"
}
}
}
if response != "" {
fmt.Println(response)
}
}

View File

@ -7,16 +7,22 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"kitty"
"net/url" "net/url"
"os" "os"
"os/exec"
"os/user" "os/user"
"path/filepath"
"strconv"
"strings" "strings"
"kitty/tools/cli" "kitty/tools/cli"
"kitty/tools/tty" "kitty/tools/tty"
"kitty/tools/utils"
"kitty/tools/utils/shm" "kitty/tools/utils/shm"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -102,8 +108,58 @@ func parse_kitten_args(found_extra_args []string, username, hostname_for_match s
return return
} }
func connection_sharing_args(kitty_pid int) ([]string, error) {
rd := utils.RuntimeDir()
// Bloody OpenSSH generates a 40 char hash and in creating the socket
// appends a 27 char temp suffix to it. Socket max path length is approx
// ~104 chars. And on idiotic Apple the path length to the runtime dir
// (technically the cache dir since Apple has no runtime dir and thinks it's
// a great idea to delete files in /tmp) is ~48 chars.
if len(rd) > 35 {
idiotic_design := fmt.Sprintf("/tmp/kssh-rdir-%d", os.Geteuid())
if err := utils.AtomicCreateSymlink(rd, idiotic_design); err != nil {
return nil, err
}
rd = idiotic_design
}
cp := strings.Replace(kitty.SSHControlMasterTemplate, "{kitty_pid}", strconv.Itoa(kitty_pid), 1)
cp = strings.Replace(cp, "{ssh_placeholder}", "%C", 1)
return []string{
"-o", "ControlMaster=auto",
"-o", "ControlPath=" + cp,
"-o", "ControlPersist=yes",
"-o", "ServerAliveInterval=60",
"-o", "ServerAliveCountMax=5",
"-o", "TCPKeepAlive=no",
}, nil
}
func set_askpass() (need_to_request_data bool) {
need_to_request_data = true
sentinel := filepath.Join(utils.CacheDir(), "openssh-is-new-enough-for-askpass")
_, err := os.Stat(sentinel)
sentinel_exists := err == nil
if sentinel_exists || GetSSHVersion().SupportsAskpassRequire() {
if !sentinel_exists {
os.WriteFile(sentinel, []byte{0}, 0o644)
}
need_to_request_data = false
}
exe, err := os.Executable()
if err == nil {
os.Setenv("SSH_ASKPASS", exe)
os.Setenv("KITTY_KITTEN_RUN_MODULE", "ssh_askpass")
if !need_to_request_data {
os.Setenv("SSH_ASKPASS_REQUIRE", "force")
}
} else {
need_to_request_data = true
}
return
}
func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) { func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) {
cmd := append([]string{ssh_exe()}, ssh_args...) cmd := append([]string{SSHExe()}, ssh_args...)
hostname, remote_args := server_args[0], server_args[1:] hostname, remote_args := server_args[0], server_args[1:]
if len(remote_args) == 0 { if len(remote_args) == 0 {
cmd = append(cmd, "-t") cmd = append(cmd, "-t")
@ -115,10 +171,35 @@ func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err erro
if err != nil { if err != nil {
return 1, err return 1, err
} }
if insertion_point > 0 && overrides != nil && literal_env != nil { host_opts, err := load_config(hostname_for_match, uname, overrides)
if err != nil {
return 1, err
} }
// TODO: Implement me if host_opts.Share_connections {
return kpid, err := strconv.Atoi(os.Getenv("KITTY_PID"))
if err != nil {
return 1, fmt.Errorf("Invalid KITTY_PID env var not an integer: %#v", os.Getenv("KITTY_PID"))
}
cpargs, err := connection_sharing_args(kpid)
if err != nil {
return 1, err
}
cmd = slices.Insert(cmd, insertion_point, cpargs...)
}
use_kitty_askpass := host_opts.Askpass == Askpass_native || (host_opts.Askpass == Askpass_unless_set && os.Getenv("SSH_ASKPASS") == "")
need_to_request_data := true
if use_kitty_askpass {
need_to_request_data = set_askpass()
}
if need_to_request_data && host_opts.Share_connections {
check_cmd := slices.Insert(cmd, 1, "-O", "check")
err = exec.Command(check_cmd[0], check_cmd[1:]...).Run()
if err == nil {
need_to_request_data = false
}
}
_ = literal_env
return 0, nil
} }
func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
@ -139,7 +220,7 @@ func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
if invargs.Msg != "" { if invargs.Msg != "" {
fmt.Fprintln(os.Stderr, invargs.Msg) fmt.Fprintln(os.Stderr, invargs.Msg)
} }
return 1, unix.Exec(ssh_exe(), []string{"ssh"}, os.Environ()) return 1, unix.Exec(SSHExe(), []string{"ssh"}, os.Environ())
} }
return 1, err return 1, err
} }
@ -147,7 +228,7 @@ func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
if len(found_extra_args) > 0 { if len(found_extra_args) > 0 {
return 1, fmt.Errorf("The SSH kitten cannot work with the options: %s", strings.Join(maps.Keys(PassthroughArgs()), " ")) return 1, fmt.Errorf("The SSH kitten cannot work with the options: %s", strings.Join(maps.Keys(PassthroughArgs()), " "))
} }
return 1, unix.Exec(ssh_exe(), append([]string{"ssh"}, args...), os.Environ()) return 1, unix.Exec(SSHExe(), append([]string{"ssh"}, args...), os.Environ())
} }
if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" { if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" {
return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window") return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window")

View File

@ -7,28 +7,26 @@ import (
"io" "io"
"kitty/tools/utils" "kitty/tools/utils"
"os/exec" "os/exec"
"regexp"
"strconv"
"strings" "strings"
"sync"
) )
var _ = fmt.Print var _ = fmt.Print
var ssh_options map[string]string var SSHExe = (&utils.Once[string]{Run: func() string {
var query_ssh_for_options_once sync.Once
func ssh_exe() string {
ans := utils.Which("ssh") ans := utils.Which("ssh")
if ans != "" { if ans != "" {
return ans return ans
} }
ans = utils.Which("ssh", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin") ans = utils.Which("ssh", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin", "/usr/sbin", "/sbin")
if ans == "" { if ans == "" {
ans = "ssh" ans = "ssh"
} }
return ans return ans
} }}).Get
func get_ssh_options() { var SSHOptions = (&utils.Once[map[string]string]{Run: func() (ssh_options map[string]string) {
defer func() { defer func() {
if ssh_options == nil { if ssh_options == nil {
ssh_options = map[string]string{ ssh_options = map[string]string{
@ -42,7 +40,7 @@ func get_ssh_options() {
} }
} }
}() }()
cmd := exec.Command(ssh_exe()) cmd := exec.Command(SSHExe())
stderr, err := cmd.StderrPipe() stderr, err := cmd.StderrPipe()
if err != nil { if err != nil {
return return
@ -86,12 +84,8 @@ func get_ssh_options() {
} }
} }
} }
} return
}}).Get
func SSHOptions() map[string]string {
query_ssh_for_options_once.Do(get_ssh_options)
return ssh_options
}
func GetSSHCLI() (boolean_ssh_args *utils.Set[string], other_ssh_args *utils.Set[string]) { func GetSSHCLI() (boolean_ssh_args *utils.Set[string], other_ssh_args *utils.Set[string]) {
other_ssh_args, boolean_ssh_args = utils.NewSet[string](32), utils.NewSet[string](32) other_ssh_args, boolean_ssh_args = utils.NewSet[string](32), utils.NewSet[string](32)
@ -205,3 +199,23 @@ func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, serve
} }
return return
} }
type SSHVersion struct{ Major, Minor int }
func (self SSHVersion) SupportsAskpassRequire() bool {
return self.Major > 8 || (self.Major == 8 && self.Minor >= 4)
}
var GetSSHVersion = (&utils.Once[SSHVersion]{Run: func() SSHVersion {
b, err := exec.Command(SSHExe(), "-V").CombinedOutput()
if err != nil {
return SSHVersion{}
}
m := regexp.MustCompile(`OpenSSH_(\d+).(\d+)`).FindSubmatch(b)
if len(m) == 3 {
maj, _ := strconv.Atoi(utils.UnsafeBytesToString(m[1]))
min, _ := strconv.Atoi(utils.UnsafeBytesToString(m[2]))
return SSHVersion{Major: maj, Minor: min}
}
return SSHVersion{}
}}).Get

View File

@ -50,5 +50,4 @@ func TestParseSSHArgs(t *testing.T) {
p(`-46p23 localhost sh -c "a b"`, `-4 -6 -p 23`, `localhost sh -c "a b"`, ``, false) p(`-46p23 localhost sh -c "a b"`, `-4 -6 -p 23`, `localhost sh -c "a b"`, ``, false)
p(`-46p23 -S/moose -W x:6 -- localhost sh -c "a b"`, `-4 -6 -p 23 -S /moose -W x:6`, `localhost sh -c "a b"`, ``, false) p(`-46p23 -S/moose -W x:6 -- localhost sh -c "a b"`, `-4 -6 -p 23 -S /moose -W x:6`, `localhost sh -c "a b"`, ``, false)
p(`--kitten=abc -np23 --kitten xyz host`, `-n -p 23`, `host`, `--kitten abc --kitten xyz`, true) p(`--kitten=abc -np23 --kitten xyz host`, `-n -p 23`, `host`, `--kitten abc --kitten xyz`, true)
} }

View File

@ -20,6 +20,9 @@ func AtomicCreateSymlink(oldname, newname string) (err error) {
if !errors.Is(err, fs.ErrExist) { if !errors.Is(err, fs.ErrExist) {
return err return err
} }
if et, err := os.Readlink(newname); err == nil && et == oldname {
return nil
}
for { for {
tempname := newname + RandomFilename() tempname := newname + RandomFilename()
err = os.Symlink(oldname, tempname) err = os.Symlink(oldname, tempname)

View File

@ -52,6 +52,10 @@ type MMap interface {
IsFileSystemBacked() bool IsFileSystemBacked() bool
FileSystemName() string FileSystemName() string
Stat() (fs.FileInfo, error) Stat() (fs.FileInfo, error)
Flush() error
Seek(offset int64, whence int) (int64, error)
Read(b []byte) (int, error)
ReadWithSize() ([]byte, error)
} }
type AccessFlags int type AccessFlags int

View File

@ -13,6 +13,8 @@ import (
"runtime" "runtime"
"kitty/tools/utils" "kitty/tools/utils"
"golang.org/x/sys/unix"
) )
var _ = fmt.Print var _ = fmt.Print
@ -51,6 +53,22 @@ func (self *file_based_mmap) Name() string {
return filepath.Base(self.f.Name()) return filepath.Base(self.f.Name())
} }
func (self *file_based_mmap) Flush() error {
return unix.Msync(self.region, unix.MS_SYNC)
}
func (self *file_based_mmap) Seek(offset int64, whence int) (int64, error) {
return self.f.Seek(offset, whence)
}
func (self *file_based_mmap) Read(b []byte) (int, error) {
return self.f.Read(b)
}
func (self *file_based_mmap) ReadWithSize() ([]byte, error) {
return read_with_size(self.f)
}
func (self *file_based_mmap) FileSystemName() string { func (self *file_based_mmap) FileSystemName() string {
return self.f.Name() return self.f.Name()
} }

View File

@ -92,10 +92,26 @@ func (self *syscall_based_mmap) Stat() (fs.FileInfo, error) {
return self.f.Stat() return self.f.Stat()
} }
func (self *syscall_based_mmap) Flush() error {
return unix.Msync(self.region, unix.MS_SYNC)
}
func (self *syscall_based_mmap) Slice() []byte { func (self *syscall_based_mmap) Slice() []byte {
return self.region return self.region
} }
func (self *syscall_based_mmap) Seek(offset int64, whence int) (int64, error) {
return self.f.Seek(offset, whence)
}
func (self *syscall_based_mmap) Read(b []byte) (int, error) {
return self.f.Read(b)
}
func (self *syscall_based_mmap) ReadWithSize() ([]byte, error) {
return read_with_size(self.f)
}
func (self *syscall_based_mmap) Close() (err error) { func (self *syscall_based_mmap) Close() (err error) {
if self.region != nil { if self.region != nil {
self.f.Close() self.f.Close()

View File

@ -23,6 +23,10 @@ func TestSHM(t *testing.T) {
} }
copy(mm.Slice(), data) copy(mm.Slice(), data)
err = mm.Flush()
if err != nil {
t.Fatalf("Failed to msync() with error: %v", err)
}
err = mm.Close() err = mm.Close()
if err != nil { if err != nil {
t.Fatalf("Failed to close with error: %v", err) t.Fatalf("Failed to close with error: %v", err)