Merge branch 'ssh'

This commit is contained in:
Kovid Goyal 2023-02-28 12:45:51 +05:30
commit 9135ba138e
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
69 changed files with 4521 additions and 1567 deletions

View File

@ -62,6 +62,9 @@ Detailed list of changes
- macOS: Fix the maximized window not taking up full space when the title bar is hidden or when :opt:`resize_in_steps` is configured (:iss:`6021`)
- ssh kitten: Change the syntax of glob patterns slightly to match common usage
elsewhere. Now the syntax is the same a "extendedglob" in most shells.
0.27.1 [2023-02-07]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -226,15 +226,16 @@ def commit_role(
# CLI docs {{{
def write_cli_docs(all_kitten_names: Iterable[str]) -> None:
from kittens.ssh.copy import option_text
from kittens.ssh.options.definition import copy_message
from kittens.ssh.main import copy_message, option_text
from kitty.cli import option_spec_as_rst
from kitty.launch import options_spec as launch_options_spec
with open('generated/ssh-copy.rst', 'w') as f:
f.write(option_spec_as_rst(
appname='copy', ospec=option_text, heading_char='^',
usage='file-or-dir-to-copy ...', message=copy_message
))
del sys.modules['kittens.ssh.main']
from kitty.launch import options_spec as launch_options_spec
with open('generated/launch.rst', 'w') as f:
f.write(option_spec_as_rst(
appname='launch', ospec=launch_options_spec, heading_char='_',
@ -525,9 +526,9 @@ def write_conf_docs(app: Any, all_kitten_names: Iterable[str]) -> None:
from kittens.runner import get_kitten_conf_docs
for kitten in all_kitten_names:
definition = get_kitten_conf_docs(kitten)
if definition:
generate_default_config(definition, f'kitten-{kitten}')
defn = get_kitten_conf_docs(kitten)
if defn is not None:
generate_default_config(defn, f'kitten-{kitten}')
from kitty.actions import as_rst
with open('generated/actions.rst', 'w', encoding='utf-8') as f:

View File

@ -51,8 +51,6 @@ def main() -> None:
from kittens.diff.options.definition import definition as kd
write_output('kittens.diff', kd)
from kittens.ssh.options.definition import definition as sd
write_output('kittens.ssh', sd)
if __name__ == '__main__':

View File

@ -1,15 +1,17 @@
#!./kitty/launcher/kitty +launch
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import bz2
import io
import json
import os
import struct
import subprocess
import sys
import zlib
import tarfile
from contextlib import contextmanager, suppress
from functools import lru_cache
from itertools import chain
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Set, TextIO, Tuple, Union
import kitty.constants as kc
@ -22,6 +24,8 @@ from kitty.cli import (
parse_option_spec,
serialize_as_go_string,
)
from kitty.conf.generate import gen_go_code
from kitty.conf.types import Definition
from kitty.guess_mime_type import text_mimes
from kitty.key_encoding import config_mod_map
from kitty.key_names import character_key_name_aliases, functional_key_name_aliases
@ -38,7 +42,7 @@ def newer(dest: str, *sources: str) -> bool:
dtime = os.path.getmtime(dest)
except OSError:
return True
for s in sources:
for s in chain(sources, (__file__,)):
with suppress(FileNotFoundError):
if os.path.getmtime(s) >= dtime:
return True
@ -318,8 +322,45 @@ def wrapped_kittens() -> Sequence[str]:
raise Exception('Failed to read wrapped kittens from kitty wrapper script')
def generate_conf_parser(kitten: str, defn: Definition) -> None:
with replace_if_needed(f'tools/cmd/{kitten}/conf_generated.go'):
print(f'package {kitten}')
print(gen_go_code(defn))
def generate_extra_cli_parser(name: str, spec: str) -> None:
print('import "kitty/tools/cli"')
go_opts = tuple(go_options_for_seq(parse_option_spec(spec)[0]))
print(f'type {name}_options struct ''{')
for opt in go_opts:
print(opt.struct_declaration())
print('}')
print(f'func parse_{name}_args(args []string) (*{name}_options, []string, error) ''{')
print(f'root := cli.Command{{Name: `{name}` }}')
for opt in go_opts:
print(opt.as_option('root'))
print('cmd, err := root.ParseArgs(args)')
print('if err != nil { return nil, nil, err }')
print(f'var opts {name}_options')
print('err = cmd.GetOptionValues(&opts)')
print('if err != nil { return nil, nil, err }')
print('return &opts, cmd.Args, nil')
print('}')
def kitten_clis() -> None:
from kittens.runner import get_kitten_conf_docs, get_kitten_extra_cli_parsers
for kitten in wrapped_kittens():
defn = get_kitten_conf_docs(kitten)
if defn is not None:
generate_conf_parser(kitten, defn)
ecp = get_kitten_extra_cli_parsers(kitten)
if ecp:
for name, spec in ecp.items():
with replace_if_needed(f'tools/cmd/{kitten}/{name}_cli_generated.go'):
print(f'package {kitten}')
generate_extra_cli_parser(name, spec)
with replace_if_needed(f'tools/cmd/{kitten}/cli_generated.go'):
od = []
kcd = kitten_cli_docs(kitten)
@ -329,10 +370,11 @@ def kitten_clis() -> None:
print('func create_cmd(root *cli.Command, run_func func(*cli.Command, *Options, []string)(int, error)) {')
print('ans := root.AddSubCommand(&cli.Command{')
print(f'Name: "{kitten}",')
print(f'ShortDescription: "{serialize_as_go_string(kcd["short_desc"])}",')
if kcd['usage']:
print(f'Usage: "[options] {serialize_as_go_string(kcd["usage"])}",')
print(f'HelpText: "{serialize_as_go_string(kcd["help_text"])}",')
if kcd:
print(f'ShortDescription: "{serialize_as_go_string(kcd["short_desc"])}",')
if kcd['usage']:
print(f'Usage: "[options] {serialize_as_go_string(kcd["usage"])}",')
print(f'HelpText: "{serialize_as_go_string(kcd["help_text"])}",')
print('Run: func(cmd *cli.Command, args []string) (int, error) {')
print('opts := Options{}')
print('err := cmd.GetOptionValues(&opts)')
@ -351,6 +393,8 @@ def kitten_clis() -> None:
print("clone := root.AddClone(ans.Group, ans)")
print('clone.Hidden = false')
print(f'clone.Name = "{serialize_as_go_string(kitten.replace("_", "-"))}"')
if not kcd:
print('specialize_command(ans)')
print('}')
print('type Options struct {')
print('\n'.join(od))
@ -383,11 +427,24 @@ def generate_spinners() -> str:
def generate_color_names() -> str:
selfg = "" if Options.selection_foreground is None else Options.selection_foreground.as_sharp
selbg = "" if Options.selection_background is None else Options.selection_background.as_sharp
cursor = "" if Options.cursor is None else Options.cursor.as_sharp
return 'package style\n\nvar ColorNames = map[string]RGBA{' + '\n'.join(
f'\t"{name}": RGBA{{ Red:{val.red}, Green:{val.green}, Blue:{val.blue} }},'
for name, val in color_names.items()
) + '\n}' + '\n\nvar ColorTable = [256]uint32{' + ', '.join(
f'{x}' for x in Options.color_table) + '}\n'
f'{x}' for x in Options.color_table) + '}\n' + f'''
var DefaultColors = struct {{
Foreground, Background, Cursor, SelectionFg, SelectionBg string
}}{{
Foreground: "{Options.foreground.as_sharp}",
Background: "{Options.background.as_sharp}",
Cursor: "{cursor}",
SelectionFg: "{selfg}",
SelectionBg: "{selbg}",
}}
'''
def load_ref_map() -> Dict[str, Dict[str, str]]:
@ -399,6 +456,8 @@ def load_ref_map() -> Dict[str, Dict[str, str]]:
def generate_constants() -> str:
from kitty.options.types import Options
from kitty.options.utils import allowed_shell_integration_values
ref_map = load_ref_map()
dp = ", ".join(map(lambda x: f'"{serialize_as_go_string(x)}"', kc.default_pager_for_help))
return f'''\
@ -410,6 +469,7 @@ type VersionType struct {{
const VersionString string = "{kc.str_version}"
const WebsiteBaseURL string = "{kc.website_base_url}"
const VCSRevision string = ""
const SSHControlMasterTemplate = "{kc.ssh_control_master_template}"
const RC_ENCRYPTION_PROTOCOL_VERSION string = "{kc.RC_ENCRYPTION_PROTOCOL_VERSION}"
const IsFrozenBuild bool = false
const IsStandaloneBuild bool = false
@ -421,6 +481,12 @@ var CharacterKeyNameAliases = map[string]string{serialize_go_dict(character_key_
var ConfigModMap = map[string]uint16{serialize_go_dict(config_mod_map)}
var RefMap = map[string]string{serialize_go_dict(ref_map['ref'])}
var DocTitleMap = map[string]string{serialize_go_dict(ref_map['doc'])}
var AllowedShellIntegrationValues = []string{{ {str(sorted(allowed_shell_integration_values))[1:-1].replace("'", '"')} }}
var KittyConfigDefaults = struct {{
Term, Shell_integration string
}}{{
Term: "{Options.term}", Shell_integration: "{' '.join(Options.shell_integration)}",
}}
''' # }}}
@ -598,6 +664,11 @@ def generate_textual_mimetypes() -> str:
return '\n'.join(ans)
def write_compressed_data(data: bytes, d: BinaryIO) -> None:
d.write(struct.pack('<I', len(data)))
d.write(bz2.compress(data))
def generate_unicode_names(src: TextIO, dest: BinaryIO) -> None:
num_names, num_of_words = map(int, next(src).split())
gob = io.BytesIO()
@ -612,9 +683,31 @@ def generate_unicode_names(src: TextIO, dest: BinaryIO) -> None:
if aliases:
record += aliases.encode()
gob.write(struct.pack('<H', len(record)) + record)
data = gob.getvalue()
dest.write(struct.pack('<I', len(data)))
dest.write(zlib.compress(data, zlib.Z_BEST_COMPRESSION))
write_compressed_data(gob.getvalue(), dest)
def generate_ssh_kitten_data() -> None:
files = {
'terminfo/kitty.terminfo', 'terminfo/x/xterm-kitty',
}
for dirpath, dirnames, filenames in os.walk('shell-integration'):
for f in filenames:
path = os.path.join(dirpath, f)
files.add(path.replace(os.sep, '/'))
dest = 'tools/cmd/ssh/data_generated.bin'
def normalize(t: tarfile.TarInfo) -> tarfile.TarInfo:
t.uid = t.gid = 0
t.uname = t.gname = ''
return t
if newer(dest, *files):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode='w') as tf:
for f in sorted(files):
tf.add(f, filter=normalize)
with open(dest, 'wb') as d:
write_compressed_data(buf.getvalue(), d)
def main() -> None:
@ -633,6 +726,7 @@ def main() -> None:
if newer('tools/unicode_names/data_generated.bin', 'tools/unicode_names/names.txt'):
with open('tools/unicode_names/data_generated.bin', 'wb') as dest, open('tools/unicode_names/names.txt') as src:
generate_unicode_names(src, dest)
generate_ssh_kitten_data()
update_completion()
update_at_commands()

14
go.mod
View File

@ -4,14 +4,24 @@ go 1.20
require (
github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924
github.com/bmatcuk/doublestar v1.3.4
github.com/disintegration/imaging v1.6.2
github.com/google/go-cmp v0.5.8
github.com/google/go-cmp v0.5.9
github.com/google/uuid v1.3.0
github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f
github.com/seancfoley/ipaddress-go v1.5.3
github.com/shirou/gopsutil/v3 v3.23.1
golang.org/x/exp v0.0.0-20230202163644-54bba9f4231b
golang.org/x/image v0.5.0
golang.org/x/sys v0.4.0
)
require github.com/seancfoley/bintree v1.2.1 // indirect
require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/seancfoley/bintree v1.2.1 // indirect
github.com/tklauser/go-sysconf v0.3.11 // indirect
github.com/tklauser/numcpus v0.6.0 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect
)

41
go.sum
View File

@ -1,18 +1,47 @@
github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924 h1:DG4UyTVIujioxwJc8Zj8Nabz1L1wTgQ/xNBSQDfdP3I=
github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924/go.mod h1:+NaH2gLeY6RPBPPQf4aRotPPStg+eXc8f9ZaE4vRfD4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f h1:Ko4+g6K16vSyUrtd/pPXuQnWsiHe5BYptEtTxfwYwCc=
github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f/go.mod h1:eHzfhOKbTGJEGPSdMHzU6jft192tHHt2Bu2vIZArvC0=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/seancfoley/bintree v1.2.1 h1:Z/iNjRKkXnn0CTW7jDQYtjW5fz2GH1yWvOTJ4MrMvdo=
github.com/seancfoley/bintree v1.2.1/go.mod h1:hIUabL8OFYyFVTQ6azeajbopogQc2l5C/hiXMcemWNU=
github.com/seancfoley/ipaddress-go v1.5.3 h1:fLnn4nsatd2rp3IJsVWriXv5gXn2Qiy8uxjxe4iZtTg=
github.com/seancfoley/ipaddress-go v1.5.3/go.mod h1:fpvVPC+Jso+YEhNcNiww8HQmBgKP8T4T6BTp1SLxxIo=
github.com/shirou/gopsutil/v3 v3.23.1 h1:a9KKO+kGLKEvcPIs4W62v0nu3sciVDOOOPUD0Hz7z/4=
github.com/shirou/gopsutil/v3 v3.23.1/go.mod h1:NN6mnm5/0k8jw4cBfCnJtr5L7ErOTg18tMNpgFkn0hA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
github.com/tklauser/go-sysconf v0.3.11/go.mod h1:GqXfhXY3kiPa0nAXPDIQIWzJbMCB7AmcWpGR8lSZfqI=
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
github.com/tklauser/numcpus v0.6.0/go.mod h1:FEZLMke0lhOUG6w2JadTzp0a+Nl8PF/GFkQ5UVIcaL4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg=
github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20230202163644-54bba9f4231b h1:EqBVA+nNsObCwQoBEHy4wLU0pi7i8a4AL3pbItPdPkE=
@ -27,10 +56,13 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
@ -43,3 +75,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -7,7 +7,7 @@ import os
import sys
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Generator, List, cast
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Generator, List, Optional, cast
from kitty.constants import list_kitty_resources
from kitty.types import run_once
@ -171,7 +171,7 @@ def get_kitten_completer(kitten: str) -> Any:
return ans
def get_kitten_conf_docs(kitten: str) -> Definition:
def get_kitten_conf_docs(kitten: str) -> Optional[Definition]:
setattr(sys, 'options_definition', None)
run_kitten(kitten, run_name='__conf__')
ans = getattr(sys, 'options_definition')
@ -179,6 +179,14 @@ def get_kitten_conf_docs(kitten: str) -> Definition:
return cast(Definition, ans)
def get_kitten_extra_cli_parsers(kitten: str) -> Dict[str,str]:
setattr(sys, 'extra_cli_parsers', {})
run_kitten(kitten, run_name='__extra_cli_parsers__')
ans = getattr(sys, 'extra_cli_parsers')
delattr(sys, 'extra_cli_parsers')
return cast(Dict[str, str], ans)
def main() -> None:
try:
args = sys.argv[1:]

View File

@ -1,70 +0,0 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import fnmatch
import os
from typing import Any, Dict, Iterable, Optional
from kitty.conf.utils import load_config as _load_config
from kitty.conf.utils import parse_config_base, resolve_config
from kitty.constants import config_dir
from .options.types import Options as SSHOptions
from .options.types import defaults
SYSTEM_CONF = '/etc/xdg/kitty/ssh.conf'
defconf = os.path.join(config_dir, 'ssh.conf')
def host_matches(mpat: str, hostname: str, username: str) -> bool:
for pat in mpat.split():
upat = '*'
if '@' in pat:
upat, pat = pat.split('@', 1)
if fnmatch.fnmatchcase(hostname, pat) and fnmatch.fnmatchcase(username, upat):
return True
return False
def load_config(*paths: str, overrides: Optional[Iterable[str]] = None, hostname: str = '!', username: str = '') -> SSHOptions:
from .options.parse import create_result_dict, merge_result_dicts, parse_conf_item
from .options.utils import first_seen_positions, get_per_hosts_dict, init_results_dict
def merge_dicts(base: Dict[str, Any], vals: Dict[str, Any]) -> Dict[str, Any]:
base_phd = get_per_hosts_dict(base)
vals_phd = get_per_hosts_dict(vals)
for hostname in base_phd:
vals_phd[hostname] = merge_result_dicts(base_phd[hostname], vals_phd.get(hostname, {}))
ans: Dict[str, Any] = vals_phd.pop(vals['hostname'])
ans['per_host_dicts'] = vals_phd
return ans
def parse_config(lines: Iterable[str]) -> Dict[str, Any]:
ans: Dict[str, Any] = init_results_dict(create_result_dict())
parse_config_base(lines, parse_conf_item, ans)
return ans
overrides = tuple(overrides) if overrides is not None else ()
first_seen_positions.clear()
first_seen_positions['*'] = 0
opts_dict, paths = _load_config(
defaults, parse_config, merge_dicts, *paths, overrides=overrides, initialize_defaults=init_results_dict)
phd = get_per_hosts_dict(opts_dict)
final_dict: Dict[str, Any] = {}
for hostname_pat in sorted(phd, key=first_seen_positions.__getitem__):
if host_matches(hostname_pat, hostname, username):
od = phd[hostname_pat]
for k, v in od.items():
if isinstance(v, dict):
bv = final_dict.setdefault(k, {})
bv.update(v)
else:
final_dict[k] = v
first_seen_positions.clear()
return SSHOptions(final_dict)
def init_config(hostname: str, username: str, overrides: Optional[Iterable[str]] = None) -> SSHOptions:
config = tuple(resolve_config(SYSTEM_CONF, defconf))
return load_config(*config, overrides=overrides, hostname=hostname, username=username)

View File

@ -1,117 +0,0 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import glob
import os
import shlex
import uuid
from typing import Dict, Iterable, Iterator, List, NamedTuple, Optional, Sequence, Tuple
from kitty.cli import parse_args
from kitty.cli_stub import CopyCLIOptions
from kitty.types import run_once
from ..transfer.utils import expand_home, home_path
@run_once
def option_text() -> str:
return '''
--glob
type=bool-set
Interpret file arguments as glob patterns.
--dest
The destination on the remote host to copy to. Relative paths are resolved
relative to HOME on the remote host. When this option is not specified, the
local file path is used as the remote destination (with the HOME directory
getting automatically replaced by the remote HOME). Note that environment
variables and ~ are not expanded.
--exclude
type=list
A glob pattern. Files with names matching this pattern are excluded from being
transferred. Useful when adding directories. Can
be specified multiple times, if any of the patterns match the file will be
excluded. To exclude a directory use a pattern like */directory_name/*.
--symlink-strategy
default=preserve
choices=preserve,resolve,keep-path
Control what happens if the specified path is a symlink. The default is to preserve
the symlink, re-creating it on the remote machine. Setting this to :code:`resolve`
will cause the symlink to be followed and its target used as the file/directory to copy.
The value of :code:`keep-path` is the same as :code:`resolve` except that the remote
file path is derived from the symlink's path instead of the path of the symlink's target.
Note that this option does not apply to symlinks encountered while recursively copying directories.
'''
def parse_copy_args(args: Optional[Sequence[str]] = None) -> Tuple[CopyCLIOptions, List[str]]:
args = list(args or ())
try:
opts, args = parse_args(result_class=CopyCLIOptions, args=args, ospec=option_text)
except SystemExit as e:
raise CopyCLIError from e
return opts, args
def resolve_file_spec(spec: str, is_glob: bool) -> Iterator[str]:
ans = os.path.expandvars(expand_home(spec))
if not os.path.isabs(ans):
ans = expand_home(f'~/{ans}')
if is_glob:
files = glob.glob(ans)
if not files:
raise CopyCLIError(f'{spec} does not exist')
else:
if not os.path.exists(ans):
raise CopyCLIError(f'{spec} does not exist')
files = [ans]
for x in files:
yield os.path.normpath(x).replace(os.sep, '/')
class CopyCLIError(ValueError):
pass
def get_arcname(loc: str, dest: Optional[str], home: str) -> str:
if dest:
arcname = dest
else:
arcname = os.path.normpath(loc)
if arcname.startswith(home):
arcname = os.path.relpath(arcname, home)
arcname = os.path.normpath(arcname).replace(os.sep, '/')
prefix = 'root' if arcname.startswith('/') else 'home/'
return prefix + arcname
class CopyInstruction(NamedTuple):
local_path: str
arcname: str
exclude_patterns: Tuple[str, ...]
def parse_copy_instructions(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]:
opts, args = parse_copy_args(shlex.split(val))
locations: List[str] = []
for a in args:
locations.extend(resolve_file_spec(a, opts.glob))
if not locations:
raise CopyCLIError('No files to copy specified')
if len(locations) > 1 and opts.dest:
raise CopyCLIError('Specifying a remote location with more than one file is not supported')
home = home_path()
for loc in locations:
if opts.symlink_strategy != 'preserve':
rp = os.path.realpath(loc)
else:
rp = loc
arcname = get_arcname(rp if opts.symlink_strategy == 'resolve' else loc, opts.dest, home)
yield str(uuid.uuid4()), CopyInstruction(rp, arcname, tuple(opts.exclude))

View File

@ -1,728 +1,214 @@
#!/usr/bin/env python3
# License: GPL v3 Copyright: 2018, Kovid Goyal <kovid at kovidgoyal.net>
import fnmatch
import glob
import io
import json
import os
import re
import secrets
import shlex
import shutil
import stat
import subprocess
import sys
import tarfile
import tempfile
import termios
import time
import traceback
from base64 import standard_b64decode, standard_b64encode
from contextlib import contextmanager, suppress
from getpass import getuser
from select import select
from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Tuple, Union, cast
from typing import List, Optional
from kitty.constants import cache_dir, runtime_dir, shell_integration_dir, ssh_control_master_template, str_version, terminfo_dir
from kitty.shell_integration import as_str_literal
from kitty.shm import SharedMemory
from kitty.conf.types import Definition
from kitty.types import run_once
from kitty.utils import SSHConnectionData, expandvars, resolve_abs_or_config_path
from kitty.utils import set_echo as turn_off_echo
from ..tui.operations import RESTORE_PRIVATE_MODE_VALUES, SAVE_PRIVATE_MODE_VALUES, Mode, restore_colors, save_colors, set_mode
from ..tui.utils import kitty_opts, running_in_tmux
from .config import init_config
from .copy import CopyInstruction
from .options.types import Options as SSHOptions
from .options.utils import DELETE_ENV_VAR
from .utils import create_shared_memory, get_ssh_cli, is_extra_arg, passthrough_args
copy_message = '''\
Copy files and directories from local to remote hosts. The specified files are
assumed to be relative to the HOME directory and copied to the HOME on the
remote. Directories are copied recursively. If absolute paths are used, they are
copied as is.'''
@run_once
def ssh_exe() -> str:
return shutil.which('ssh') or 'ssh'
def read_data_from_shared_memory(shm_name: str) -> Any:
with SharedMemory(shm_name, readonly=True) as shm:
shm.unlink()
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile')
mode = stat.S_IMODE(shm.stats.st_mode)
if mode != stat.S_IREAD | stat.S_IWRITE:
raise ValueError('Incorrect permissions on pwfile')
return json.loads(shm.read_data_with_size())
# See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
quote_pat = re.compile('([\\`"])')
def quote_env_val(x: str, literal_quote: bool = False) -> str:
if literal_quote:
return as_str_literal(x)
x = quote_pat.sub(r'\\\1', x)
x = x.replace('$(', r'\$(') # prevent execution with $()
return f'"{x}"'
def serialize_env(literal_env: Dict[str, str], env: Dict[str, str], base_env: Dict[str, str], for_python: bool = False) -> bytes:
lines = []
literal_quote = True
if for_python:
def a(k: str, val: str = '', prefix: str = 'export') -> None:
if val:
lines.append(f'{prefix} {json.dumps((k, val, literal_quote))}')
else:
lines.append(f'{prefix} {json.dumps((k,))}')
else:
def a(k: str, val: str = '', prefix: str = 'export') -> None:
if val:
lines.append(f'{prefix} {shlex.quote(k)}={quote_env_val(val, literal_quote)}')
else:
lines.append(f'{prefix} {shlex.quote(k)}')
for k, v in literal_env.items():
a(k, v)
literal_quote = False
for k in sorted(env):
v = env[k]
if v == DELETE_ENV_VAR:
a(k, prefix='unset')
elif v == '_kitty_copy_env_var_':
q = base_env.get(k)
if q is not None:
a(k, q)
else:
a(k, v)
return '\n'.join(lines).encode('utf-8')
def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz', literal_env: Dict[str, str] = {}) -> bytes:
def normalize_tarinfo(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo:
tarinfo.uname = tarinfo.gname = ''
tarinfo.uid = tarinfo.gid = 0
# some distro's like nix mess with installed file permissions so ensure
# files are at least readable and writable by owning user
tarinfo.mode |= stat.S_IWUSR | stat.S_IRUSR
return tarinfo
def add_data_as_file(tf: tarfile.TarFile, arcname: str, data: Union[str, bytes]) -> tarfile.TarInfo:
ans = tarfile.TarInfo(arcname)
ans.mtime = 0
ans.type = tarfile.REGTYPE
if isinstance(data, str):
data = data.encode('utf-8')
ans.size = len(data)
normalize_tarinfo(ans)
tf.addfile(ans, io.BytesIO(data))
return ans
def filter_from_globs(*pats: str) -> Callable[[tarfile.TarInfo], Optional[tarfile.TarInfo]]:
def filter(tarinfo: tarfile.TarInfo) -> Optional[tarfile.TarInfo]:
for junk_dir in ('.DS_Store', '__pycache__'):
for pat in (f'*/{junk_dir}', f'*/{junk_dir}/*'):
if fnmatch.fnmatch(tarinfo.name, pat):
return None
for pat in pats:
if fnmatch.fnmatch(tarinfo.name, pat):
return None
return normalize_tarinfo(tarinfo)
return filter
from kitty.shell_integration import get_effective_ksi_env_var
if ssh_opts.shell_integration == 'inherited':
ksi = get_effective_ksi_env_var(kitty_opts())
else:
from kitty.options.types import Options
from kitty.options.utils import shell_integration
ksi = get_effective_ksi_env_var(Options({'shell_integration': shell_integration(ssh_opts.shell_integration)}))
env = {
'TERM': os.environ.get('TERM') or kitty_opts().term,
'COLORTERM': 'truecolor',
}
env.update(ssh_opts.env)
for q in ('KITTY_WINDOW_ID', 'WINDOWID'):
val = os.environ.get(q)
if val is not None:
env[q] = val
env['KITTY_SHELL_INTEGRATION'] = ksi or DELETE_ENV_VAR
env['KITTY_SSH_KITTEN_DATA_DIR'] = ssh_opts.remote_dir
if ssh_opts.login_shell:
env['KITTY_LOGIN_SHELL'] = ssh_opts.login_shell
if ssh_opts.cwd:
env['KITTY_LOGIN_CWD'] = ssh_opts.cwd
if ssh_opts.remote_kitty != 'no':
env['KITTY_REMOTE'] = ssh_opts.remote_kitty
if os.environ.get('KITTY_PUBLIC_KEY'):
env.pop('KITTY_PUBLIC_KEY', None)
literal_env['KITTY_PUBLIC_KEY'] = os.environ['KITTY_PUBLIC_KEY']
env_script = serialize_env(literal_env, env, base_env, for_python=compression != 'gz')
buf = io.BytesIO()
with tarfile.open(mode=f'w:{compression}', fileobj=buf, encoding='utf-8') as tf:
rd = ssh_opts.remote_dir.rstrip('/')
for ci in ssh_opts.copy.values():
tf.add(ci.local_path, arcname=ci.arcname, filter=filter_from_globs(*ci.exclude_patterns))
add_data_as_file(tf, 'data.sh', env_script)
if compression == 'gz':
tf.add(f'{shell_integration_dir}/ssh/bootstrap-utils.sh', arcname='bootstrap-utils.sh', filter=normalize_tarinfo)
if ksi:
arcname = 'home/' + rd + '/shell-integration'
tf.add(shell_integration_dir, arcname=arcname, filter=filter_from_globs(
f'{arcname}/ssh/*', # bootstrap files are sent as command line args
f'{arcname}/zsh/kitty.zsh', # present for legacy compat not needed by ssh kitten
))
if ssh_opts.remote_kitty != 'no':
arcname = 'home/' + rd + '/kitty'
add_data_as_file(tf, arcname + '/version', str_version.encode('ascii'))
tf.add(shell_integration_dir + '/ssh/kitty', arcname=arcname + '/bin/kitty', filter=normalize_tarinfo)
tf.add(shell_integration_dir + '/ssh/kitten', arcname=arcname + '/bin/kitten', filter=normalize_tarinfo)
tf.add(f'{terminfo_dir}/kitty.terminfo', arcname='home/.terminfo/kitty.terminfo', filter=normalize_tarinfo)
tf.add(glob.glob(f'{terminfo_dir}/*/xterm-kitty')[0], arcname='home/.terminfo/x/xterm-kitty', filter=normalize_tarinfo)
return buf.getvalue()
def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
yield b'\nKITTY_DATA_START\n' # to discard leading data
try:
msg = standard_b64decode(msg).decode('utf-8')
md = dict(x.split('=', 1) for x in msg.split(':'))
pw = md['pw']
pwfilename = md['pwfile']
rq_id = md['id']
except Exception:
traceback.print_exc()
yield b'invalid ssh data request message\n'
else:
try:
env_data = read_data_from_shared_memory(pwfilename)
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window')
except Exception as e:
traceback.print_exc()
yield f'{e}\n'.encode('utf-8')
else:
yield b'OK\n'
ssh_opts = SSHOptions(env_data['opts'])
ssh_opts.copy = {k: CopyInstruction(*v) for k, v in ssh_opts.copy.items()}
encoded_data = memoryview(env_data['tarfile'].encode('ascii'))
# macOS has a 255 byte limit on its input queue as per man stty.
# Not clear if that applies to canonical mode input as well, but
# better to be safe.
line_sz = 254
while encoded_data:
yield encoded_data[:line_sz]
yield b'\n'
encoded_data = encoded_data[line_sz:]
yield b'KITTY_DATA_END\n'
def safe_remove(x: str) -> None:
with suppress(OSError):
os.remove(x)
def prepare_script(ans: str, replacements: Dict[str, str], script_type: str) -> str:
for k in ('EXEC_CMD', 'EXPORT_HOME_CMD'):
replacements[k] = replacements.get(k, '')
def sub(m: 're.Match[str]') -> str:
return replacements[m.group()]
return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans)
def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str:
# ssh simply concatenates multiple commands using a space see
# line 1129 of ssh.c and on the remote side sshd.c runs the
# concatenated command as shell -c cmd
if is_python:
return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args)
return f"""unset KITTY_SHELL_INTEGRATION; exec "$login_shell" -c '{args}'"""
def prepare_export_home_cmd(ssh_opts: SSHOptions, is_python: bool) -> str:
home = ssh_opts.env.get('HOME')
if home == '_kitty_copy_env_var_':
home = os.environ.get('HOME')
if home:
if is_python:
return standard_b64encode(home.encode('utf-8')).decode('ascii')
else:
return f'export HOME={quote_env_val(home)}; cd "$HOME"'
return ''
def bootstrap_script(
ssh_opts: SSHOptions, script_type: str = 'sh', remote_args: Sequence[str] = (),
test_script: str = '', request_id: Optional[str] = None, cli_hostname: str = '', cli_uname: str = '',
request_data: bool = False, echo_on: bool = True, literal_env: Dict[str, str] = {}
) -> Tuple[str, Dict[str, str], str]:
if request_id is None:
request_id = os.environ['KITTY_PID'] + '-' + os.environ['KITTY_WINDOW_ID']
is_python = script_type == 'py'
export_home_cmd = prepare_export_home_cmd(ssh_opts, is_python) if 'HOME' in ssh_opts.env else ''
exec_cmd = prepare_exec_cmd(remote_args, is_python) if remote_args else ''
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
ans = f.read()
pw = secrets.token_hex()
tfd = standard_b64encode(make_tarfile(ssh_opts, dict(os.environ), 'gz' if script_type == 'sh' else 'bz2', literal_env=literal_env)).decode('ascii')
data = {'pw': pw, 'opts': ssh_opts._asdict(), 'hostname': cli_hostname, 'uname': cli_uname, 'tarfile': tfd}
shm_name = create_shared_memory(data, prefix=f'kssh-{os.getpid()}-')
sensitive_data = {'REQUEST_ID': request_id, 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm_name}
replacements = {
'EXPORT_HOME_CMD': export_home_cmd,
'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
'REQUEST_DATA': '1' if request_data else '0', 'ECHO_ON': '1' if echo_on else '0',
}
sd = replacements.copy()
if request_data:
sd.update(sensitive_data)
replacements.update(sensitive_data)
return prepare_script(ans, sd, script_type), replacements, shm_name
def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
port: Optional[int] = None
expecting_port = expecting_identity = False
expecting_option_val = False
expecting_hostname = False
expecting_extra_val = ''
host_name = identity_file = found_ssh = ''
found_extra_args: List[Tuple[str, str]] = []
for i, arg in enumerate(args):
if not found_ssh:
if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
found_ssh = arg
continue
if expecting_hostname:
host_name = arg
continue
if arg.startswith('-') and not expecting_option_val:
if arg in boolean_ssh_args:
continue
if arg == '--':
expecting_hostname = True
if arg.startswith('-p'):
if arg[2:].isdigit():
with suppress(Exception):
port = int(arg[2:])
continue
elif arg == '-p':
expecting_port = True
elif arg.startswith('-i'):
if arg == '-i':
expecting_identity = True
else:
identity_file = arg[2:]
continue
if arg.startswith('--') and extra_args:
matching_ex = is_extra_arg(arg, extra_args)
if matching_ex:
if '=' in arg:
exval = arg.partition('=')[-1]
found_extra_args.append((matching_ex, exval))
continue
expecting_extra_val = matching_ex
expecting_option_val = True
continue
if expecting_option_val:
if expecting_port:
with suppress(Exception):
port = int(arg)
expecting_port = False
elif expecting_identity:
identity_file = arg
elif expecting_extra_val:
found_extra_args.append((expecting_extra_val, arg))
expecting_extra_val = ''
expecting_option_val = False
continue
if not host_name:
host_name = arg
if not host_name:
return None
if host_name.startswith('ssh://'):
from urllib.parse import urlparse
purl = urlparse(host_name)
if purl.hostname:
host_name = purl.hostname
if purl.username:
host_name = f'{purl.username}@{host_name}'
if port is None and purl.port:
port = purl.port
if identity_file:
if not os.path.isabs(identity_file):
identity_file = os.path.expanduser(identity_file)
if not os.path.isabs(identity_file):
identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))
class InvalidSSHArgs(ValueError):
def __init__(self, msg: str = ''):
super().__init__(msg)
self.err_msg = msg
def system_exit(self) -> None:
if self.err_msg:
print(self.err_msg, file=sys.stderr)
os.execlp(ssh_exe(), 'ssh')
def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[List[str], List[str], bool, Tuple[str, ...]]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
ssh_args = []
server_args: List[str] = []
expecting_option_val = False
passthrough = False
stop_option_processing = False
found_extra_args: List[str] = []
expecting_extra_val = ''
for argument in args:
if len(server_args) > 1 or stop_option_processing:
server_args.append(argument)
continue
if argument.startswith('-') and not expecting_option_val:
if argument == '--':
stop_option_processing = True
continue
if extra_args:
matching_ex = is_extra_arg(argument, extra_args)
if matching_ex:
if '=' in argument:
exval = argument.partition('=')[-1]
found_extra_args.extend((matching_ex, exval))
else:
expecting_extra_val = matching_ex
expecting_option_val = True
continue
# could be a multi-character option
all_args = argument[1:]
for i, arg in enumerate(all_args):
arg = f'-{arg}'
if arg in passthrough_args:
passthrough = True
if arg in boolean_ssh_args:
ssh_args.append(arg)
continue
if arg in other_ssh_args:
ssh_args.append(arg)
rest = all_args[i+1:]
if rest:
ssh_args.append(rest)
else:
expecting_option_val = True
break
raise InvalidSSHArgs(f'unknown option -- {arg[1:]}')
continue
if expecting_option_val:
if expecting_extra_val:
found_extra_args.extend((expecting_extra_val, argument))
expecting_extra_val = ''
else:
ssh_args.append(argument)
expecting_option_val = False
continue
server_args.append(argument)
if not server_args:
raise InvalidSSHArgs()
return ssh_args, server_args, passthrough, tuple(found_extra_args)
def wrap_bootstrap_script(sh_script: str, interpreter: str) -> List[str]:
# sshd will execute the command we pass it by join all command line
# arguments with a space and passing it as a single argument to the users
# login shell with -c. If the user has a non POSIX login shell it might
# have different escaping semantics and syntax, so the command it should
# execute has to be as simple as possible, basically of the form
# interpreter -c unwrap_script escaped_bootstrap_script
# The unwrap_script is responsible for unescaping the bootstrap script and
# executing it.
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
if is_python:
es = standard_b64encode(sh_script.encode('utf-8')).decode('ascii')
unwrap_script = '''"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"'''
else:
# We cant rely on base64 being available on the remote system, so instead
# we quote the bootstrap script by replacing ' and \ with \v and \f
# also replacing \n and ! with \r and \b for tcsh
# finally surrounding with '
es = "'" + sh_script.replace("'", '\v').replace('\\', '\f').replace('\n', '\r').replace('!', '\b') + "'"
unwrap_script = r"""'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' """
# exec is supported by all sh like shells, and fish and csh
return ['exec', interpreter, '-c', unwrap_script, es]
def get_remote_command(
remote_args: List[str], ssh_opts: SSHOptions, cli_hostname: str = '', cli_uname: str = '',
echo_on: bool = True, request_data: bool = False, literal_env: Dict[str, str] = {}
) -> Tuple[List[str], Dict[str, str], str]:
interpreter = ssh_opts.interpreter
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
sh_script, replacements, shm_name = bootstrap_script(
ssh_opts, script_type='py' if is_python else 'sh', remote_args=remote_args, literal_env=literal_env,
cli_hostname=cli_hostname, cli_uname=cli_uname, echo_on=echo_on, request_data=request_data)
return wrap_bootstrap_script(sh_script, interpreter), replacements, shm_name
def connection_sharing_args(kitty_pid: int) -> List[str]:
rd = runtime_dir()
# Bloody OpenSSH generates a 40 char hash and in creating the socket
# appends a 27 char temp suffix to it. Socket max path length is approx
# ~104 chars. macOS has no system runtime dir so we use a cache dir in
# /Users/WHY_DOES_ANYONE_USE_MACOS/Library/Caches/APPLE_ARE_IDIOTIC
if len(rd) > 35 and os.path.isdir('/tmp'):
idiotic_design = f'/tmp/kssh-rdir-{os.getuid()}'
try:
os.symlink(rd, idiotic_design)
except FileExistsError:
try:
dest = os.readlink(idiotic_design)
except OSError as e:
raise ValueError(f'The {idiotic_design} symlink could not be created as something with that name exists already') from e
else:
if dest != rd:
with tempfile.TemporaryDirectory(dir='/tmp') as tdir:
tlink = os.path.join(tdir, 'sigh')
os.symlink(rd, tlink)
os.rename(tlink, idiotic_design)
rd = idiotic_design
cp = os.path.join(rd, ssh_control_master_template.format(kitty_pid=kitty_pid, ssh_placeholder='%C'))
ans: List[str] = [
'-o', 'ControlMaster=auto',
'-o', f'ControlPath={cp}',
'-o', 'ControlPersist=yes',
'-o', 'ServerAliveInterval=60',
'-o', 'ServerAliveCountMax=5',
'-o', 'TCPKeepAlive=no',
]
return ans
@contextmanager
def restore_terminal_state() -> Iterator[bool]:
with open(os.ctermid()) as f:
val = termios.tcgetattr(f.fileno())
print(end=SAVE_PRIVATE_MODE_VALUES)
print(end=set_mode(Mode.HANDLE_TERMIOS_SIGNALS), flush=True)
try:
yield bool(val[3] & termios.ECHO)
finally:
termios.tcsetattr(f.fileno(), termios.TCSAFLUSH, val)
print(end=RESTORE_PRIVATE_MODE_VALUES, flush=True)
def dcs_to_kitty(payload: Union[bytes, str], type: str = 'ssh') -> bytes:
if isinstance(payload, str):
payload = payload.encode('utf-8')
payload = standard_b64encode(payload)
ans = b'\033P@kitty-' + type.encode('ascii') + b'|' + payload
tmux = running_in_tmux()
if tmux:
cp = subprocess.run([tmux, 'set', '-p', 'allow-passthrough', 'on'])
if cp.returncode != 0:
raise SystemExit(cp.returncode)
ans = b'\033Ptmux;\033' + ans + b'\033\033\\\033\\'
else:
ans += b'\033\\'
return ans
@run_once
def ssh_version() -> Tuple[int, int]:
o = subprocess.check_output([ssh_exe(), '-V'], stderr=subprocess.STDOUT).decode()
m = re.match(r'OpenSSH_(\d+).(\d+)', o)
if m is None:
raise ValueError(f'Invalid version string for OpenSSH: {o}')
return int(m.group(1)), int(m.group(2))
@contextmanager
def drain_potential_tty_garbage(p: 'subprocess.Popen[bytes]', data_request: str) -> Iterator[None]:
with open(os.open(os.ctermid(), os.O_CLOEXEC | os.O_RDWR | os.O_NOCTTY), 'wb') as tty:
if data_request:
turn_off_echo(tty.fileno())
tty.write(dcs_to_kitty(data_request))
tty.flush()
try:
yield
finally:
# discard queued input data on tty in case data transmission was
# interrupted due to SSH failure, avoids spewing garbage to screen
from uuid import uuid4
canary = uuid4().hex.encode('ascii')
turn_off_echo(tty.fileno())
tty.write(dcs_to_kitty(canary + b'\n\r', type='echo'))
tty.flush()
data = b''
give_up_at = time.monotonic() + 2
tty_fd = tty.fileno()
while time.monotonic() < give_up_at and canary not in data:
with suppress(KeyboardInterrupt):
rd, wr, err = select([tty_fd], [], [tty_fd], max(0, give_up_at - time.monotonic()))
if err or not rd:
break
q = os.read(tty_fd, io.DEFAULT_BUFFER_SIZE)
if not q:
break
data += q
def change_colors(color_scheme: str) -> bool:
if not color_scheme:
return False
from kittens.themes.collection import NoCacheFound, load_themes, text_as_opts
from kittens.themes.main import colors_as_escape_codes
if color_scheme.endswith('.conf'):
conf_file = resolve_abs_or_config_path(color_scheme)
try:
with open(conf_file) as f:
opts = text_as_opts(f.read())
except FileNotFoundError:
raise SystemExit(f'Failed to find the color conf file: {expandvars(conf_file)}')
else:
try:
themes = load_themes(-1)
except NoCacheFound:
themes = load_themes()
cs = expandvars(color_scheme)
try:
theme = themes[cs]
except KeyError:
raise SystemExit(f'Failed to find the color theme: {cs}')
opts = theme.kitty_opts
raw = colors_as_escape_codes(opts)
print(save_colors(), sep='', end=raw, flush=True)
return True
def add_cloned_env(shm_name: str) -> Dict[str, str]:
try:
return cast(Dict[str, str], read_data_from_shared_memory(shm_name))
except FileNotFoundError:
pass
return {}
def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple[str, ...]) -> NoReturn:
cmd = [ssh_exe()] + ssh_args
hostname, remote_args = server_args[0], server_args[1:]
if not remote_args:
cmd.append('-t')
insertion_point = len(cmd)
cmd.append('--')
cmd.append(hostname)
uname = getuser()
if hostname.startswith('ssh://'):
from urllib.parse import urlparse
purl = urlparse(hostname)
hostname_for_match = purl.hostname or hostname[6:].split('/', 1)[0]
uname = purl.username or uname
elif '@' in hostname and hostname[0] != '@':
uname, hostname_for_match = hostname.split('@', 1)
else:
hostname_for_match = hostname
hostname_for_match = hostname_for_match.split('@', 1)[-1].split(':', 1)[0]
overrides: List[str] = []
literal_env: Dict[str, str] = {}
pat = re.compile(r'^([a-zA-Z0-9_]+)[ \t]*=')
for i, a in enumerate(found_extra_args):
if i % 2 == 1:
aq = pat.sub(r'\1 ', a.lstrip())
key = aq.split(maxsplit=1)[0]
if key == 'clone_env':
literal_env = add_cloned_env(aq.split(maxsplit=1)[1])
elif key != 'hostname':
overrides.append(aq)
if overrides:
overrides.insert(0, f'hostname {uname}@{hostname_for_match}')
host_opts = init_config(hostname_for_match, uname, overrides)
if host_opts.share_connections:
cmd[insertion_point:insertion_point] = connection_sharing_args(int(os.environ['KITTY_PID']))
use_kitty_askpass = host_opts.askpass == 'native' or (host_opts.askpass == 'unless-set' and 'SSH_ASKPASS' not in os.environ)
need_to_request_data = True
if use_kitty_askpass:
sentinel = os.path.join(cache_dir(), 'openssh-is-new-enough-for-askpass')
sentinel_exists = os.path.exists(sentinel)
if sentinel_exists or ssh_version() >= (8, 4):
if not sentinel_exists:
open(sentinel, 'w').close()
# SSH_ASKPASS_REQUIRE was introduced in 8.4 release on 2020-09-27
need_to_request_data = False
os.environ['SSH_ASKPASS_REQUIRE'] = 'force'
os.environ['SSH_ASKPASS'] = os.path.join(shell_integration_dir, 'ssh', 'askpass.py')
if need_to_request_data and host_opts.share_connections:
cp = subprocess.run(cmd[:1] + ['-O', 'check'] + cmd[1:], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
if cp.returncode == 0:
# we will use the master connection so SSH does not need to use the tty
need_to_request_data = False
with restore_terminal_state() as echo_on:
rcmd, replacements, shm_name = get_remote_command(
remote_args, host_opts, hostname_for_match, uname, echo_on, request_data=need_to_request_data, literal_env=literal_env)
cmd += rcmd
colors_changed = change_colors(host_opts.color_scheme)
try:
p = subprocess.Popen(cmd)
except FileNotFoundError:
raise SystemExit('Could not find the ssh executable, is it in your PATH?')
else:
rq = '' if need_to_request_data else 'id={REQUEST_ID}:pwfile={PASSWORD_FILENAME}:pw={DATA_PASSWORD}'.format(**replacements)
with drain_potential_tty_garbage(p, rq):
raise SystemExit(p.wait())
finally:
if colors_changed:
print(end=restore_colors(), flush=True)
def main(args: List[str]) -> None:
args = args[1:]
if args and args[0] == 'use-python':
args = args[1:] # backwards compat from when we had a python implementation
try:
ssh_args, server_args, passthrough, found_extra_args = parse_ssh_args(args, extra_args=('--kitten',))
except InvalidSSHArgs as e:
e.system_exit()
if passthrough:
if found_extra_args:
raise SystemExit(f'The SSH kitten cannot work with the options: {", ".join(passthrough_args)}')
os.execlp(ssh_exe(), 'ssh', *args)
if not os.environ.get('KITTY_WINDOW_ID') or not os.environ.get('KITTY_PID'):
raise SystemExit('The SSH kitten is meant to run inside a kitty window')
if not sys.stdin.isatty():
raise SystemExit('The SSH kitten is meant for interactive use only, STDIN must be a terminal')
try:
run_ssh(ssh_args, server_args, found_extra_args)
except KeyboardInterrupt:
sys.excepthook = lambda *a: None
raise
def option_text() -> str:
return '''
--glob
type=bool-set
Interpret file arguments as glob patterns. Globbing is based on
Based on standard wildcards with the addition that ``/**/`` matches any number of directories.
See the :link:`detailed syntax <https://github.com/bmatcuk/doublestar#patterns>`.
--dest
The destination on the remote host to copy to. Relative paths are resolved
relative to HOME on the remote host. When this option is not specified, the
local file path is used as the remote destination (with the HOME directory
getting automatically replaced by the remote HOME). Note that environment
variables and ~ are not expanded.
--exclude
type=list
A glob pattern. Files with names matching this pattern are excluded from being
transferred. Useful when adding directories. Can
be specified multiple times, if any of the patterns match the file will be
excluded. If the pattern includes a :code:`/` then it will match against the full
path, not just the filename. In such patterns you can use :code:`/**/` to match zero
or more directories. For example, to exclude a directory and everything under it use
:code:`**/directory_name`.
See the :link:`detailed syntax <https://github.com/bmatcuk/doublestar#patterns>` for
how wildcards match.
--symlink-strategy
default=preserve
choices=preserve,resolve,keep-path
Control what happens if the specified path is a symlink. The default is to preserve
the symlink, re-creating it on the remote machine. Setting this to :code:`resolve`
will cause the symlink to be followed and its target used as the file/directory to copy.
The value of :code:`keep-path` is the same as :code:`resolve` except that the remote
file path is derived from the symlink's path instead of the path of the symlink's target.
Note that this option does not apply to symlinks encountered while recursively copying directories,
those are always preserved.
'''
definition = Definition(
'!kittens.ssh',
)
agr = definition.add_group
egr = definition.end_group
opt = definition.add_option
agr('bootstrap', 'Host bootstrap configuration') # {{{
opt('hostname', '*', long_text='''
The hostname that the following options apply to. A glob pattern to match
multiple hosts can be used. Multiple hostnames can also be specified, separated
by spaces. The hostname can include an optional username in the form
:code:`user@host`. When not specified options apply to all hosts, until the
first hostname specification is found. Note that matching of hostname is done
against the name you specify on the command line to connect to the remote host.
If you wish to include the same basic configuration for many different hosts,
you can do so with the :ref:`include <include>` directive.
''')
opt('interpreter', 'sh', long_text='''
The interpreter to use on the remote host. Must be either a POSIX complaint
shell or a :program:`python` executable. If the default :program:`sh` is not
available or broken, using an alternate interpreter can be useful.
''')
opt('remote_dir', '.local/share/kitty-ssh-kitten', long_text='''
The location on the remote host where the files needed for this kitten are
installed. Relative paths are resolved with respect to :code:`$HOME`.
''')
opt('+copy', '', add_to_default=False, ctype='CopyInstruction', long_text=f'''
{copy_message} For example::
copy .vimrc .zshrc .config/some-dir
Use :code:`--dest` to copy a file to some other destination on the remote host::
copy --dest some-other-name some-file
Glob patterns can be specified to copy multiple files, with :code:`--glob`::
copy --glob images/*.png
Files can be excluded when copying with :code:`--exclude`::
copy --glob --exclude *.jpg --exclude *.bmp images/*
Files whose remote name matches the exclude pattern will not be copied.
For more details, see :ref:`ssh_copy_command`.
''')
egr() # }}}
agr('shell', 'Login shell environment') # {{{
opt('shell_integration', 'inherited', long_text='''
Control the shell integration on the remote host. See :ref:`shell_integration`
for details on how this setting works. The special value :code:`inherited` means
use the setting from :file:`kitty.conf`. This setting is useful for overriding
integration on a per-host basis.
''')
opt('login_shell', '', long_text='''
The login shell to execute on the remote host. By default, the remote user
account's login shell is used.
''')
opt('+env', '', add_to_default=False, ctype='EnvInstruction', long_text='''
Specify the environment variables to be set on the remote host. Using the
name with an equal sign (e.g. :code:`env VAR=`) will set it to the empty string.
Specifying only the name (e.g. :code:`env VAR`) will remove the variable from
the remote shell environment. The special value :code:`_kitty_copy_env_var_`
will cause the value of the variable to be copied from the local environment.
The definitions are processed alphabetically. Note that environment variables
are expanded recursively, for example::
env VAR1=a
env VAR2=${HOME}/${VAR1}/b
The value of :code:`VAR2` will be :code:`<path to home directory>/a/b`.
''')
opt('cwd', '', long_text='''
The working directory on the remote host to change to. Environment variables in
this value are expanded. The default is empty so no changing is done, which
usually means the HOME directory is used.
''')
opt('color_scheme', '', long_text='''
Specify a color scheme to use when connecting to the remote host. If this option
ends with :code:`.conf`, it is assumed to be the name of a config file to load
from the kitty config directory, otherwise it is assumed to be the name of a
color theme to load via the :doc:`themes kitten </kittens/themes>`. Note that
only colors applying to the text/background are changed, other config settings
in the .conf files/themes are ignored.
''')
opt('remote_kitty', 'if-needed', choices=('if-needed', 'no', 'yes'), long_text='''
Make :program:`kitty` available on the remote host. Useful to run kittens such
as the :doc:`icat kitten </kittens/icat>` to display images or the
:doc:`transfer file kitten </kittens/transfer>` to transfer files. Only works if
the remote host has an architecture for which :link:`pre-compiled kitty binaries
<https://github.com/kovidgoyal/kitty/releases>` are available. Note that kitty
is not actually copied to the remote host, instead a small bootstrap script is
copied which will download and run kitty when kitty is first executed on the
remote host. A value of :code:`if-needed` means kitty is installed only if not
already present in the system-wide PATH. A value of :code:`yes` means that kitty
is installed even if already present, and the installed kitty takes precedence.
Finally, :code:`no` means no kitty is installed on the remote host. The
installed kitty can be updated by running: :code:`kitty +update-kitty` on the
remote host.
''')
egr() # }}}
agr('ssh', 'SSH configuration') # {{{
opt('share_connections', 'yes', option_type='to_bool', long_text='''
Within a single kitty instance, all connections to a particular server can be
shared. This reduces startup latency for subsequent connections and means that
you have to enter the password only once. Under the hood, it uses SSH
ControlMasters and these are automatically cleaned up by kitty when it quits.
You can map a shortcut to :ac:`close_shared_ssh_connections` to disconnect all
active shared connections.
''')
opt('askpass', 'unless-set', choices=('unless-set', 'ssh', 'native'), long_text='''
Control the program SSH uses to ask for passwords or confirmation of host keys
etc. The default is to use kitty's native :program:`askpass`, unless the
:envvar:`SSH_ASKPASS` environment variable is set. Set this option to
:code:`ssh` to not interfere with the normal ssh askpass mechanism at all, which
typically means that ssh will prompt at the terminal. Set it to :code:`native`
to always use kitty's native, built-in askpass implementation. Note that not
using the kitty askpass implementation means that SSH might need to use the
terminal before the connection is established, so the kitten cannot use the
terminal to send data without an extra roundtrip, adding to initial connection
latency.
''')
egr() # }}}
def main(args: List[str]) -> Optional[str]:
raise SystemExit('This should be run as kitten unicode_input')
if __name__ == '__main__':
main(sys.argv)
main([])
elif __name__ == '__wrapper_of__':
cd = sys.cli_docs # type: ignore
cd = getattr(sys, 'cli_docs')
cd['wrapper_of'] = 'ssh'
elif __name__ == '__conf__':
from .options.definition import definition
sys.options_definition = definition # type: ignore
setattr(sys, 'options_definition', definition)
elif __name__ == '__extra_cli_parsers__':
setattr(sys, 'extra_cli_parsers', {'copy': option_text()})

View File

@ -1,153 +0,0 @@
#!/usr/bin/env python
# vim:fileencoding=utf-8
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
# After editing this file run ./gen-config.py to apply the changes
from kitty.conf.types import Definition
copy_message = '''\
Copy files and directories from local to remote hosts. The specified files are
assumed to be relative to the HOME directory and copied to the HOME on the
remote. Directories are copied recursively. If absolute paths are used, they are
copied as is.'''
definition = Definition(
'kittens.ssh',
)
agr = definition.add_group
egr = definition.end_group
opt = definition.add_option
agr('bootstrap', 'Host bootstrap configuration') # {{{
opt('hostname', '*', option_type='hostname', long_text='''
The hostname that the following options apply to. A glob pattern to match
multiple hosts can be used. Multiple hostnames can also be specified, separated
by spaces. The hostname can include an optional username in the form
:code:`user@host`. When not specified options apply to all hosts, until the
first hostname specification is found. Note that matching of hostname is done
against the name you specify on the command line to connect to the remote host.
If you wish to include the same basic configuration for many different hosts,
you can do so with the :ref:`include <include>` directive.
''')
opt('interpreter', 'sh', long_text='''
The interpreter to use on the remote host. Must be either a POSIX complaint
shell or a :program:`python` executable. If the default :program:`sh` is not
available or broken, using an alternate interpreter can be useful.
''')
opt('remote_dir', '.local/share/kitty-ssh-kitten', long_text='''
The location on the remote host where the files needed for this kitten are
installed. Relative paths are resolved with respect to :code:`$HOME`.
''')
opt('+copy', '', option_type='copy', add_to_default=False, long_text=f'''
{copy_message} For example::
copy .vimrc .zshrc .config/some-dir
Use :code:`--dest` to copy a file to some other destination on the remote host::
copy --dest some-other-name some-file
Glob patterns can be specified to copy multiple files, with :code:`--glob`::
copy --glob images/*.png
Files can be excluded when copying with :code:`--exclude`::
copy --glob --exclude *.jpg --exclude *.bmp images/*
Files whose remote name matches the exclude pattern will not be copied.
For more details, see :ref:`ssh_copy_command`.
''')
egr() # }}}
agr('shell', 'Login shell environment') # {{{
opt('shell_integration', 'inherited', long_text='''
Control the shell integration on the remote host. See :ref:`shell_integration`
for details on how this setting works. The special value :code:`inherited` means
use the setting from :file:`kitty.conf`. This setting is useful for overriding
integration on a per-host basis.
''')
opt('login_shell', '', long_text='''
The login shell to execute on the remote host. By default, the remote user
account's login shell is used.
''')
opt('+env', '', option_type='env', add_to_default=False, long_text='''
Specify the environment variables to be set on the remote host. Using the
name with an equal sign (e.g. :code:`env VAR=`) will set it to the empty string.
Specifying only the name (e.g. :code:`env VAR`) will remove the variable from
the remote shell environment. The special value :code:`_kitty_copy_env_var_`
will cause the value of the variable to be copied from the local environment.
The definitions are processed alphabetically. Note that environment variables
are expanded recursively, for example::
env VAR1=a
env VAR2=${HOME}/${VAR1}/b
The value of :code:`VAR2` will be :code:`<path to home directory>/a/b`.
''')
opt('cwd', '', long_text='''
The working directory on the remote host to change to. Environment variables in
this value are expanded. The default is empty so no changing is done, which
usually means the HOME directory is used.
''')
opt('color_scheme', '', long_text='''
Specify a color scheme to use when connecting to the remote host. If this option
ends with :code:`.conf`, it is assumed to be the name of a config file to load
from the kitty config directory, otherwise it is assumed to be the name of a
color theme to load via the :doc:`themes kitten </kittens/themes>`. Note that
only colors applying to the text/background are changed, other config settings
in the .conf files/themes are ignored.
''')
opt('remote_kitty', 'if-needed', choices=('if-needed', 'no', 'yes'), long_text='''
Make :program:`kitty` available on the remote host. Useful to run kittens such
as the :doc:`icat kitten </kittens/icat>` to display images or the
:doc:`transfer file kitten </kittens/transfer>` to transfer files. Only works if
the remote host has an architecture for which :link:`pre-compiled kitty binaries
<https://github.com/kovidgoyal/kitty/releases>` are available. Note that kitty
is not actually copied to the remote host, instead a small bootstrap script is
copied which will download and run kitty when kitty is first executed on the
remote host. A value of :code:`if-needed` means kitty is installed only if not
already present in the system-wide PATH. A value of :code:`yes` means that kitty
is installed even if already present, and the installed kitty takes precedence.
Finally, :code:`no` means no kitty is installed on the remote host. The
installed kitty can be updated by running: :code:`kitty +update-kitty` on the
remote host.
''')
egr() # }}}
agr('ssh', 'SSH configuration') # {{{
opt('share_connections', 'yes', option_type='to_bool', long_text='''
Within a single kitty instance, all connections to a particular server can be
shared. This reduces startup latency for subsequent connections and means that
you have to enter the password only once. Under the hood, it uses SSH
ControlMasters and these are automatically cleaned up by kitty when it quits.
You can map a shortcut to :ac:`close_shared_ssh_connections` to disconnect all
active shared connections.
''')
opt('askpass', 'unless-set', choices=('unless-set', 'ssh', 'native'), long_text='''
Control the program SSH uses to ask for passwords or confirmation of host keys
etc. The default is to use kitty's native :program:`askpass`, unless the
:envvar:`SSH_ASKPASS` environment variable is set. Set this option to
:code:`ssh` to not interfere with the normal ssh askpass mechanism at all, which
typically means that ssh will prompt at the terminal. Set it to :code:`native`
to always use kitty's native, built-in askpass implementation. Note that not
using the kitty askpass implementation means that SSH might need to use the
terminal before the connection is established, so the kitten cannot use the
terminal to send data without an extra roundtrip, adding to initial connection
latency.
''')
egr() # }}}

View File

@ -1,90 +0,0 @@
# generated by gen-config.py DO NOT edit
# isort: skip_file
import typing
from kittens.ssh.options.utils import copy, env, hostname
from kitty.conf.utils import merge_dicts, to_bool
class Parser:
def askpass(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
val = val.lower()
if val not in self.choices_for_askpass:
raise ValueError(f"The value {val} is not a valid choice for askpass")
ans["askpass"] = val
choices_for_askpass = frozenset(('unless-set', 'ssh', 'native'))
def color_scheme(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['color_scheme'] = str(val)
def copy(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
for k, v in copy(val, ans["copy"]):
ans["copy"][k] = v
def cwd(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['cwd'] = str(val)
def env(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
for k, v in env(val, ans["env"]):
ans["env"][k] = v
def hostname(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
hostname(val, ans)
def interpreter(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['interpreter'] = str(val)
def login_shell(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['login_shell'] = str(val)
def remote_dir(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['remote_dir'] = str(val)
def remote_kitty(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
val = val.lower()
if val not in self.choices_for_remote_kitty:
raise ValueError(f"The value {val} is not a valid choice for remote_kitty")
ans["remote_kitty"] = val
choices_for_remote_kitty = frozenset(('if-needed', 'no', 'yes'))
def share_connections(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['share_connections'] = to_bool(val)
def shell_integration(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['shell_integration'] = str(val)
def create_result_dict() -> typing.Dict[str, typing.Any]:
return {
'copy': {},
'env': {},
}
actions: typing.FrozenSet[str] = frozenset(())
def merge_result_dicts(defaults: typing.Dict[str, typing.Any], vals: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
ans = {}
for k, v in defaults.items():
if isinstance(v, dict):
ans[k] = merge_dicts(v, vals.get(k, {}))
elif k in actions:
ans[k] = v + vals.get(k, [])
else:
ans[k] = vals.get(k, v)
return ans
parser = Parser()
def parse_conf_item(key: str, val: str, ans: typing.Dict[str, typing.Any]) -> bool:
func = getattr(parser, key, None)
if func is not None:
func(val, ans)
return True
return False

View File

@ -1,93 +0,0 @@
# generated by gen-config.py DO NOT edit
# isort: skip_file
import typing
import kittens.ssh.copy
if typing.TYPE_CHECKING:
choices_for_askpass = typing.Literal['unless-set', 'ssh', 'native']
choices_for_remote_kitty = typing.Literal['if-needed', 'no', 'yes']
else:
choices_for_askpass = str
choices_for_remote_kitty = str
option_names = ( # {{{
'askpass',
'color_scheme',
'copy',
'cwd',
'env',
'hostname',
'interpreter',
'login_shell',
'remote_dir',
'remote_kitty',
'share_connections',
'shell_integration') # }}}
class Options:
askpass: choices_for_askpass = 'unless-set'
color_scheme: str = ''
cwd: str = ''
hostname: str = '*'
interpreter: str = 'sh'
login_shell: str = ''
remote_dir: str = '.local/share/kitty-ssh-kitten'
remote_kitty: choices_for_remote_kitty = 'if-needed'
share_connections: bool = True
shell_integration: str = 'inherited'
copy: typing.Dict[str, kittens.ssh.copy.CopyInstruction] = {}
env: typing.Dict[str, str] = {}
config_paths: typing.Tuple[str, ...] = ()
config_overrides: typing.Tuple[str, ...] = ()
def __init__(self, options_dict: typing.Optional[typing.Dict[str, typing.Any]] = None) -> None:
if options_dict is not None:
null = object()
for key in option_names:
val = options_dict.get(key, null)
if val is not null:
setattr(self, key, val)
@property
def _fields(self) -> typing.Tuple[str, ...]:
return option_names
def __iter__(self) -> typing.Iterator[str]:
return iter(self._fields)
def __len__(self) -> int:
return len(self._fields)
def _copy_of_val(self, name: str) -> typing.Any:
ans = getattr(self, name)
if isinstance(ans, dict):
ans = ans.copy()
elif isinstance(ans, list):
ans = ans[:]
return ans
def _asdict(self) -> typing.Dict[str, typing.Any]:
return {k: self._copy_of_val(k) for k in self}
def _replace(self, **kw: typing.Any) -> "Options":
ans = Options()
for name in self:
setattr(ans, name, self._copy_of_val(name))
for name, val in kw.items():
setattr(ans, name, val)
return ans
def __getitem__(self, key: typing.Union[int, str]) -> typing.Any:
k = option_names[key] if isinstance(key, int) else key
try:
return getattr(self, k)
except AttributeError:
pass
raise KeyError(f"No option named: {k}")
defaults = Options()
defaults.copy = {}
defaults.env = {}

View File

@ -1,56 +0,0 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
from typing import Any, Dict, Iterable, Optional, Tuple
from ..copy import CopyInstruction, parse_copy_instructions
DELETE_ENV_VAR = '_delete_this_env_var_'
def env(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, str]]:
val = val.strip()
if val:
if '=' in val:
key, v = val.split('=', 1)
key, v = key.strip(), v.strip()
if key:
yield key, v
else:
yield val, DELETE_ENV_VAR
def copy(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]:
yield from parse_copy_instructions(val, current_val)
def init_results_dict(ans: Dict[str, Any]) -> Dict[str, Any]:
ans['hostname'] = '*'
ans['per_host_dicts'] = {}
return ans
def get_per_hosts_dict(results_dict: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
ans: Dict[str, Dict[str, Any]] = results_dict.get('per_host_dicts', {}).copy()
h = results_dict['hostname']
hd = {k: v for k, v in results_dict.items() if k != 'per_host_dicts'}
ans[h] = hd
return ans
first_seen_positions: Dict[str, int] = {}
def hostname(val: str, dict_with_parse_results: Optional[Dict[str, Any]] = None) -> str:
if dict_with_parse_results is not None:
ch = dict_with_parse_results['hostname']
if val != ch:
from .parse import create_result_dict
phd = get_per_hosts_dict(dict_with_parse_results)
dict_with_parse_results.clear()
dict_with_parse_results.update(phd.pop(val, create_result_dict()))
dict_with_parse_results['per_host_dicts'] = phd
dict_with_parse_results['hostname'] = val
if val not in first_seen_positions:
first_seen_positions[val] = len(first_seen_positions)
return val

View File

@ -4,9 +4,12 @@
import os
import subprocess
from typing import Any, Dict, List, Sequence, Set, Tuple
import traceback
from contextlib import suppress
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple
from kitty.types import run_once
from kitty.utils import SSHConnectionData
@run_once
@ -94,6 +97,57 @@ def create_shared_memory(data: Any, prefix: str) -> str:
return shm.name
def read_data_from_shared_memory(shm_name: str) -> Any:
import json
import stat
from kitty.shm import SharedMemory
with SharedMemory(shm_name, readonly=True) as shm:
shm.unlink()
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile')
mode = stat.S_IMODE(shm.stats.st_mode)
if mode != stat.S_IREAD | stat.S_IWRITE:
raise ValueError('Incorrect permissions on pwfile')
return json.loads(shm.read_data_with_size())
def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
from base64 import standard_b64decode
yield b'\nKITTY_DATA_START\n' # to discard leading data
try:
msg = standard_b64decode(msg).decode('utf-8')
md = dict(x.split('=', 1) for x in msg.split(':'))
pw = md['pw']
pwfilename = md['pwfile']
rq_id = md['id']
except Exception:
traceback.print_exc()
yield b'invalid ssh data request message\n'
else:
try:
env_data = read_data_from_shared_memory(pwfilename)
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window')
except Exception as e:
traceback.print_exc()
yield f'{e}\n'.encode('utf-8')
else:
yield b'OK\n'
encoded_data = memoryview(env_data['tarfile'].encode('ascii'))
# macOS has a 255 byte limit on its input queue as per man stty.
# Not clear if that applies to canonical mode input as well, but
# better to be safe.
line_sz = 254
while encoded_data:
yield encoded_data[:line_sz]
yield b'\n'
encoded_data = encoded_data[line_sz:]
yield b'KITTY_DATA_END\n'
def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None:
patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)
@ -183,3 +237,86 @@ def set_server_args_in_cmdline(
ans.insert(i, '-t')
break
argv[:] = ans + server_args
def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
port: Optional[int] = None
expecting_port = expecting_identity = False
expecting_option_val = False
expecting_hostname = False
expecting_extra_val = ''
host_name = identity_file = found_ssh = ''
found_extra_args: List[Tuple[str, str]] = []
for i, arg in enumerate(args):
if not found_ssh:
if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
found_ssh = arg
continue
if expecting_hostname:
host_name = arg
continue
if arg.startswith('-') and not expecting_option_val:
if arg in boolean_ssh_args:
continue
if arg == '--':
expecting_hostname = True
if arg.startswith('-p'):
if arg[2:].isdigit():
with suppress(Exception):
port = int(arg[2:])
continue
elif arg == '-p':
expecting_port = True
elif arg.startswith('-i'):
if arg == '-i':
expecting_identity = True
else:
identity_file = arg[2:]
continue
if arg.startswith('--') and extra_args:
matching_ex = is_extra_arg(arg, extra_args)
if matching_ex:
if '=' in arg:
exval = arg.partition('=')[-1]
found_extra_args.append((matching_ex, exval))
continue
expecting_extra_val = matching_ex
expecting_option_val = True
continue
if expecting_option_val:
if expecting_port:
with suppress(Exception):
port = int(arg)
expecting_port = False
elif expecting_identity:
identity_file = arg
elif expecting_extra_val:
found_extra_args.append((expecting_extra_val, arg))
expecting_extra_val = ''
expecting_option_val = False
continue
if not host_name:
host_name = arg
if not host_name:
return None
if host_name.startswith('ssh://'):
from urllib.parse import urlparse
purl = urlparse(host_name)
if purl.hostname:
host_name = purl.hostname
if purl.username:
host_name = f'{purl.username}@{host_name}'
if port is None and purl.port:
port = purl.port
if identity_file:
if not os.path.isabs(identity_file):
identity_file = os.path.expanduser(identity_file)
if not os.path.isabs(identity_file):
identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))

View File

@ -13,7 +13,7 @@ LaunchCLIOptions = AskCLIOptions = ClipboardCLIOptions = DiffCLIOptions = CLIOpt
HintsCLIOptions = IcatCLIOptions = PanelCLIOptions = ResizeCLIOptions = CLIOptions
ErrorCLIOptions = UnicodeCLIOptions = RCOptions = RemoteFileCLIOptions = CLIOptions
QueryTerminalCLIOptions = BroadcastCLIOptions = ShowKeyCLIOptions = CLIOptions
ThemesCLIOptions = TransferCLIOptions = CopyCLIOptions = CLIOptions
ThemesCLIOptions = TransferCLIOptions = CLIOptions
def generate_stub() -> None:
@ -78,9 +78,6 @@ def generate_stub() -> None:
from kittens.transfer.main import option_text as OPTIONS
do(OPTIONS(), 'TransferCLIOptions')
from kittens.ssh.copy import option_text as OPTIONS
do(OPTIONS(), 'CopyCLIOptions')
from kitty.rc.base import all_command_names, command_for_name
for cmd_name in all_command_names():
cmd = command_for_name(cmd_name)

View File

@ -9,7 +9,7 @@ import re
import textwrap
from typing import Any, Callable, Dict, Iterator, List, Set, Tuple, Union, get_type_hints
from kitty.conf.types import Definition, MultiOption, Option, unset
from kitty.conf.types import Definition, MultiOption, Option, ParserFuncType, unset
from kitty.types import _T
@ -442,6 +442,121 @@ def write_output(loc: str, defn: Definition) -> None:
f.write(f'{c}\n')
def go_type_data(parser_func: ParserFuncType, ctype: str) -> Tuple[str, str]:
if ctype:
return f'*{ctype}', f'Parse{ctype}(val)'
p = parser_func.__name__
if p == 'int':
return 'int64', 'strconv.ParseInt(val, 10, 64)'
if p == 'str':
return 'string', 'val, nil'
if p == 'float':
return 'float64', 'strconv.ParseFloat(val, 10, 64)'
if p == 'to_bool':
return 'bool', 'config.StringToBool(val), nil'
th = get_type_hints(parser_func)
rettype = th['return']
return {int: 'int64', str: 'string', float: 'float64'}[rettype], f'{p}(val)'
def gen_go_code(defn: Definition) -> str:
lines = ['import "fmt"', 'import "strconv"', 'import "kitty/tools/config"',
'var _ = fmt.Println', 'var _ = config.StringToBool', 'var _ = strconv.Atoi']
a = lines.append
choices = {}
go_types = {}
go_parsers = {}
defaults = {}
multiopts = {''}
for option in sorted(defn.iter_all_options(), key=lambda a: natural_keys(a.name)):
name = option.name.capitalize()
if isinstance(option, MultiOption):
go_types[name], go_parsers[name] = go_type_data(option.parser_func, option.ctype)
multiopts.add(name)
else:
defaults[name] = option.parser_func(option.defval_as_string)
if option.choices:
choices[name] = option.choices
go_types[name] = f'{name}_Choice_Type'
go_parsers[name] = f'Parse_{name}(val)'
continue
go_types[name], go_parsers[name] = go_type_data(option.parser_func, option.ctype)
for oname in choices:
a(f'type {go_types[oname]} int')
a('type Config struct {')
for name, gotype in go_types.items():
if name in multiopts:
a(f'{name} []{gotype}')
else:
a(f'{name} {gotype}')
a('}')
def cval(x: str) -> str:
return x.replace('-', '_')
a('func NewConfig() *Config {')
a('return &Config{')
for name, pname in go_parsers.items():
if name in multiopts:
continue
d = defaults[name]
if not d:
continue
if isinstance(d, str):
dval = f'{name}_{cval(d)}' if name in choices else f'`{d}`'
elif isinstance(d, bool):
dval = repr(d).lower()
else:
dval = repr(d)
a(f'{name}: {dval},')
a('}''}')
for oname, choice_vals in choices.items():
a('const (')
for i, c in enumerate(choice_vals):
c = cval(c)
if i == 0:
a(f'{oname}_{c} {oname}_Choice_Type = iota')
else:
a(f'{oname}_{c}')
a(')')
a(f'func (x {oname}_Choice_Type) String() string'' {')
a('switch x {')
a('default: return ""')
for c in choice_vals:
a(f'case {oname}_{cval(c)}: return "{c}"')
a('}''}')
a(f'func {go_parsers[oname].split("(")[0]}(val string) (ans {go_types[oname]}, err error) ''{')
a('switch val {')
for c in choice_vals:
a(f'case "{c}": return {oname}_{cval(c)}, nil')
vals = ', '.join(choice_vals)
a(f'default: return ans, fmt.Errorf("%#v is not a valid value for %s. Valid values are: %s", val, "{c}", "{vals}")')
a('}''}')
a('func (c *Config) Parse(key, val string) (err error) {')
a('switch key {')
a('default: return fmt.Errorf("Unknown configuration key: %#v", key)')
for oname, pname in go_parsers.items():
ol = oname.lower()
is_multiple = oname in multiopts
a(f'case "{ol}":')
if is_multiple:
a(f'var temp_val []{go_types[oname]}')
else:
a(f'var temp_val {go_types[oname]}')
a(f'temp_val, err = {pname}')
a(f'if err != nil {{ return fmt.Errorf("Failed to parse {ol} = %#v with error: %w", val, err) }}')
if is_multiple:
a(f'c.{oname} = append(c.{oname}, temp_val...)')
else:
a(f'c.{oname} = temp_val')
a('}')
a('return}')
return '\n'.join(lines)
def main() -> None:
# To use run it as:
# kitty +runpy 'from kitty.conf.generate import main; main()' /path/to/kitten/file.py

View File

@ -854,12 +854,14 @@ def store_multiple(val: str, current_val: Container[str]) -> Iterable[Tuple[str,
yield val, val
allowed_shell_integration_values = frozenset({'enabled', 'disabled', 'no-rc', 'no-cursor', 'no-title', 'no-prompt-mark', 'no-complete', 'no-cwd'})
def shell_integration(x: str) -> FrozenSet[str]:
s = frozenset({'enabled', 'disabled', 'no-rc', 'no-cursor', 'no-title', 'no-prompt-mark', 'no-complete', 'no-cwd'})
q = frozenset(x.lower().split())
if not q.issubset(s):
log_error(f'Invalid shell integration options: {q - s}, ignoring')
return q & s or frozenset({'invalid'})
if not q.issubset(allowed_shell_integration_values):
log_error(f'Invalid shell integration options: {q - allowed_shell_integration_values}, ignoring')
return q & allowed_shell_integration_values or frozenset({'invalid'})
return q

View File

@ -959,7 +959,7 @@ class Window:
def handle_remote_file(self, netloc: str, remote_path: str) -> None:
from kittens.remote_file.main import is_ssh_kitten_sentinel
from kittens.ssh.main import get_connection_data
from kittens.ssh.utils import get_connection_data
from .utils import SSHConnectionData
args = self.ssh_kitten_cmdline()
@ -1156,7 +1156,7 @@ class Window:
self.write_to_child(data)
def handle_remote_ssh(self, msg: str) -> None:
from kittens.ssh.main import get_ssh_data
from kittens.ssh.utils import get_ssh_data
for line in get_ssh_data(msg, f'{os.getpid()}-{self.id}'):
self.write_to_child(line)

View File

@ -112,7 +112,7 @@ class Callbacks:
self.current_clone_data += rest
def handle_remote_ssh(self, msg):
from kittens.ssh.main import get_ssh_data
from kittens.ssh.utils import get_ssh_data
if self.pty:
for line in get_ssh_data(msg, "testing"):
self.pty.write_to_child(line)

View File

@ -67,7 +67,7 @@ class TestBuild(BaseTest):
q = stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
return mode & q == q
for x in ('kitty', 'kitten', 'askpass.py'):
for x in ('kitty', 'kitten'):
x = os.path.join(shell_integration_dir, 'ssh', x)
self.assertTrue(is_executable(x), f'{x} is not executable')
if getattr(sys, 'frozen', False):

30
kitty_tests/shm.py Normal file
View File

@ -0,0 +1,30 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2023, Kovid Goyal <kovid at kovidgoyal.net>
import os
import subprocess
from kitty.constants import kitten_exe
from kitty.fast_data_types import shm_unlink
from kitty.shm import SharedMemory
from . import BaseTest
class SHMTest(BaseTest):
def test_shm_with_kitten(self):
data = os.urandom(333)
with SharedMemory(size=363) as shm:
shm.write_data_with_size(data)
cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'read', shm.name], stdout=subprocess.PIPE)
self.assertEqual(cp.returncode, 0)
self.assertEqual(cp.stdout, data)
self.assertRaises(FileNotFoundError, shm_unlink, shm.name)
cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'write'], input=data, stdout=subprocess.PIPE)
self.assertEqual(cp.returncode, 0)
name = cp.stdout.decode().strip()
with SharedMemory(name=name, unlink_on_exit=True) as shm:
q = shm.read_data_with_size()
self.assertEqual(data, q)

View File

@ -3,18 +3,16 @@
import glob
import json
import os
import shutil
import subprocess
import tempfile
from contextlib import suppress
from functools import lru_cache
from kittens.ssh.config import load_config
from kittens.ssh.main import bootstrap_script, get_connection_data, wrap_bootstrap_script
from kittens.ssh.options.types import Options as SSHOptions
from kittens.ssh.options.utils import DELETE_ENV_VAR
from kittens.transfer.utils import set_paths
from kitty.constants import is_macos, runtime_dir
from kittens.ssh.utils import get_connection_data
from kitty.constants import is_macos, kitten_exe, runtime_dir
from kitty.fast_data_types import CURSOR_BEAM, shm_unlink
from kitty.utils import SSHConnectionData
@ -58,32 +56,6 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
t('ssh --kitten=one -p 12 --kitten two -ix main', identity_file='x', port=12, extra_args=(('--kitten', 'one'), ('--kitten', 'two')))
self.assertTrue(runtime_dir())
def test_ssh_config_parsing(self):
def parse(conf, hostname='unmatched_host', username=''):
return load_config(overrides=conf.splitlines(), hostname=hostname, username=username)
self.ae(parse('').env, {})
self.ae(parse('env a=b').env, {'a': 'b'})
conf = 'env a=b\nhostname 2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'})
self.ae(parse('env a=').env, {'a': ''})
self.ae(parse('env a').env, {'a': '_delete_this_env_var_'})
conf = 'env a=b\nhostname test@2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '2').env, {'a': 'b'})
self.ae(parse(conf, '2', 'test').env, {'a': 'c', 'b': 'b'})
conf = 'env a=b\nhostname 1 2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '1').env, {'a': 'c', 'b': 'b'})
self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'})
def test_ssh_bootstrap_sh_cmd_limit(self):
# dropbear has a 9000 bytes maximum command length limit
sh_script, _, _ = bootstrap_script(SSHOptions({'interpreter': 'sh'}), script_type='sh', remote_args=[], request_id='123-123')
rcmd = wrap_bootstrap_script(sh_script, 'sh')
self.assertLessEqual(sum(len(x) for x in rcmd), 9000)
@property
@lru_cache()
def all_possible_sh(self):
@ -98,11 +70,13 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
f.write(simple_data)
for sh in self.all_possible_sh:
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as remote_home, tempfile.TemporaryDirectory() as local_home, set_paths(home=local_home):
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as remote_home, tempfile.TemporaryDirectory() as local_home:
tuple(map(touch, 'simple-file g.1 g.2'.split()))
os.makedirs(f'{local_home}/d1/d2/d3')
touch('d1/d2/x')
touch('d1/d2/w.exclude')
os.mkdir(f'{local_home}/d1/r')
touch('d1/r/noooo')
os.symlink('d2/x', f'{local_home}/d1/y')
os.symlink('simple-file', f'{local_home}/s1')
os.symlink('simple-file', f'{local_home}/s2')
@ -110,15 +84,13 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
conf = '''\
copy simple-file
copy s1
copy --symlink-strategy=keep-name s2
copy --symlink-strategy=keep-path s2
copy --dest=a/sfa simple-file
copy --glob g.*
copy --exclude */w.* d1
copy --exclude **/w.* --exclude **/r d1
'''
copy = load_config(overrides=filter(None, conf.splitlines())).copy
self.check_bootstrap(
sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='',
ssh_opts={'copy': copy}
sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', conf=conf, home=local_home,
)
tname = '.terminfo'
if os.path.exists('/usr/share/misc/terminfo.cdb'):
@ -148,17 +120,18 @@ copy --exclude */w.* d1
self.ae(len(glob.glob(f'{remote_home}/{tname}/*/xterm-kitty')), 2)
def test_ssh_env_vars(self):
tset = '$A-$(echo no)-`echo no2` !Q5 "something\nelse"'
tset = '$A-$(echo no)-`echo no2` !Q5 "something else"'
for sh in self.all_possible_sh:
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
os.mkdir(os.path.join(tdir, 'cwd'))
conf = f'''
cwd $HOME/cwd
env A=AAA
env TSET={tset}
env COLORTERM
'''
pty = self.check_bootstrap(
sh, tdir, test_script='env; pwd; exit 0', SHELL_INTEGRATION_VALUE='',
ssh_opts={'cwd': '$HOME/cwd', 'env': {
'A': 'AAA',
'TSET': tset,
'COLORTERM': DELETE_ENV_VAR,
}}
sh, tdir, test_script='env; pwd; exit 0', SHELL_INTEGRATION_VALUE='', conf=conf
)
pty.wait_till(lambda: 'TSET={}'.format(tset.replace('$A', 'AAA')) in pty.screen_contents())
self.assertNotIn('COLORTERM', pty.screen_contents())
@ -240,34 +213,34 @@ copy --exclude */w.* d1
self.assertEqual(pty.screen.cursor.shape, 0)
self.assertNotIn(b'\x1b]133;', pty.received_bytes)
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None, launcher='sh'):
ssh_opts = ssh_opts or {}
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', conf='', launcher='sh', home=''):
if login_shell:
ssh_opts['login_shell'] = login_shell
conf += f'\nlogin_shell {login_shell}'
if 'python' in sh:
if test_script.startswith('env;'):
test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})'
test_script = f'print("UNTAR_DONE", flush=True); {test_script}'
else:
test_script = f'echo "UNTAR_DONE"; {test_script}'
ssh_opts['shell_integration'] = SHELL_INTEGRATION_VALUE or 'disabled'
script, replacements, shm_name = bootstrap_script(
SSHOptions(ssh_opts), script_type='py' if 'python' in sh else 'sh', request_id="testing", test_script=test_script,
request_data=True
)
conf += '\nshell_integration ' + (SHELL_INTEGRATION_VALUE or 'disabled')
conf += '\ninterpreter ' + sh
env = os.environ.copy()
if home:
env['HOME'] = home
cp = subprocess.run([kitten_exe(), '__pytest__', 'ssh', test_script], env=env, stdout=subprocess.PIPE, input=conf.encode('utf-8'))
self.assertEqual(cp.returncode, 0)
self.rdata = json.loads(cp.stdout)
del cp
try:
env = basic_shell_env(home_dir)
# Avoid generating unneeded completion scripts
os.makedirs(os.path.join(home_dir, '.local', 'share', 'fish', 'generated_completions'), exist_ok=True)
# prevent newuser-install from running
open(os.path.join(home_dir, '.zshrc'), 'w').close()
cmd = wrap_bootstrap_script(script, sh)
pty = self.create_pty([launcher, '-c', ' '.join(cmd)], cwd=home_dir, env=env)
pty = self.create_pty([launcher, '-c', ' '.join(self.rdata['cmd'])], cwd=home_dir, env=env)
pty.turn_off_echo()
del cmd
if pre_data:
pty.write_buf = pre_data.encode('utf-8')
del script
def check_untar_or_fail():
q = pty.screen_contents()
@ -284,4 +257,4 @@ copy --exclude */w.* d1
return pty
finally:
with suppress(FileNotFoundError):
shm_unlink(shm_name)
shm_unlink(self.rdata['shm_name'])

View File

@ -1,5 +1,5 @@
[tool.mypy]
files = 'kitty,kittens,glfw,*.py,docs/conf.py,shell-integration/ssh/askpass.py'
files = 'kitty,kittens,glfw,*.py,docs/conf.py'
no_implicit_optional = true
sqlite_cache = true
cache_fine_grained = true

View File

@ -1459,7 +1459,7 @@ def package(args: Options, bundle_type: str) -> None:
if path.endswith('.so'):
return True
q = path.split(os.sep)[-2:]
if len(q) == 2 and q[0] == 'ssh' and q[1] in ('askpass.py', 'kitty', 'kitten'):
if len(q) == 2 and q[0] == 'ssh' and q[1] in ('kitty', 'kitten'):
return True
return False

View File

@ -1,46 +0,0 @@
#!/usr/bin/env -S kitty +launch
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import json
import os
import sys
import time
from kitty.shm import SharedMemory
msg = sys.argv[-1]
prompt = os.environ.get('SSH_ASKPASS_PROMPT', '')
is_confirm = prompt == 'confirm'
is_fingerprint_check = '(yes/no/[fingerprint])' in msg
q = {
'message': msg,
'type': 'confirm' if is_confirm else 'get_line',
'is_password': not is_fingerprint_check,
}
data = json.dumps(q)
with SharedMemory(
size=len(data) + 1 + SharedMemory.num_bytes_for_size, unlink_on_exit=True, prefix=f'askpass-{os.getpid()}-') as shm, \
open(os.ctermid(), 'wb') as tty:
shm.write(b'\0')
shm.write_data_with_size(data)
shm.flush()
with open(os.ctermid(), 'wb') as f:
f.write(f'\x1bP@kitty-ask|{shm.name}\x1b\\'.encode('ascii'))
f.flush()
while True:
# TODO: Replace sleep() with a mutex and condition variable created in the shared memory
time.sleep(0.05)
shm.seek(0)
if shm.read(1) == b'\x01':
break
response = json.loads(shm.read_data_with_size())
if is_confirm:
response = 'yes' if response else 'no'
elif is_fingerprint_check:
if response.lower() in ('y', 'yes'):
response = 'yes'
if response.lower() in ('n', 'no'):
response = 'no'
if response:
print(response, flush=True)

View File

@ -24,7 +24,7 @@ exec_kitty() {
is_wrapped_kitten() {
wrapped_kittens="clipboard icat unicode_input"
wrapped_kittens="clipboard icat unicode_input ssh"
[ -n "$1" ] && {
case " $wrapped_kittens " in
*" $1 "*) printf "%s" "$1" ;;

View File

@ -33,9 +33,11 @@ type Command struct {
ArgCompleter CompletionFunc
// Stop completion processing at this arg num
StopCompletingAtArg int
// Consider all args as non-options args
// Consider all args as non-options args when parsing for completion
OnlyArgsAllowed bool
// Specialised arg aprsing
// Pass through all args, useful for wrapper commands
IgnoreAllArgs bool
// Specialised arg parsing
ParseArgsForCompletion func(cmd *Command, args []string, completions *Completions)
SubCommandGroups []*CommandGroup

View File

@ -13,6 +13,10 @@ func (self *Command) parse_args(ctx *Context, args []string) error {
args_to_parse := make([]string, len(args))
copy(args_to_parse, args)
ctx.SeenCommands = append(ctx.SeenCommands, self)
if self.IgnoreAllArgs {
self.Args = args
return nil
}
var expecting_arg_for *Option
options_allowed := true

View File

@ -66,8 +66,8 @@ func complete_plus_open(completions *cli.Completions, word string, arg_num int)
}
func complete_themes(completions *cli.Completions, word string, arg_num int) {
kitty, err := utils.KittyExe()
if err == nil {
kitty := utils.KittyExe()
if kitty != "" {
out, err := exec.Command(kitty, "+runpy", "from kittens.themes.collection import *; print_theme_names()").Output()
if err == nil {
mg := completions.AddMatchGroup("Themes")

View File

@ -14,7 +14,6 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"kitty/tools/tui/graphics"
"kitty/tools/utils"
@ -24,12 +23,13 @@ import (
var _ = fmt.Print
var find_exe_lock sync.Once
var magick_exe string = ""
func find_magick_exe() {
magick_exe = utils.Which("magick")
}
var MagickExe = (&utils.Once[string]{Run: func() string {
ans := utils.Which("magick")
if ans == "" {
ans = utils.Which("magick", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin", "/usr/sbin", "/sbin")
}
return ans
}}).Get
func run_magick(path string, cmd []string) ([]byte, error) {
c := exec.Command(cmd[0], cmd[1:]...)
@ -154,10 +154,9 @@ func parse_identify_record(ans *IdentifyRecord, raw *IdentifyOutput) (err error)
}
func Identify(path string) (ans []IdentifyRecord, err error) {
find_exe_lock.Do(find_magick_exe)
cmd := []string{"identify"}
if magick_exe != "" {
cmd = []string{magick_exe, cmd[0]}
if MagickExe() != "" {
cmd = []string{MagickExe(), cmd[0]}
}
q := `{"fmt":"%m","canvas":"%g","transparency":"%A","gap":"%T","index":"%p","size":"%wx%h",` +
`"dpi":"%xx%y","dispose":"%D","orientation":"%[EXIF:Orientation]"},`
@ -227,10 +226,9 @@ func check_resize(frame *image_frame) error {
}
func Render(path string, ro *RenderOptions, frames []IdentifyRecord) (ans []*image_frame, err error) {
find_exe_lock.Do(find_magick_exe)
cmd := []string{"convert"}
if magick_exe != "" {
cmd = []string{magick_exe, cmd[0]}
if MagickExe() != "" {
cmd = []string{MagickExe(), cmd[0]}
}
ans = make([]*image_frame, 0, len(frames))
defer func() {

View File

@ -3,12 +3,22 @@
package main
import (
"os"
"kitty/tools/cli"
"kitty/tools/cmd/completion"
"kitty/tools/cmd/ssh"
"kitty/tools/cmd/tool"
)
func main() {
krm := os.Getenv("KITTY_KITTEN_RUN_MODULE")
os.Unsetenv("KITTY_KITTEN_RUN_MODULE")
switch krm {
case "ssh_askpass":
ssh.RunSSHAskpass()
return
}
root := cli.NewRootCommand()
root.ShortDescription = "Fast, statically compiled implementations for various kittens (command line tools for use with kitty)"
root.Usage = "command [command options] [command args]"

22
tools/cmd/pytest/main.go Normal file
View File

@ -0,0 +1,22 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package pytest
import (
"fmt"
"kitty/tools/cli"
"kitty/tools/cmd/ssh"
"kitty/tools/utils/shm"
)
var _ = fmt.Print
func EntryPoint(root *cli.Command) {
root = root.AddSubCommand(&cli.Command{
Name: "__pytest__",
Hidden: true,
})
shm.TestEntryPoint(root)
ssh.TestEntryPoint(root)
}

118
tools/cmd/ssh/askpass.go Normal file
View File

@ -0,0 +1,118 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"encoding/binary"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"kitty/tools/cli"
"kitty/tools/tty"
"kitty/tools/utils/shm"
)
var _ = fmt.Print
func fatal(err error) {
cli.ShowError(err)
os.Exit(1)
}
func trigger_ask(name string) {
term, err := tty.OpenControllingTerm()
if err != nil {
fatal(err)
}
defer term.Close()
_, err = term.WriteString("\x1bP@kitty-ask|" + name + "\x1b\\")
if err != nil {
fatal(err)
}
}
func RunSSHAskpass() {
msg := os.Args[len(os.Args)-1]
prompt := os.Getenv("SSH_ASKPASS_PROMPT")
is_confirm := prompt == "confirm"
q_type := "get_line"
if is_confirm {
q_type = "confirm"
}
is_fingerprint_check := strings.Contains(msg, "(yes/no/[fingerprint])")
q := map[string]any{
"message": msg,
"type": q_type,
"is_password": !is_fingerprint_check,
}
data, err := json.Marshal(q)
if err != nil {
fatal(err)
}
shm, err := shm.CreateTemp("askpass-*", uint64(len(data)+32))
if err != nil {
fatal(fmt.Errorf("Failed to create SHM file with error: %w", err))
}
defer shm.Close()
defer shm.Unlink()
shm.Slice()[0] = 0
binary.BigEndian.PutUint32(shm.Slice()[1:], uint32(len(data)))
copy(shm.Slice()[5:], data)
err = shm.Flush()
if err != nil {
fatal(fmt.Errorf("Failed to flush SHM file with error: %w", err))
}
trigger_ask(shm.Name())
buf := []byte{0}
for {
time.Sleep(50 * time.Millisecond)
_, err = shm.Seek(0, os.SEEK_SET)
if err != nil {
fatal(fmt.Errorf("Failed to seek into SHM file while waiting for response with error: %w", err))
}
_, err = shm.Read(buf)
if err != nil {
fatal(fmt.Errorf("Failed to read from SHM file while waiting for response with error: %w", err))
}
if buf[0] == 1 {
break
}
}
data, err = shm.ReadWithSize()
if err != nil {
fatal(fmt.Errorf("Failed to read response data from SHM file with error: %w", err))
}
response := ""
if is_confirm {
var ok bool
err = json.Unmarshal(data, &ok)
if err != nil {
fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err))
}
response = "no"
if ok {
response = "yes"
}
} else {
err = json.Unmarshal(data, &response)
if err != nil {
fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err))
}
if is_fingerprint_check {
response = strings.ToLower(response)
if response == "y" {
response = "yes"
} else if response == "n" {
response = "no"
}
}
}
if response != "" {
fmt.Println(response)
}
}

409
tools/cmd/ssh/config.go Normal file
View File

@ -0,0 +1,409 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"archive/tar"
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path"
"path/filepath"
"strings"
"time"
"kitty/tools/config"
"kitty/tools/utils"
"kitty/tools/utils/paths"
"kitty/tools/utils/shlex"
"github.com/bmatcuk/doublestar"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
type EnvInstruction struct {
key, val string
delete_on_remote, copy_from_local, literal_quote bool
}
func quote_for_sh(val string, literal_quote bool) string {
if literal_quote {
return utils.QuoteStringForSH(val)
}
// See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
b := strings.Builder{}
b.Grow(len(val) + 16)
b.WriteRune('"')
runes := []rune(val)
for i, ch := range runes {
if ch == '\\' || ch == '`' || ch == '"' || (ch == '$' && i+1 < len(runes) && runes[i+1] == '(') {
// special chars are escaped
// $( is escaped to prevent execution
b.WriteRune('\\')
}
b.WriteRune(ch)
}
b.WriteRune('"')
return b.String()
}
func (self *EnvInstruction) Serialize(for_python bool, get_local_env func(string) (string, bool)) string {
var unset func() string
var export func(string) string
if for_python {
dumps := func(x ...any) string {
ans, _ := json.Marshal(x)
return utils.UnsafeBytesToString(ans)
}
export = func(val string) string {
if val == "" {
return fmt.Sprintf("export %s", dumps(self.key))
}
return fmt.Sprintf("export %s", dumps(self.key, val, self.literal_quote))
}
unset = func() string {
return fmt.Sprintf("unset %s", dumps(self.key))
}
} else {
kq := utils.QuoteStringForSH(self.key)
unset = func() string {
return fmt.Sprintf("unset %s", kq)
}
export = func(val string) string {
return fmt.Sprintf("export %s=%s", kq, quote_for_sh(val, self.literal_quote))
}
}
if self.delete_on_remote {
return unset()
}
if self.copy_from_local {
val, found := get_local_env(self.key)
if !found {
return ""
}
return export(val)
}
return export(self.val)
}
func final_env_instructions(for_python bool, get_local_env func(string) (string, bool), env ...*EnvInstruction) string {
seen := make(map[string]int, len(env))
ans := make([]string, 0, len(env))
for _, ei := range env {
q := ei.Serialize(for_python, get_local_env)
if q != "" {
if pos, found := seen[ei.key]; found {
ans[pos] = q
} else {
seen[ei.key] = len(ans)
ans = append(ans, q)
}
}
}
return strings.Join(ans, "\n")
}
type CopyInstruction struct {
local_path, arcname string
exclude_patterns []string
}
func ParseEnvInstruction(spec string) (ans []*EnvInstruction, err error) {
const COPY_FROM_LOCAL string = "_kitty_copy_env_var_"
ei := &EnvInstruction{}
found := false
ei.key, ei.val, found = strings.Cut(spec, "=")
ei.key = strings.TrimSpace(ei.key)
if found {
ei.val = strings.TrimSpace(ei.val)
if ei.val == COPY_FROM_LOCAL {
ei.val = ""
ei.copy_from_local = true
}
} else {
ei.delete_on_remote = true
}
if ei.key == "" {
err = fmt.Errorf("The env directive must not be empty")
}
ans = []*EnvInstruction{ei}
return
}
var paths_ctx *paths.Ctx
func resolve_file_spec(spec string, is_glob bool) ([]string, error) {
if paths_ctx == nil {
paths_ctx = &paths.Ctx{}
}
ans := os.ExpandEnv(paths_ctx.ExpandHome(spec))
if !filepath.IsAbs(ans) {
ans = paths_ctx.AbspathFromHome(ans)
}
if is_glob {
files, err := doublestar.Glob(ans)
if err != nil {
return nil, fmt.Errorf("%s is not a valid glob pattern with error: %w", spec, err)
}
if len(files) == 0 {
return nil, fmt.Errorf("%s matches no files", spec)
}
return files, nil
}
err := unix.Access(ans, unix.R_OK)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("%s does not exist", spec)
}
return nil, fmt.Errorf("Cannot read from: %s with error: %w", spec, err)
}
return []string{ans}, nil
}
func get_arcname(loc, dest, home string) (arcname string) {
if dest != "" {
arcname = dest
} else {
arcname = filepath.Clean(loc)
if filepath.HasPrefix(arcname, home) {
ra, err := filepath.Rel(home, arcname)
if err == nil {
arcname = ra
}
}
}
prefix := "home/"
if strings.HasPrefix(arcname, "/") {
prefix = "root"
}
return prefix + arcname
}
func ParseCopyInstruction(spec string) (ans []*CopyInstruction, err error) {
args, err := shlex.Split("copy " + spec)
if err != nil {
return nil, err
}
opts, args, err := parse_copy_args(args)
if err != nil {
return nil, err
}
locations := make([]string, 0, len(args))
for _, arg := range args {
locs, err := resolve_file_spec(arg, opts.Glob)
if err != nil {
return nil, err
}
locations = append(locations, locs...)
}
if len(locations) == 0 {
return nil, fmt.Errorf("No files to copy specified")
}
if len(locations) > 1 && opts.Dest != "" {
return nil, fmt.Errorf("Specifying a remote location with more than one file is not supported")
}
home := paths_ctx.HomePath()
ans = make([]*CopyInstruction, 0, len(locations))
for _, loc := range locations {
ci := CopyInstruction{local_path: loc, exclude_patterns: opts.Exclude}
if opts.SymlinkStrategy != "preserve" {
ci.local_path, err = filepath.EvalSymlinks(loc)
if err != nil {
return nil, fmt.Errorf("Failed to resolve symlinks in %#v with error: %w", loc, err)
}
}
if opts.SymlinkStrategy == "resolve" {
ci.arcname = get_arcname(ci.local_path, opts.Dest, home)
} else {
ci.arcname = get_arcname(loc, opts.Dest, home)
}
ans = append(ans, &ci)
}
return
}
type file_unique_id struct {
dev, inode uint64
}
func excluded(pattern, path string) bool {
if !strings.ContainsRune(pattern, '/') {
path = filepath.Base(path)
}
if matched, err := doublestar.PathMatch(pattern, path); matched && err == nil {
return true
}
return false
}
func get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string, local_path, arcname string, exclude_patterns []string) error {
s, err := os.Lstat(local_path)
if err != nil {
return err
}
u, ok := s.Sys().(unix.Stat_t)
cb := func(h *tar.Header, data []byte, arcname string) error {
h.Name = arcname
if h.Typeflag == tar.TypeDir {
h.Name = strings.TrimRight(h.Name, "/") + "/"
}
h.Size = int64(len(data))
h.Mode = int64(s.Mode().Perm())
h.ModTime = s.ModTime()
h.Format = tar.FormatPAX
if ok {
h.AccessTime = time.Unix(0, u.Atim.Nano())
h.ChangeTime = time.Unix(0, u.Ctim.Nano())
}
return callback(h, data)
}
// we only copy regular files, directories and symlinks
switch s.Mode().Type() {
case fs.ModeSymlink:
target, err := os.Readlink(local_path)
if err != nil {
return err
}
err = cb(&tar.Header{
Typeflag: tar.TypeSymlink,
Linkname: target,
}, nil, arcname)
if err != nil {
return err
}
case fs.ModeDir:
local_path = filepath.Clean(local_path)
type entry struct {
path, arcname string
}
stack := []entry{{local_path, arcname}}
for len(stack) > 0 {
x := stack[0]
stack = stack[1:]
entries, err := os.ReadDir(x.path)
if err != nil {
if x.path == local_path {
return err
}
continue
}
err = cb(&tar.Header{Typeflag: tar.TypeDir}, nil, x.arcname)
if err != nil {
return err
}
for _, e := range entries {
entry_path := filepath.Join(x.path, e.Name())
aname := path.Join(x.arcname, e.Name())
ok := true
for _, pat := range exclude_patterns {
if excluded(pat, entry_path) {
ok = false
break
}
}
if !ok {
continue
}
if e.IsDir() {
stack = append(stack, entry{entry_path, aname})
} else {
err = get_file_data(callback, seen, entry_path, aname, exclude_patterns)
if err != nil {
return err
}
}
}
}
case 0: // Regular file
fid := file_unique_id{dev: u.Dev, inode: u.Ino}
if prev, ok := seen[fid]; ok { // Hard link
err = cb(&tar.Header{Typeflag: tar.TypeLink, Linkname: prev}, nil, arcname)
if err != nil {
return err
}
}
seen[fid] = arcname
data, err := os.ReadFile(local_path)
if err != nil {
return err
}
err = cb(&tar.Header{Typeflag: tar.TypeReg}, data, arcname)
if err != nil {
return err
}
}
return nil
}
func (ci *CopyInstruction) get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string) (err error) {
ep := ci.exclude_patterns
for _, folder_name := range []string{"__pycache__", ".DS_Store"} {
ep = append(ep, "**/"+folder_name, "**/"+folder_name+"/**")
}
return get_file_data(callback, seen, ci.local_path, ci.arcname, ep)
}
type ConfigSet struct {
all_configs []*Config
}
func config_for_hostname(hostname_to_match, username_to_match string, cs *ConfigSet) *Config {
matcher := func(q *Config) bool {
for _, pat := range strings.Split(q.Hostname, " ") {
upat := "*"
if strings.Contains(pat, "@") {
upat, pat, _ = strings.Cut(pat, "@")
}
var host_matched, user_matched bool
if matched, err := filepath.Match(pat, hostname_to_match); matched && err == nil {
host_matched = true
}
if matched, err := filepath.Match(upat, username_to_match); matched && err == nil {
user_matched = true
}
if host_matched && user_matched {
return true
}
}
return false
}
for _, c := range utils.Reversed(cs.all_configs) {
if matcher(c) {
return c
}
}
return cs.all_configs[0]
}
func (self *ConfigSet) line_handler(key, val string) error {
c := self.all_configs[len(self.all_configs)-1]
if key == "hostname" {
c = NewConfig()
self.all_configs = append(self.all_configs, c)
}
return c.Parse(key, val)
}
func load_config(hostname_to_match string, username_to_match string, overrides []string, paths ...string) (*Config, []config.ConfigLine, error) {
ans := &ConfigSet{all_configs: []*Config{NewConfig()}}
p := config.ConfigParser{LineHandler: ans.line_handler}
if len(paths) == 0 {
paths = []string{filepath.Join(utils.ConfigDir(), "ssh.conf")}
}
paths = utils.Filter(paths, func(x string) bool { return x != "" })
err := p.ParseFiles(paths...)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, nil, err
}
if len(overrides) > 0 {
err = p.ParseOverrides(overrides...)
if err != nil {
return nil, nil, err
}
}
return config_for_hostname(hostname_to_match, username_to_match, ans), p.BadLines(), nil
}

View File

@ -0,0 +1,110 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"fmt"
"kitty/tools/utils"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestSSHConfigParsing(t *testing.T) {
tdir := t.TempDir()
hostname := "unmatched"
username := ""
conf := ""
for_python := false
cf := filepath.Join(tdir, "ssh.conf")
rt := func(expected_env ...string) {
os.WriteFile(cf, []byte(conf), 0o600)
c, bad_lines, err := load_config(hostname, username, nil, cf)
if err != nil {
t.Fatal(err)
}
if len(bad_lines) != 0 {
t.Fatalf("Bad config line: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err)
}
actual := final_env_instructions(for_python, func(key string) (string, bool) {
if key == "LOCAL_ENV" {
return "LOCAL_VAL", true
}
return "", false
}, c.Env...)
if expected_env == nil {
expected_env = []string{}
}
diff := cmp.Diff(expected_env, utils.Splitlines(actual))
if diff != "" {
t.Fatalf("Unexpected env for\nhostname: %#v\nusername: %#v\nconf: %s\n%s", hostname, username, conf, diff)
}
}
rt()
conf = "env a=b"
rt(`export 'a'="b"`)
conf = "env a=b\nhostname 2\nenv a=c\nenv b=b"
rt(`export 'a'="b"`)
hostname = "2"
rt(`export 'a'="c"`, `export 'b'="b"`)
conf = "env a="
rt(`export 'a'=""`)
conf = "env a"
rt(`unset 'a'`)
conf = "env a=b\nhostname test@2\nenv a=c\nenv b=b"
hostname = "unmatched"
rt(`export 'a'="b"`)
hostname = "2"
rt(`export 'a'="b"`)
username = "test"
rt(`export 'a'="c"`, `export 'b'="b"`)
conf = "env a=b\nhostname 1 2\nenv a=c\nenv b=b"
username = ""
hostname = "unmatched"
rt(`export 'a'="b"`)
hostname = "1"
rt(`export 'a'="c"`, `export 'b'="b"`)
hostname = "2"
rt(`export 'a'="c"`, `export 'b'="b"`)
for_python = true
rt(`export ["a","c",false]`, `export ["b","b",false]`)
conf = "env a="
rt(`export ["a"]`)
conf = "env a"
rt(`unset ["a"]`)
conf = "env LOCAL_ENV=_kitty_copy_env_var_"
rt(`export ["LOCAL_ENV","LOCAL_VAL",false]`)
ci, err := ParseCopyInstruction("--exclude moose --dest=target " + cf)
if err != nil {
t.Fatal(err)
}
diff := cmp.Diff("home/target", ci[0].arcname)
if diff != "" {
t.Fatalf("Incorrect arcname:\n%s", diff)
}
diff = cmp.Diff(cf, ci[0].local_path)
if diff != "" {
t.Fatalf("Incorrect local_path:\n%s", diff)
}
diff = cmp.Diff([]string{"moose"}, ci[0].exclude_patterns)
if diff != "" {
t.Fatalf("Incorrect excludes:\n%s", diff)
}
ci, err = ParseCopyInstruction("--glob " + filepath.Join(filepath.Dir(cf), "*.conf"))
if err != nil {
t.Fatal(err)
}
diff = cmp.Diff(cf, ci[0].local_path)
if diff != "" {
t.Fatalf("Incorrect local_path:\n%s", diff)
}
if len(ci) != 1 {
t.Fatal(ci)
}
}

69
tools/cmd/ssh/data.go Normal file
View File

@ -0,0 +1,69 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"archive/tar"
_ "embed"
"errors"
"fmt"
"io"
"kitty/tools/utils"
"regexp"
"strings"
)
var _ = fmt.Print
//go:embed data_generated.bin
var embedded_data string
type Entry struct {
metadata *tar.Header
data []byte
}
type Container map[string]Entry
var Data = (&utils.Once[Container]{Run: func() Container {
tr := tar.NewReader(utils.ReaderForCompressedEmbeddedData(embedded_data))
ans := make(Container, 64)
for {
hdr, err := tr.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
panic(err)
}
data, err := utils.ReadAll(tr, int(hdr.Size))
if err != nil {
panic(err)
}
ans[hdr.Name] = Entry{hdr, data}
}
return ans
}}).Get
func (self Container) files_matching(prefix string, exclude_patterns ...string) []string {
ans := make([]string, 0, len(self))
patterns := make([]*regexp.Regexp, len(exclude_patterns))
for i, exp := range exclude_patterns {
patterns[i] = regexp.MustCompile(exp)
}
for name := range self {
if strings.HasPrefix(name, prefix) {
excluded := false
for _, pat := range patterns {
if matched := pat.FindString(name); matched != "" {
excluded = true
break
}
}
if !excluded {
ans = append(ans, name)
}
}
}
return ans
}

801
tools/cmd/ssh/main.go Normal file
View File

@ -0,0 +1,801 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"archive/tar"
"bytes"
"compress/gzip"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"kitty"
"net/url"
"os"
"os/exec"
"os/user"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"kitty/tools/cli"
"kitty/tools/themes"
"kitty/tools/tty"
"kitty/tools/tui"
"kitty/tools/tui/loop"
"kitty/tools/utils"
"kitty/tools/utils/secrets"
"kitty/tools/utils/shm"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
func get_destination(hostname string) (username, hostname_for_match string) {
u, err := user.Current()
if err == nil {
username = u.Username
}
hostname_for_match = hostname
if strings.HasPrefix(hostname, "ssh://") {
p, err := url.Parse(hostname)
if err == nil {
hostname_for_match = p.Hostname()
if p.User.Username() != "" {
username = p.User.Username()
}
}
} else if strings.Contains(hostname, "@") && hostname[0] != '@' {
username, hostname_for_match, _ = strings.Cut(hostname, "@")
}
if strings.Contains(hostname, "@") && hostname[0] != '@' {
_, hostname_for_match, _ = strings.Cut(hostname_for_match, "@")
}
hostname_for_match, _, _ = strings.Cut(hostname_for_match, ":")
return
}
func read_data_from_shared_memory(shm_name string) ([]byte, error) {
data, err := shm.ReadWithSizeAndUnlink(shm_name, func(f *os.File) error {
s, err := f.Stat()
if err != nil {
return fmt.Errorf("Failed to stat SHM file with error: %w", err)
}
if stat, ok := s.Sys().(unix.Stat_t); ok {
if os.Getuid() != int(stat.Uid) || os.Getgid() != int(stat.Gid) {
return fmt.Errorf("Incorrect owner on SHM file")
}
}
if s.Mode().Perm() != 0o600 {
return fmt.Errorf("Incorrect permissions on SHM file")
}
return nil
})
return data, err
}
func add_cloned_env(val string) (ans map[string]string, err error) {
data, err := read_data_from_shared_memory(val)
if err != nil {
return nil, err
}
err = json.Unmarshal(data, &ans)
return ans, err
}
func parse_kitten_args(found_extra_args []string, username, hostname_for_match string) (overrides []string, literal_env map[string]string, ferr error) {
literal_env = make(map[string]string)
overrides = make([]string, 0, 4)
for i, a := range found_extra_args {
if i%2 == 0 {
continue
}
if key, val, found := strings.Cut(a, "="); found {
if key == "clone_env" {
le, err := add_cloned_env(val)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return nil, nil, ferr
}
} else if le != nil {
literal_env = le
}
} else if key != "hostname" {
overrides = append(overrides, key+" "+val)
}
}
}
if len(overrides) > 0 {
overrides = append([]string{"hostname " + username + "@" + hostname_for_match}, overrides...)
}
return
}
func connection_sharing_args(kitty_pid int) ([]string, error) {
rd := utils.RuntimeDir()
// Bloody OpenSSH generates a 40 char hash and in creating the socket
// appends a 27 char temp suffix to it. Socket max path length is approx
// ~104 chars. And on idiotic Apple the path length to the runtime dir
// (technically the cache dir since Apple has no runtime dir and thinks it's
// a great idea to delete files in /tmp) is ~48 chars.
if len(rd) > 35 {
idiotic_design := fmt.Sprintf("/tmp/kssh-rdir-%d", os.Geteuid())
if err := utils.AtomicCreateSymlink(rd, idiotic_design); err != nil {
return nil, err
}
rd = idiotic_design
}
cp := strings.Replace(kitty.SSHControlMasterTemplate, "{kitty_pid}", strconv.Itoa(kitty_pid), 1)
cp = strings.Replace(cp, "{ssh_placeholder}", "%C", 1)
return []string{
"-o", "ControlMaster=auto",
"-o", "ControlPath=" + filepath.Join(rd, cp),
"-o", "ControlPersist=yes",
"-o", "ServerAliveInterval=60",
"-o", "ServerAliveCountMax=5",
"-o", "TCPKeepAlive=no",
}, nil
}
func set_askpass() (need_to_request_data bool) {
need_to_request_data = true
sentinel := filepath.Join(utils.CacheDir(), "openssh-is-new-enough-for-askpass")
_, err := os.Stat(sentinel)
sentinel_exists := err == nil
if sentinel_exists || GetSSHVersion().SupportsAskpassRequire() {
if !sentinel_exists {
os.WriteFile(sentinel, []byte{0}, 0o644)
}
need_to_request_data = false
}
exe, err := os.Executable()
if err == nil {
os.Setenv("SSH_ASKPASS", exe)
os.Setenv("KITTY_KITTEN_RUN_MODULE", "ssh_askpass")
if !need_to_request_data {
os.Setenv("SSH_ASKPASS_REQUIRE", "force")
}
} else {
need_to_request_data = true
}
return
}
type connection_data struct {
remote_args []string
host_opts *Config
hostname_for_match string
username string
echo_on bool
request_data bool
literal_env map[string]string
test_script string
shm_name string
script_type string
rcmd []string
replacements map[string]string
request_id string
bootstrap_script string
}
func get_effective_ksi_env_var(x string) string {
parts := strings.Split(strings.TrimSpace(strings.ToLower(x)), " ")
current := utils.NewSetWithItems(parts...)
if current.Has("disabled") {
return ""
}
allowed := utils.NewSetWithItems(kitty.AllowedShellIntegrationValues...)
if !current.IsSubsetOf(allowed) {
return RelevantKittyOpts().Shell_integration
}
return x
}
func serialize_env(cd *connection_data, get_local_env func(string) (string, bool)) (string, string) {
ksi := ""
if cd.host_opts.Shell_integration == "inherited" {
ksi = get_effective_ksi_env_var(RelevantKittyOpts().Shell_integration)
} else {
ksi = get_effective_ksi_env_var(cd.host_opts.Shell_integration)
}
env := make([]*EnvInstruction, 0, 8)
add_env := func(key, val string, fallback ...string) *EnvInstruction {
if val == "" && len(fallback) > 0 {
val = fallback[0]
}
if val != "" {
env = append(env, &EnvInstruction{key: key, val: val, literal_quote: true})
return env[len(env)-1]
}
return nil
}
add_non_literal_env := func(key, val string, fallback ...string) *EnvInstruction {
ans := add_env(key, val, fallback...)
if ans != nil {
ans.literal_quote = false
}
return ans
}
for k, v := range cd.literal_env {
add_env(k, v)
}
add_env("TERM", os.Getenv("TERM"), RelevantKittyOpts().Term)
add_env("COLORTERM", "truecolor")
env = append(env, cd.host_opts.Env...)
add_env("KITTY_WINDOW_ID", os.Getenv("KITTY_WINDOW_ID"))
add_env("WINDOWID", os.Getenv("WINDOWID"))
if ksi != "" {
add_env("KITTY_SHELL_INTEGRATION", ksi)
} else {
env = append(env, &EnvInstruction{key: "KITTY_SHELL_INTEGRATION", delete_on_remote: true})
}
add_non_literal_env("KITTY_SSH_KITTEN_DATA_DIR", cd.host_opts.Remote_dir)
add_non_literal_env("KITTY_LOGIN_SHELL", cd.host_opts.Login_shell)
add_non_literal_env("KITTY_LOGIN_CWD", cd.host_opts.Cwd)
if cd.host_opts.Remote_kitty != Remote_kitty_no {
add_env("KITTY_REMOTE", cd.host_opts.Remote_kitty.String())
}
add_env("KITTY_PUBLIC_KEY", os.Getenv("KITTY_PUBLIC_KEY"))
return final_env_instructions(cd.script_type == "py", get_local_env, env...), ksi
}
func make_tarfile(cd *connection_data, get_local_env func(string) (string, bool)) ([]byte, error) {
env_script, ksi := serialize_env(cd, get_local_env)
w := bytes.Buffer{}
w.Grow(64 * 1024)
gw, err := gzip.NewWriterLevel(&w, gzip.BestCompression)
if err != nil {
return nil, err
}
tw := tar.NewWriter(gw)
rd := strings.TrimRight(cd.host_opts.Remote_dir, "/")
seen := make(map[file_unique_id]string, 32)
add := func(h *tar.Header, data []byte) (err error) {
// some distro's like nix mess with installed file permissions so ensure
// files are at least readable and writable by owning user
h.Mode |= 0o600
err = tw.WriteHeader(h)
if err != nil {
return
}
if data != nil {
_, err := tw.Write(data)
if err != nil {
return err
}
}
return
}
for _, ci := range cd.host_opts.Copy {
err = ci.get_file_data(add, seen)
if err != nil {
return nil, err
}
}
type fe struct {
arcname string
data []byte
}
now := time.Now()
add_data := func(items ...fe) error {
for _, item := range items {
err := add(
&tar.Header{
Typeflag: tar.TypeReg, Name: item.arcname, Format: tar.FormatPAX, Size: int64(len(item.data)),
Mode: 0o644, ModTime: now, ChangeTime: now, AccessTime: now,
}, item.data)
if err != nil {
return err
}
}
return nil
}
add_entries := func(prefix string, items ...Entry) error {
for _, item := range items {
err := add(
&tar.Header{
Typeflag: item.metadata.Typeflag, Name: path.Join(prefix, path.Base(item.metadata.Name)), Format: tar.FormatPAX,
Size: int64(len(item.data)), Mode: item.metadata.Mode, ModTime: item.metadata.ModTime,
AccessTime: item.metadata.AccessTime, ChangeTime: item.metadata.ChangeTime,
}, item.data)
if err != nil {
return err
}
}
return nil
}
add_data(fe{"data.sh", utils.UnsafeStringToBytes(env_script)})
if cd.script_type == "sh" {
add_data(fe{"bootstrap-utils.sh", Data()[path.Join("shell-integration/ssh/bootstrap-utils.sh")].data})
}
if ksi != "" {
for _, fname := range Data().files_matching(
"shell-integration/",
"shell-integration/ssh/.+", // bootstrap files are sent as command line args
"shell-integration/zsh/kitty.zsh", // backward compat file not needed by ssh kitten
) {
arcname := path.Join("home/", rd, "/", path.Dir(fname))
err = add_entries(arcname, Data()[fname])
if err != nil {
return nil, err
}
}
}
if cd.host_opts.Remote_kitty != Remote_kitty_no {
arcname := path.Join("home/", rd, "/kitty")
err = add_data(fe{arcname + "/version", utils.UnsafeStringToBytes(kitty.VersionString)})
if err != nil {
return nil, err
}
for _, x := range []string{"kitty", "kitten"} {
err = add_entries(path.Join(arcname, "bin"), Data()[path.Join("shell-integration", "ssh", x)])
if err != nil {
return nil, err
}
}
}
err = add_entries(path.Join("home", ".terminfo"), Data()["terminfo/kitty.terminfo"])
if err == nil {
err = add_entries(path.Join("home", ".terminfo", "x"), Data()["terminfo/x/xterm-kitty"])
}
if err == nil {
err = tw.Close()
if err == nil {
err = gw.Close()
}
}
return w.Bytes(), err
}
func prepare_home_command(cd *connection_data) string {
is_python := cd.script_type == "py"
homevar := ""
for _, ei := range cd.host_opts.Env {
if ei.key == "HOME" && !ei.delete_on_remote {
if ei.copy_from_local {
homevar = os.Getenv("HOME")
} else {
homevar = ei.val
}
}
}
export_home_cmd := ""
if homevar != "" {
if is_python {
export_home_cmd = base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(homevar))
} else {
export_home_cmd = fmt.Sprintf("export HOME=%s; cd \"$HOME\"", utils.QuoteStringForSH(homevar))
}
}
return export_home_cmd
}
func prepare_exec_cmd(cd *connection_data) string {
// ssh simply concatenates multiple commands using a space see
// line 1129 of ssh.c and on the remote side sshd.c runs the
// concatenated command as shell -c cmd
if cd.script_type == "py" {
return base64.RawStdEncoding.EncodeToString(utils.UnsafeStringToBytes(strings.Join(cd.remote_args, " ")))
}
args := make([]string, len(cd.remote_args))
for i, arg := range cd.remote_args {
args[i] = strings.ReplaceAll(arg, "'", "'\"'\"'")
}
return "unset KITTY_SHELL_INTEGRATION; exec \"$login_shell\" -c '" + strings.Join(args, " ") + "'"
}
var data_shm shm.MMap
func prepare_script(script string, replacements map[string]string) string {
if _, found := replacements["EXEC_CMD"]; !found {
replacements["EXEC_CMD"] = ""
}
if _, found := replacements["EXPORT_HOME_CMD"]; !found {
replacements["EXPORT_HOME_CMD"] = ""
}
keys := maps.Keys(replacements)
for i, key := range keys {
keys[i] = "\\b" + key + "\\b"
}
pat := regexp.MustCompile(strings.Join(keys, "|"))
return pat.ReplaceAllStringFunc(script, func(key string) string { return replacements[key] })
}
func bootstrap_script(cd *connection_data) (err error) {
if cd.request_id == "" {
cd.request_id = os.Getenv("KITTY_PID") + "-" + os.Getenv("KITTY_WINDOW_ID")
}
export_home_cmd := prepare_home_command(cd)
exec_cmd := ""
if len(cd.remote_args) > 0 {
exec_cmd = prepare_exec_cmd(cd)
}
pw, err := secrets.TokenHex()
if err != nil {
return err
}
tfd, err := make_tarfile(cd, os.LookupEnv)
if err != nil {
return err
}
data := map[string]string{
"tarfile": base64.StdEncoding.EncodeToString(tfd),
"pw": pw,
"hostname": cd.hostname_for_match, "username": cd.username,
}
encoded_data, err := json.Marshal(data)
if err == nil {
data_shm, err = shm.CreateTemp(fmt.Sprintf("kssh-%d-", os.Getpid()), uint64(len(encoded_data)+8))
if err == nil {
err = data_shm.WriteWithSize(encoded_data)
if err == nil {
err = data_shm.Flush()
}
}
}
if err != nil {
return err
}
cd.shm_name = data_shm.Name()
sensitive_data := map[string]string{"REQUEST_ID": cd.request_id, "DATA_PASSWORD": pw, "PASSWORD_FILENAME": cd.shm_name}
replacements := map[string]string{
"EXPORT_HOME_CMD": export_home_cmd,
"EXEC_CMD": exec_cmd,
"TEST_SCRIPT": cd.test_script,
}
add_bool := func(ok bool, key string) {
if ok {
replacements[key] = "1"
} else {
replacements[key] = "0"
}
}
add_bool(cd.request_data, "REQUEST_DATA")
add_bool(cd.echo_on, "ECHO_ON")
sd := maps.Clone(replacements)
if cd.request_data {
maps.Copy(sd, sensitive_data)
}
maps.Copy(replacements, sensitive_data)
cd.replacements = replacements
cd.bootstrap_script = utils.UnsafeBytesToString(Data()["shell-integration/ssh/bootstrap."+cd.script_type].data)
cd.bootstrap_script = prepare_script(cd.bootstrap_script, sd)
return err
}
func wrap_bootstrap_script(cd *connection_data) {
// sshd will execute the command we pass it by join all command line
// arguments with a space and passing it as a single argument to the users
// login shell with -c. If the user has a non POSIX login shell it might
// have different escaping semantics and syntax, so the command it should
// execute has to be as simple as possible, basically of the form
// interpreter -c unwrap_script escaped_bootstrap_script
// The unwrap_script is responsible for unescaping the bootstrap script and
// executing it.
encoded_script := ""
unwrap_script := ""
if cd.script_type == "py" {
encoded_script = base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(cd.bootstrap_script))
unwrap_script = `"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"`
} else {
// We cant rely on base64 being available on the remote system, so instead
// we quote the bootstrap script by replacing ' and \ with \v and \f
// also replacing \n and ! with \r and \b for tcsh
// finally surrounding with '
encoded_script = "'" + strings.NewReplacer("'", "\v", "\\", "\f", "\n", "\r", "!", "\b").Replace(cd.bootstrap_script) + "'"
unwrap_script = `'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' `
}
cd.rcmd = []string{"exec", cd.host_opts.Interpreter, "-c", unwrap_script, encoded_script}
}
func get_remote_command(cd *connection_data) error {
interpreter := cd.host_opts.Interpreter
q := strings.ToLower(path.Base(interpreter))
is_python := strings.Contains(q, "python")
cd.script_type = "sh"
if is_python {
cd.script_type = "py"
}
err := bootstrap_script(cd)
if err != nil {
return err
}
wrap_bootstrap_script(cd)
return nil
}
func drain_potential_tty_garbage(term *tty.Term) {
err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho)
if err != nil {
return
}
canary, err := secrets.TokenBase64()
if err != nil {
return
}
dcs, err := tui.DCSToKitty("echo", canary+"\n\r")
if err != nil {
return
}
err = term.WriteAllString(dcs)
if err != nil {
return
}
q := utils.UnsafeStringToBytes(canary)
data := make([]byte, 0)
give_up_at := time.Now().Add(2 * time.Second)
buf := make([]byte, 0, 8192)
for !bytes.Contains(data, q) {
buf = buf[:cap(buf)]
timeout := give_up_at.Sub(time.Now())
if timeout < 0 {
break
}
n, err := term.ReadWithTimeout(buf, timeout)
if err != nil {
return
}
data = append(data, buf[:n]...)
}
}
func change_colors(color_scheme string) (ans string, err error) {
if color_scheme == "" {
return
}
var theme *themes.Theme
if !strings.HasSuffix(color_scheme, ".conf") {
cs := os.ExpandEnv(color_scheme)
tc, closer, err := themes.LoadThemes(-1)
if err != nil && errors.Is(err, themes.ErrNoCacheFound) {
tc, closer, err = themes.LoadThemes(time.Hour * 24)
}
if err != nil {
return "", err
}
defer closer.Close()
theme = tc.ThemeByName(cs)
if theme == nil {
return "", fmt.Errorf("No theme named %#v found", cs)
}
} else {
theme, err = themes.ThemeFromFile(utils.ResolveConfPath(color_scheme))
if err != nil {
return "", err
}
}
ans, err = theme.AsEscapeCodes()
if err == nil {
ans = "\033[#P" + ans
}
return
}
func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) {
go Data()
go RelevantKittyOpts()
defer func() {
if data_shm != nil {
data_shm.Close()
data_shm.Unlink()
}
}()
cmd := append([]string{SSHExe()}, ssh_args...)
cd := connection_data{remote_args: server_args[1:]}
hostname := server_args[0]
if len(cd.remote_args) == 0 {
cmd = append(cmd, "-t")
}
insertion_point := len(cmd)
cmd = append(cmd, "--", hostname)
uname, hostname_for_match := get_destination(hostname)
overrides, literal_env, err := parse_kitten_args(found_extra_args, uname, hostname_for_match)
if err != nil {
return 1, err
}
host_opts, bad_lines, err := load_config(hostname_for_match, uname, overrides)
if err != nil {
return 1, err
}
if len(bad_lines) > 0 {
for _, x := range bad_lines {
fmt.Fprintf(os.Stderr, "Ignoring bad config line: %s:%d with error: %s", filepath.Base(x.Src_file), x.Line_number, x.Err)
}
}
if host_opts.Share_connections {
kpid, err := strconv.Atoi(os.Getenv("KITTY_PID"))
if err != nil {
return 1, fmt.Errorf("Invalid KITTY_PID env var not an integer: %#v", os.Getenv("KITTY_PID"))
}
cpargs, err := connection_sharing_args(kpid)
if err != nil {
return 1, err
}
cmd = slices.Insert(cmd, insertion_point, cpargs...)
}
use_kitty_askpass := host_opts.Askpass == Askpass_native || (host_opts.Askpass == Askpass_unless_set && os.Getenv("SSH_ASKPASS") == "")
need_to_request_data := true
if use_kitty_askpass {
need_to_request_data = set_askpass()
}
if need_to_request_data && host_opts.Share_connections {
check_cmd := slices.Insert(cmd, 1, "-O", "check")
err = exec.Command(check_cmd[0], check_cmd[1:]...).Run()
if err == nil {
need_to_request_data = false
}
}
term, err := tty.OpenControllingTerm(tty.SetNoEcho)
if err != nil {
return 1, fmt.Errorf("Failed to open controlling terminal with error: %w", err)
}
cd.echo_on = term.WasEchoOnOriginally()
cd.host_opts, cd.literal_env = host_opts, literal_env
cd.request_data = need_to_request_data
cd.hostname_for_match, cd.username = hostname_for_match, uname
escape_codes_to_set_colors, err := change_colors(cd.host_opts.Color_scheme)
if err == nil {
err = term.WriteAllString(escape_codes_to_set_colors + loop.SAVE_PRIVATE_MODE_VALUES + loop.HANDLE_TERMIOS_SIGNALS.EscapeCodeToSet())
}
if err != nil {
return 1, err
}
restore_escape_codes := loop.RESTORE_PRIVATE_MODE_VALUES
if escape_codes_to_set_colors != "" {
restore_escape_codes += "\x1b[#Q"
}
defer func() {
term.WriteAllString(restore_escape_codes)
term.RestoreAndClose()
}()
err = get_remote_command(&cd)
if err != nil {
return 1, err
}
cmd = append(cmd, cd.rcmd...)
c := exec.Command(cmd[0], cmd[1:]...)
c.Stdin, c.Stdout, c.Stderr = os.Stdin, os.Stdout, os.Stderr
err = c.Start()
if err != nil {
return 1, err
}
if !cd.request_data {
rq := fmt.Sprintf("id=%s:pwfile=%s:pw=%s", cd.replacements["REQUEST_ID"], cd.replacements["PASSWORD_FILENAME"], cd.replacements["DATA_PASSWORD"])
err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho)
if err == nil {
var dcs string
dcs, err = tui.DCSToKitty("ssh", rq)
if err == nil {
err = term.WriteAllString(dcs)
}
}
if err != nil {
c.Process.Kill()
c.Wait()
return 1, err
}
}
err = c.Wait()
drain_potential_tty_garbage(term)
if err != nil {
var exit_err *exec.ExitError
if errors.As(err, &exit_err) {
return exit_err.ExitCode(), nil
}
return 1, err
}
return 0, nil
}
func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
if len(args) > 0 {
switch args[0] {
case "use-python":
args = args[1:] // backwards compat from when we had a python implementation
case "-h", "--help":
cmd.ShowHelp()
return
}
}
ssh_args, server_args, passthrough, found_extra_args, err := ParseSSHArgs(args, "--kitten")
if err != nil {
var invargs *ErrInvalidSSHArgs
switch {
case errors.As(err, &invargs):
if invargs.Msg != "" {
fmt.Fprintln(os.Stderr, invargs.Msg)
}
return 1, unix.Exec(SSHExe(), []string{"ssh"}, os.Environ())
}
return 1, err
}
if passthrough {
if len(found_extra_args) > 0 {
return 1, fmt.Errorf("The SSH kitten cannot work with the options: %s", strings.Join(maps.Keys(PassthroughArgs()), " "))
}
return 1, unix.Exec(SSHExe(), append([]string{"ssh"}, args...), os.Environ())
}
if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" {
return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window")
}
if !tty.IsTerminal(os.Stdin.Fd()) {
return 1, fmt.Errorf("The SSH kitten is meant for interactive use only, STDIN must be a terminal")
}
return run_ssh(ssh_args, server_args, found_extra_args)
}
func EntryPoint(parent *cli.Command) {
create_cmd(parent, main)
}
func specialize_command(ssh *cli.Command) {
ssh.Usage = "arguments for the ssh command"
ssh.ShortDescription = "Truly convenient SSH"
ssh.HelpText = "The ssh kitten is a thin wrapper around the ssh command. It automatically enables shell integration on the remote host, re-uses existing connections to reduce latency, makes the kitty terminfo database available, etc. It's invocation is identical to the ssh command. For details on its usage, see :doc:`/kittens/ssh`."
ssh.IgnoreAllArgs = true
ssh.OnlyArgsAllowed = true
ssh.ArgCompleter = cli.CompletionForWrapper("ssh")
}
func test_integration_with_python(args []string) (rc int, err error) {
f, err := os.CreateTemp("", "*.conf")
if err != nil {
return 1, err
}
defer func() {
f.Close()
os.Remove(f.Name())
}()
_, err = io.Copy(f, os.Stdin)
if err != nil {
return 1, err
}
cd := &connection_data{
request_id: "testing", remote_args: []string{},
username: "testuser", hostname_for_match: "host.test", request_data: true,
test_script: args[0], echo_on: true,
}
opts, bad_lines, err := load_config(cd.hostname_for_match, cd.username, nil, f.Name())
if err == nil {
if len(bad_lines) > 0 {
return 1, fmt.Errorf("Bad config lines: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err)
}
cd.host_opts = opts
err = get_remote_command(cd)
}
if err != nil {
return 1, err
}
data, err := json.Marshal(map[string]any{"cmd": cd.rcmd, "shm_name": cd.shm_name})
if err == nil {
_, err = os.Stdout.Write(data)
os.Stdout.Close()
}
if err != nil {
return 1, err
}
return
}
func TestEntryPoint(root *cli.Command) {
root.AddSubCommand(&cli.Command{
Name: "ssh",
OnlyArgsAllowed: true,
Run: func(cmd *cli.Command, args []string) (rc int, err error) {
return test_integration_with_python(args)
},
})
}

154
tools/cmd/ssh/main_test.go Normal file
View File

@ -0,0 +1,154 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"encoding/binary"
"encoding/json"
"fmt"
"io/fs"
"kitty/tools/utils/shm"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
func TestCloneEnv(t *testing.T) {
env := map[string]string{"a": "1", "b": "2"}
data, err := json.Marshal(env)
if err != nil {
t.Fatal(err)
}
mmap, err := shm.CreateTemp("", 128)
if err != nil {
t.Fatal(err)
}
defer mmap.Unlink()
copy(mmap.Slice()[4:], data)
binary.BigEndian.PutUint32(mmap.Slice(), uint32(len(data)))
mmap.Close()
x, err := add_cloned_env(mmap.Name())
if err != nil {
t.Fatal(err)
}
diff := cmp.Diff(env, x)
if diff != "" {
t.Fatalf("Failed to deserialize env\n%s", diff)
}
}
func basic_connection_data(overrides ...string) *connection_data {
ans := &connection_data{
script_type: "sh", request_id: "123-123", remote_args: []string{},
username: "testuser", hostname_for_match: "host.test",
}
opts, bad_lines, err := load_config(ans.hostname_for_match, ans.username, overrides, "")
if err != nil {
panic(err)
}
if len(bad_lines) != 0 {
panic(fmt.Sprintf("Bad config lines: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err))
}
ans.host_opts = opts
return ans
}
func TestSSHBootstrapScriptLimit(t *testing.T) {
cd := basic_connection_data()
err := get_remote_command(cd)
if err != nil {
t.Fatal(err)
}
total := 0
for _, x := range cd.rcmd {
total += len(x)
}
if total > 9000 {
t.Fatalf("Bootstrap script too large: %d bytes", total)
}
}
func TestSSHTarfile(t *testing.T) {
tdir := t.TempDir()
cd := basic_connection_data()
data, err := make_tarfile(cd, func(key string) (val string, found bool) { return })
if err != nil {
t.Fatal(err)
}
cmd := exec.Command("tar", "xpzf", "-", "-C", tdir)
cmd.Stderr = os.Stderr
inp, err := cmd.StdinPipe()
if err != nil {
t.Fatal(err)
}
err = cmd.Start()
if err != nil {
t.Fatal(err)
}
_, err = inp.Write(data)
if err != nil {
t.Fatal(err)
}
inp.Close()
err = cmd.Wait()
if err != nil {
t.Fatal(err)
}
seen := map[string]bool{}
err = filepath.WalkDir(tdir, func(name string, d fs.DirEntry, werr error) error {
if werr != nil {
return werr
}
rname, werr := filepath.Rel(tdir, name)
if werr != nil {
return werr
}
rname = strings.ReplaceAll(rname, "\\", "/")
if rname == "." {
return nil
}
fi, werr := d.Info()
if werr != nil {
return werr
}
if fi.Mode().Perm()&0o600 == 0 {
return fmt.Errorf("%s is not rw for its owner. Actual permissions: %s", rname, fi.Mode().String())
}
seen[rname] = true
return nil
})
if err != nil {
t.Fatal(err)
}
if !seen["data.sh"] {
t.Fatalf("data.sh missing")
}
for _, x := range []string{".terminfo/kitty.terminfo", ".terminfo/x/xterm-kitty"} {
if !seen["home/"+x] {
t.Fatalf("%s missing", x)
}
}
for _, x := range []string{"shell-integration/bash/kitty.bash", "shell-integration/fish/vendor_completions.d/kitty.fish"} {
if !seen[path.Join("home", cd.host_opts.Remote_dir, x)] {
t.Fatalf("%s missing", x)
}
}
for _, x := range []string{"kitty", "kitten"} {
p := filepath.Join(tdir, "home", cd.host_opts.Remote_dir, "kitty", "bin", x)
if err = unix.Access(p, unix.X_OK); err != nil {
t.Fatalf("Cannot execute %s with error: %s", x, err)
}
}
if seen[path.Join("home", cd.host_opts.Remote_dir, "shell-integration", "ssh", "kitten")] {
t.Fatalf("Contents of shell-integration/ssh not excluded")
}
}

248
tools/cmd/ssh/utils.go Normal file
View File

@ -0,0 +1,248 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"fmt"
"io"
"kitty"
"kitty/tools/config"
"kitty/tools/utils"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
)
var _ = fmt.Print
var SSHExe = (&utils.Once[string]{Run: func() string {
ans := utils.Which("ssh")
if ans != "" {
return ans
}
ans = utils.Which("ssh", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin", "/usr/sbin", "/sbin")
if ans == "" {
ans = "ssh"
}
return ans
}}).Get
var SSHOptions = (&utils.Once[map[string]string]{Run: func() (ssh_options map[string]string) {
defer func() {
if ssh_options == nil {
ssh_options = map[string]string{
"4": "", "6": "", "A": "", "a": "", "C": "", "f": "", "G": "", "g": "", "K": "", "k": "",
"M": "", "N": "", "n": "", "q": "", "s": "", "T": "", "t": "", "V": "", "v": "", "X": "",
"x": "", "Y": "", "y": "", "B": "bind_interface", "b": "bind_address", "c": "cipher_spec",
"D": "[bind_address:]port", "E": "log_file", "e": "escape_char", "F": "configfile", "I": "pkcs11",
"i": "identity_file", "J": "[user@]host[:port]", "L": "address", "l": "login_name", "m": "mac_spec",
"O": "ctl_cmd", "o": "option", "p": "port", "Q": "query_option", "R": "address",
"S": "ctl_path", "W": "host:port", "w": "local_tun[:remote_tun]",
}
}
}()
cmd := exec.Command(SSHExe())
stderr, err := cmd.StderrPipe()
if err != nil {
return
}
if err := cmd.Start(); err != nil {
return
}
raw, err := io.ReadAll(stderr)
if err != nil {
return
}
text := utils.UnsafeBytesToString(raw)
ssh_options = make(map[string]string, 32)
for {
pos := strings.IndexByte(text, '[')
if pos < 0 {
break
}
num := 1
epos := pos
for num > 0 {
epos++
switch text[epos] {
case '[':
num += 1
case ']':
num -= 1
}
}
q := text[pos+1 : epos]
text = text[epos:]
if len(q) < 2 || !strings.HasPrefix(q, "-") {
continue
}
opt, desc, found := strings.Cut(q, " ")
if found {
ssh_options[opt[1:]] = desc
} else {
for _, ch := range opt[1:] {
ssh_options[string(ch)] = ""
}
}
}
return
}}).Get
func GetSSHCLI() (boolean_ssh_args *utils.Set[string], other_ssh_args *utils.Set[string]) {
other_ssh_args, boolean_ssh_args = utils.NewSet[string](32), utils.NewSet[string](32)
for k, v := range SSHOptions() {
k = "-" + k
if v == "" {
boolean_ssh_args.Add(k)
} else {
other_ssh_args.Add(k)
}
}
return
}
func is_extra_arg(arg string, extra_args []string) string {
for _, x := range extra_args {
if arg == x || strings.HasPrefix(arg, x+"=") {
return x
}
}
return ""
}
type ErrInvalidSSHArgs struct {
Msg string
}
func (self *ErrInvalidSSHArgs) Error() string {
return self.Msg
}
func PassthroughArgs() map[string]bool {
return map[string]bool{"-N": true, "-n": true, "-f": true, "-G": true, "-T": true}
}
func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, server_args []string, passthrough bool, found_extra_args []string, err error) {
if extra_args == nil {
extra_args = []string{}
}
if len(args) == 0 {
passthrough = true
return
}
passthrough_args := PassthroughArgs()
boolean_ssh_args, other_ssh_args := GetSSHCLI()
ssh_args, server_args, found_extra_args = make([]string, 0, 16), make([]string, 0, 16), make([]string, 0, 16)
expecting_option_val := false
stop_option_processing := false
expecting_extra_val := ""
for _, argument := range args {
if len(server_args) > 1 || stop_option_processing {
server_args = append(server_args, argument)
continue
}
if strings.HasPrefix(argument, "-") && !expecting_option_val {
if argument == "--" {
stop_option_processing = true
continue
}
if len(extra_args) > 0 {
matching_ex := is_extra_arg(argument, extra_args)
if matching_ex != "" {
_, exval, found := strings.Cut(argument, "=")
if found {
found_extra_args = append(found_extra_args, matching_ex, exval)
} else {
expecting_extra_val = matching_ex
expecting_option_val = true
}
continue
}
}
// could be a multi-character option
all_args := []rune(argument[1:])
for i, ch := range all_args {
arg := "-" + string(ch)
if passthrough_args[arg] {
passthrough = true
}
if boolean_ssh_args.Has(arg) {
ssh_args = append(ssh_args, arg)
continue
}
if other_ssh_args.Has(arg) {
ssh_args = append(ssh_args, arg)
if i+1 < len(all_args) {
ssh_args = append(ssh_args, string(all_args[i+1:]))
} else {
expecting_option_val = true
}
break
}
err = &ErrInvalidSSHArgs{Msg: "unknown option -- " + arg[1:]}
return
}
continue
}
if expecting_option_val {
if expecting_extra_val != "" {
found_extra_args = append(found_extra_args, expecting_extra_val, argument)
} else {
ssh_args = append(ssh_args, argument)
}
expecting_option_val = false
continue
}
server_args = append(server_args, argument)
}
if len(server_args) == 0 && !passthrough {
err = &ErrInvalidSSHArgs{Msg: ""}
}
return
}
type SSHVersion struct{ Major, Minor int }
func (self SSHVersion) SupportsAskpassRequire() bool {
return self.Major > 8 || (self.Major == 8 && self.Minor >= 4)
}
var GetSSHVersion = (&utils.Once[SSHVersion]{Run: func() SSHVersion {
b, err := exec.Command(SSHExe(), "-V").CombinedOutput()
if err != nil {
return SSHVersion{}
}
m := regexp.MustCompile(`OpenSSH_(\d+).(\d+)`).FindSubmatch(b)
if len(m) == 3 {
maj, _ := strconv.Atoi(utils.UnsafeBytesToString(m[1]))
min, _ := strconv.Atoi(utils.UnsafeBytesToString(m[2]))
return SSHVersion{Major: maj, Minor: min}
}
return SSHVersion{}
}}).Get
type KittyOpts struct {
Term, Shell_integration string
}
func read_relevant_kitty_opts(path string) KittyOpts {
ans := KittyOpts{Term: kitty.KittyConfigDefaults.Term, Shell_integration: kitty.KittyConfigDefaults.Shell_integration}
handle_line := func(key, val string) error {
switch key {
case "term":
ans.Term = strings.TrimSpace(val)
case "shell_integration":
ans.Shell_integration = strings.TrimSpace(val)
}
return nil
}
cp := config.ConfigParser{LineHandler: handle_line}
cp.ParseFiles(path)
return ans
}
var RelevantKittyOpts = (&utils.Once[KittyOpts]{Run: func() KittyOpts {
return read_relevant_kitty_opts(filepath.Join(utils.ConfigDir(), "kitty.conf"))
}}).Get

View File

@ -0,0 +1,68 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"fmt"
"os"
"path/filepath"
"testing"
"kitty/tools/utils/shlex"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestGetSSHOptions(t *testing.T) {
m := SSHOptions()
if m["w"] != "local_tun[:remote_tun]" {
t.Fatalf("Unexpected set of SSH options: %#v", m)
}
}
func TestParseSSHArgs(t *testing.T) {
split := func(x string) []string {
ans, err := shlex.Split(x)
if err != nil {
t.Fatal(err)
}
return ans
}
p := func(args, expected_ssh_args, expected_server_args, expected_extra_args string, expected_passthrough bool) {
ssh_args, server_args, passthrough, extra_args, err := ParseSSHArgs(split(args), "--kitten")
if err != nil {
t.Fatal(err)
}
check := func(a, b any) {
diff := cmp.Diff(a, b)
if diff != "" {
t.Fatalf("Unexpected value for args: %s\n%s", args, diff)
}
}
check(split(expected_ssh_args), ssh_args)
check(split(expected_server_args), server_args)
check(split(expected_extra_args), extra_args)
check(expected_passthrough, passthrough)
}
p(`localhost`, ``, `localhost`, ``, false)
p(`-- localhost`, ``, `localhost`, ``, false)
p(`-46p23 localhost sh -c "a b"`, `-4 -6 -p 23`, `localhost sh -c "a b"`, ``, false)
p(`-46p23 -S/moose -W x:6 -- localhost sh -c "a b"`, `-4 -6 -p 23 -S /moose -W x:6`, `localhost sh -c "a b"`, ``, false)
p(`--kitten=abc -np23 --kitten xyz host`, `-n -p 23`, `host`, `--kitten abc --kitten xyz`, true)
}
func TestRelevantKittyOpts(t *testing.T) {
tdir := t.TempDir()
path := filepath.Join(tdir, "kitty.conf")
os.WriteFile(path, []byte("term XXX\nshell_integration changed\nterm abcd"), 0o600)
rko := read_relevant_kitty_opts(path)
if rko.Term != "abcd" {
t.Fatalf("Unexpected TERM: %s", RelevantKittyOpts().Term)
}
if rko.Shell_integration != "changed" {
t.Fatalf("Unexpected shell_integration: %s", RelevantKittyOpts().Shell_integration)
}
}

View File

@ -10,6 +10,8 @@ import (
"kitty/tools/cmd/clipboard"
"kitty/tools/cmd/edit_in_kitty"
"kitty/tools/cmd/icat"
"kitty/tools/cmd/pytest"
"kitty/tools/cmd/ssh"
"kitty/tools/cmd/unicode_input"
"kitty/tools/cmd/update_self"
"kitty/tools/tui"
@ -30,8 +32,12 @@ func KittyToolEntryPoints(root *cli.Command) {
clipboard.EntryPoint(root)
// icat
icat.EntryPoint(root)
// ssh
ssh.EntryPoint(root)
// unicode_input
unicode_input.EntryPoint(root)
// __pytest__
pytest.EntryPoint(root)
// __hold_till_enter__
root.AddSubCommand(&cli.Command{
Name: "__hold_till_enter__",

192
tools/config/api.go Normal file
View File

@ -0,0 +1,192 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package config
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"kitty/tools/utils"
"os"
"path/filepath"
"strings"
)
var _ = fmt.Print
func StringToBool(x string) bool {
x = strings.ToLower(x)
return x == "y" || x == "yes" || x == "true"
}
type ConfigLine struct {
Src_file, Line string
Line_number int
Err error
}
type ConfigParser struct {
LineHandler func(key, val string) error
CommentsHandler func(line string) error
SourceHandler func(text, path string)
bad_lines []ConfigLine
seen_includes map[string]bool
override_env []string
}
type Scanner interface {
Scan() bool
Text() string
Err() error
}
func (self *ConfigParser) BadLines() []ConfigLine {
return self.bad_lines
}
func (self *ConfigParser) parse(scanner Scanner, name, base_path_for_includes string, depth int) error {
if self.seen_includes[name] { // avoid include loops
return nil
}
self.seen_includes[name] = true
recurse := func(r io.Reader, nname, base_path_for_includes string) error {
if depth > 32 {
return fmt.Errorf("Too many nested include directives while processing config file: %s", name)
}
escanner := bufio.NewScanner(r)
return self.parse(escanner, nname, base_path_for_includes, depth+1)
}
lnum := 0
make_absolute := func(path string) (string, error) {
if path == "" {
return "", fmt.Errorf("Empty include paths not allowed")
}
if !filepath.IsAbs(path) {
path = filepath.Join(base_path_for_includes, path)
}
return path, nil
}
for scanner.Scan() {
line := strings.TrimLeft(scanner.Text(), " ")
lnum++
if line == "" {
continue
}
if line[0] == '#' {
if self.CommentsHandler != nil {
err := self.CommentsHandler(line)
if err != nil {
self.bad_lines = append(self.bad_lines, ConfigLine{Src_file: name, Line: line, Line_number: lnum, Err: err})
}
}
continue
}
key, val, _ := strings.Cut(line, " ")
switch key {
default:
err := self.LineHandler(key, val)
if err != nil {
self.bad_lines = append(self.bad_lines, ConfigLine{Src_file: name, Line: line, Line_number: lnum, Err: err})
}
case "include", "globinclude", "envinclude":
var includes []string
switch key {
case "include":
aval, err := make_absolute(val)
if err == nil {
includes = []string{aval}
}
case "globinclude":
aval, err := make_absolute(val)
if err == nil {
matches, err := filepath.Glob(aval)
if err == nil {
includes = matches
}
}
case "envinclude":
env := self.override_env
if env == nil {
env = os.Environ()
}
for _, x := range env {
key, eval, _ := strings.Cut(x, "=")
is_match, err := filepath.Match(val, key)
if is_match && err == nil {
err := recurse(strings.NewReader(eval), "<env var: "+key+">", base_path_for_includes)
if err != nil {
return err
}
}
}
}
if len(includes) > 0 {
for _, incpath := range includes {
raw, err := os.ReadFile(incpath)
if err == nil {
err := recurse(bytes.NewReader(raw), incpath, filepath.Dir(incpath))
if err != nil {
return err
}
} else if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("Failed to process include %#v with error: %w", incpath, err)
}
}
}
}
}
return nil
}
func (self *ConfigParser) ParseFiles(paths ...string) error {
for _, path := range paths {
apath, err := filepath.Abs(path)
if err == nil {
path = apath
}
raw, err := os.ReadFile(path)
if err != nil {
return err
}
scanner := bufio.NewScanner(bytes.NewReader(raw))
self.seen_includes = make(map[string]bool)
err = self.parse(scanner, path, filepath.Dir(path), 0)
if err != nil {
return err
}
if self.SourceHandler != nil {
self.SourceHandler(utils.UnsafeBytesToString(raw), path)
}
}
return nil
}
type LinesScanner struct {
lines []string
}
func (self *LinesScanner) Scan() bool {
return len(self.lines) > 0
}
func (self *LinesScanner) Text() string {
ans := self.lines[0]
self.lines = self.lines[1:]
return ans
}
func (self *LinesScanner) Err() error {
return nil
}
func (self *ConfigParser) ParseOverrides(overrides ...string) error {
s := LinesScanner{lines: overrides}
self.seen_includes = make(map[string]bool)
return self.parse(&s, "<overrides>", utils.ConfigDir(), 0)
}

62
tools/config/api_test.go Normal file
View File

@ -0,0 +1,62 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package config
import (
"fmt"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestConfigParsing(t *testing.T) {
tdir := t.TempDir()
conf_file := filepath.Join(tdir, "a.conf")
os.Mkdir(filepath.Join(tdir, "sub"), 0o700)
os.WriteFile(conf_file, []byte(
`error main
# ignore me
a one
#: other
include sub/b.conf
b
include non-existent
globinclude sub/c?.conf
`), 0o600)
os.WriteFile(filepath.Join(tdir, "sub/b.conf"), []byte("incb cool\ninclude a.conf"), 0o600)
os.WriteFile(filepath.Join(tdir, "sub/c1.conf"), []byte("inc1 cool"), 0o600)
os.WriteFile(filepath.Join(tdir, "sub/c2.conf"), []byte("inc2 cool\nenvinclude ENVINCLUDE"), 0o600)
os.WriteFile(filepath.Join(tdir, "sub/c.conf"), []byte("inc notcool\nerror sub"), 0o600)
var parsed_lines []string
pl := func(key, val string) error {
if key == "error" {
return fmt.Errorf("%s", val)
}
parsed_lines = append(parsed_lines, key+" "+val)
return nil
}
p := ConfigParser{LineHandler: pl, override_env: []string{"ENVINCLUDE=env cool\ninclude c.conf"}}
err := p.ParseFiles(conf_file)
if err != nil {
t.Fatal(err)
}
err = p.ParseOverrides("over one", "over two")
diff := cmp.Diff([]string{"a one", "incb cool", "b ", "inc1 cool", "inc2 cool", "env cool", "inc notcool", "over one", "over two"}, parsed_lines)
if diff != "" {
t.Fatalf("Unexpected parsed config values:\n%s", diff)
}
bad_lines := []string{}
for _, bl := range p.BadLines() {
bad_lines = append(bad_lines, fmt.Sprintf("%s: %d", filepath.Base(bl.Src_file), bl.Line_number))
}
diff = cmp.Diff([]string{"a.conf: 1", "c.conf: 2"}, bad_lines)
if diff != "" {
t.Fatalf("Unexpected bad lines:\n%s", diff)
}
}

430
tools/themes/collection.go Normal file
View File

@ -0,0 +1,430 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package themes
import (
"archive/zip"
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"kitty/tools/config"
"kitty/tools/utils"
"kitty/tools/utils/style"
"net/http"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/exp/maps"
)
var _ = fmt.Print
type JSONMetadata struct {
Etag string `json:"etag"`
Timestamp string `json:"timestamp"`
}
var ErrNoCacheFound = errors.New("No cache found and max cache age is negative")
func fetch_cached(name, url, cache_path string, max_cache_age time.Duration) (string, error) {
cache_path = filepath.Join(cache_path, name+".zip")
zf, err := zip.OpenReader(cache_path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return "", err
}
var jm JSONMetadata
if err == nil {
err = json.Unmarshal(utils.UnsafeStringToBytes(zf.Comment), &jm)
if max_cache_age < 0 {
return cache_path, nil
}
cache_age, err := utils.ISO8601Parse(jm.Timestamp)
if err == nil {
if time.Now().Before(cache_age.Add(max_cache_age)) {
return cache_path, nil
}
}
}
if max_cache_age < 0 {
return "", ErrNoCacheFound
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
if jm.Etag != "" {
req.Header.Add("If-None-Match", jm.Etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("Failed to download %s with error: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusNotModified {
return cache_path, nil
}
return "", fmt.Errorf("Failed to download %s with HTTP error: %s", url, resp.Status)
}
var tf, tf2 *os.File
tf, err = os.CreateTemp(filepath.Dir(cache_path), name+".temp-*")
if err == nil {
tf2, err = os.CreateTemp(filepath.Dir(cache_path), name+".temp-*")
}
defer func() {
if tf != nil {
tf.Close()
os.Remove(tf.Name())
tf = nil
}
if tf2 != nil {
tf2.Close()
os.Remove(tf2.Name())
tf2 = nil
}
}()
if err != nil {
return "", fmt.Errorf("Failed to create temp file in %s with error: %w", filepath.Dir(cache_path), err)
}
_, err = io.Copy(tf, resp.Body)
if err != nil {
return "", fmt.Errorf("Failed to download %s with error: %w", url, err)
}
r, err := zip.OpenReader(tf.Name())
if err != nil {
return "", fmt.Errorf("Failed to open downloaded zip file with error: %w", err)
}
w := zip.NewWriter(tf2)
jm.Etag = resp.Header.Get("ETag")
jm.Timestamp = utils.ISO8601Format(time.Now())
comment, _ := json.Marshal(jm)
w.SetComment(utils.UnsafeBytesToString(comment))
for _, file := range r.File {
err = w.Copy(file)
if err != nil {
return "", fmt.Errorf("Failed to copy zip file from source to destination archive")
}
}
err = w.Close()
if err != nil {
return "", err
}
tf2.Close()
err = os.Rename(tf2.Name(), cache_path)
if err != nil {
return "", fmt.Errorf("Failed to atomic rename temp file to %s with error: %w", cache_path, err)
}
tf2 = nil
return cache_path, nil
}
func FetchCached(max_cache_age time.Duration) (string, error) {
return fetch_cached("kitty-themes", "https://codeload.github.com/kovidgoyal/kitty-themes/zip/master", utils.CacheDir(), max_cache_age)
}
type ThemeMetadata struct {
Name string `json:"name"`
Filepath string `json:"file"`
Is_dark bool `json:"is_dark"`
Num_settings int `json:"num_settings"`
Blurb string `json:"blurb"`
License string `json:"license"`
Upstream string `json:"upstream"`
Author string `json:"author"`
}
func parse_theme_metadata(path string) (*ThemeMetadata, map[string]string, error) {
var in_metadata, in_blurb, finished_metadata bool
ans := ThemeMetadata{}
settings := map[string]string{}
read_is_dark := func(key, val string) (err error) {
settings[key] = val
if key == "background" {
val = strings.TrimSpace(val)
if val != "" {
bg, err := style.ParseColor(val)
if err == nil {
ans.Is_dark = utils.Max(bg.Red, bg.Green, bg.Green) < 115
}
}
}
return
}
read_metadata := func(line string) (err error) {
is_block := strings.HasPrefix(line, "## ")
if in_metadata && !is_block {
finished_metadata = true
}
if finished_metadata {
return
}
if !in_metadata && is_block {
in_metadata = true
}
if !in_metadata {
return
}
line = line[3:]
if in_blurb {
ans.Blurb += " " + line
return
}
key, val, found := strings.Cut(line, ":")
if !found {
return
}
key = strings.TrimSpace(strings.ToLower(key))
val = strings.TrimSpace(val)
switch key {
case "name":
ans.Name = val
case "author":
ans.Author = val
case "upstream":
ans.Upstream = val
case "blurb":
ans.Blurb = val
in_blurb = true
case "license":
ans.License = val
}
return
}
cp := config.ConfigParser{LineHandler: read_is_dark, CommentsHandler: read_metadata}
err := cp.ParseFiles(path)
if err != nil {
return nil, nil, err
}
ans.Num_settings = len(settings)
return &ans, settings, nil
}
type Theme struct {
metadata *ThemeMetadata
code string
settings map[string]string
zip_reader *zip.File
is_user_defined bool
}
func (self *Theme) load_code() (string, error) {
if self.zip_reader != nil {
f, err := self.zip_reader.Open()
self.zip_reader = nil
if err != nil {
return "", err
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return "", err
}
self.code = utils.UnsafeBytesToString(data)
}
return self.code, nil
}
func (self *Theme) Settings() (map[string]string, error) {
if self.zip_reader != nil {
code, err := self.load_code()
if err != nil {
return nil, err
}
self.settings = make(map[string]string, 64)
scanner := bufio.NewScanner(strings.NewReader(code))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" && line[0] != '#' {
key, val, found := strings.Cut(line, " ")
if found {
self.settings[key] = val
}
}
}
}
return self.settings, nil
}
func (self *Theme) AsEscapeCodes() (string, error) {
settings, err := self.Settings()
if err != nil {
return "", err
}
w := strings.Builder{}
w.Grow(4096)
set_color := func(i int, sharp string) {
w.WriteByte(';')
w.WriteString(strconv.Itoa(i))
w.WriteByte(';')
w.WriteString(sharp)
}
set_default_color := func(name, defval string, num int) {
w.WriteString("\033]")
defer func() { w.WriteString("\033\\") }()
val, found := settings[name]
if !found {
val = defval
}
if val != "" {
rgba, err := style.ParseColor(val)
if err == nil {
w.WriteString(strconv.Itoa(num))
w.WriteByte(';')
w.WriteString(rgba.AsRGBSharp())
return
}
}
w.WriteByte('1')
w.WriteString(strconv.Itoa(num))
}
set_default_color("foreground", style.DefaultColors.Foreground, 10)
set_default_color("background", style.DefaultColors.Background, 11)
set_default_color("cursor", style.DefaultColors.Cursor, 12)
set_default_color("selection_background", style.DefaultColors.SelectionBg, 17)
set_default_color("selection_foreground", style.DefaultColors.SelectionFg, 19)
w.WriteString("\033]4")
for i := 0; i < 256; i++ {
key := "color" + strconv.Itoa(i)
val := settings[key]
if val != "" {
rgba, err := style.ParseColor(val)
if err == nil {
set_color(i, rgba.AsRGBSharp())
continue
}
}
rgba := style.RGBA{}
rgba.FromRGB(style.ColorTable[i])
set_color(i, rgba.AsRGBSharp())
}
w.WriteString("\033\\")
return w.String(), nil
}
type Themes struct {
name_map map[string]*Theme
index_map []string
}
var camel_case_pat = (&utils.Once[*regexp.Regexp]{Run: func() *regexp.Regexp {
return regexp.MustCompile(`([a-z])([A-Z])`)
}}).Get
func theme_name_from_file_name(fname string) string {
fname = fname[:len(fname)-len(path.Ext(fname))]
fname = strings.ReplaceAll(fname, "_", " ")
fname = camel_case_pat().ReplaceAllString(fname, "$1 $2")
return strings.Join(utils.Map(strings.Split(fname, " "), strings.Title), " ")
}
func (self *Themes) AddFromFile(path string) (*Theme, error) {
m, conf, err := parse_theme_metadata(path)
if err != nil {
return nil, err
}
if m.Name == "" {
m.Name = theme_name_from_file_name(filepath.Base(path))
}
t := Theme{metadata: m, is_user_defined: true, settings: conf}
self.name_map[m.Name] = &t
return &t, nil
}
func (self *Themes) add_from_dir(dirpath string) error {
entries, err := os.ReadDir(dirpath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
err = nil
}
return err
}
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(e.Name(), ".conf") {
if _, err = self.AddFromFile(filepath.Join(dirpath, e.Name())); err != nil {
return err
}
}
}
return nil
}
func (self *Themes) add_from_zip_file(zippath string) (io.Closer, error) {
r, err := zip.OpenReader(zippath)
if err != nil {
return nil, err
}
name_map := make(map[string]*zip.File, len(r.File))
var themes []ThemeMetadata
theme_dir := ""
for _, file := range r.File {
name_map[file.Name] = file
if path.Base(file.Name) == "themes.json" {
theme_dir = path.Dir(file.Name)
fr, err := file.Open()
if err != nil {
return nil, fmt.Errorf("Error while opening %s from the ZIP file: %w", file.Name, err)
}
defer fr.Close()
raw, err := io.ReadAll(fr)
if err != nil {
return nil, fmt.Errorf("Error while reading %s from the ZIP file: %w", file.Name, err)
}
err = json.Unmarshal(raw, &themes)
if err != nil {
return nil, fmt.Errorf("Error while decoding %s: %w", file.Name, err)
}
}
}
if theme_dir == "" {
return nil, fmt.Errorf("No themes.json found in ZIP file")
}
for _, theme := range themes {
key := path.Join(theme_dir, theme.Filepath)
f := name_map[key]
if f != nil {
t := Theme{metadata: &theme, zip_reader: f}
self.name_map[theme.Name] = &t
}
}
return r, nil
}
func (self *Themes) ThemeByName(name string) *Theme {
return self.name_map[name]
}
func LoadThemes(cache_age time.Duration) (ans *Themes, closer io.Closer, err error) {
zip_path, err := FetchCached(cache_age)
ans = &Themes{name_map: make(map[string]*Theme)}
if err != nil {
return nil, nil, err
}
if closer, err = ans.add_from_zip_file(zip_path); err != nil {
return nil, nil, err
}
if err = ans.add_from_dir(filepath.Join(utils.ConfigDir(), "themes")); err != nil {
return nil, nil, err
}
ans.index_map = maps.Keys(ans.name_map)
ans.index_map = utils.StableSortWithKey(ans.index_map, strings.ToLower)
return ans, closer, nil
}
func ThemeFromFile(path string) (*Theme, error) {
ans := &Themes{name_map: make(map[string]*Theme)}
return ans.AddFromFile(path)
}

View File

@ -0,0 +1,157 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package themes
import (
"archive/zip"
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestThemeCollections(t *testing.T) {
for fname, expected := range map[string]string{
"moose": "Moose",
"mooseCat": "Moose Cat",
"a_bC": "A B C",
} {
actual := theme_name_from_file_name(fname)
if diff := cmp.Diff(expected, actual); diff != "" {
t.Fatalf("Unexpected theme name for %s:\n%s", fname, diff)
}
}
tdir := t.TempDir()
pt := func(expected ThemeMetadata, lines ...string) {
os.WriteFile(filepath.Join(tdir, "temp.conf"), []byte(strings.Join(lines, "\n")), 0o600)
actual, _, err := parse_theme_metadata(filepath.Join(tdir, "temp.conf"))
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(&expected, actual); diff != "" {
t.Fatalf("Failed to parse:\n%s\n\n%s", strings.Join(lines, "\n"), diff)
}
}
pt(ThemeMetadata{Name: "XYZ", Blurb: "a b", Author: "A", Is_dark: true, Num_settings: 2},
"# some crap", " ", "## ", "## author: A", "## name: XYZ", "## blurb: a", "## b", "", "color red", "background black", "include inc.conf")
os.WriteFile(filepath.Join(tdir, "inc.conf"), []byte("background white"), 0o600)
pt(ThemeMetadata{Name: "XYZ", Blurb: "a b", Author: "A", Num_settings: 2},
"# some crap", " ", "## ", "## author: A", "## name: XYZ", "## blurb: a", "## b", "", "color red", "background black", "include inc.conf")
buf := bytes.Buffer{}
zw := zip.NewWriter(&buf)
fw, _ := zw.Create("x/themes.json")
fw.Write([]byte(`[
{
"author": "X Y",
"blurb": "A dark color scheme for the kitty terminal.",
"file": "themes/Alabaster_Dark.conf",
"is_dark": true,
"license": "MIT",
"name": "Alabaster Dark",
"num_settings": 30,
"upstream": "https://xxx.com"
},
{
"name": "Empty", "file": "empty.conf"
}
]`))
fw, _ = zw.Create("x/empty.conf")
fw.Write([]byte("empty"))
fw, _ = zw.Create("x/themes/Alabaster_Dark.conf")
fw.Write([]byte("alabaster"))
zw.Close()
received_etag := ""
request_count := 0
check_etag := true
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request_count++
received_etag = r.Header.Get("If-None-Match")
if check_etag && received_etag == `"xxx"` {
w.WriteHeader(http.StatusNotModified)
return
}
w.Header().Add("ETag", `"xxx"`)
w.Write(buf.Bytes())
}))
defer ts.Close()
_, err := fetch_cached("test", ts.URL, tdir, 0)
if err != nil {
t.Fatal(err)
}
r, err := zip.OpenReader(filepath.Join(tdir, "test.zip"))
if err != nil {
t.Fatal(err)
}
var jm JSONMetadata
err = json.Unmarshal([]byte(r.Comment), &jm)
if err != nil {
t.Fatal(err)
}
if jm.Etag != `"xxx"` {
t.Fatalf("Unexpected ETag: %#v", jm.Etag)
}
_, err = fetch_cached("test", ts.URL, tdir, time.Hour)
if err != nil {
t.Fatal(err)
}
if request_count != 1 {
t.Fatal("Cached zip file was not used")
}
before, _ := os.Stat(filepath.Join(tdir, "test.zip"))
_, err = fetch_cached("test", ts.URL, tdir, 0)
if err != nil {
t.Fatal(err)
}
if request_count != 2 {
t.Fatal("Cached zip file was incorrectly used")
}
if received_etag != `"xxx"` {
t.Fatalf("Got invalid ETag: %#v", received_etag)
}
after, _ := os.Stat(filepath.Join(tdir, "test.zip"))
if before.ModTime() != after.ModTime() {
t.Fatal("Cached zip file was incorrectly re-downloaded")
}
check_etag = false
_, err = fetch_cached("test", ts.URL, tdir, 0)
if err != nil {
t.Fatal(err)
}
after2, _ := os.Stat(filepath.Join(tdir, "test.zip"))
if after2.ModTime() != after.ModTime() {
t.Fatal("Cached zip file was incorrectly not re-downloaded")
}
coll := Themes{name_map: map[string]*Theme{}}
closer, err := coll.add_from_zip_file(filepath.Join(tdir, "test.zip"))
if err != nil {
t.Fatal(err)
}
defer closer.Close()
if code, err := coll.ThemeByName("Empty").load_code(); code != "empty" {
if err != nil {
t.Fatal(err)
}
t.Fatal("failed to load code for empty theme")
}
if code, err := coll.ThemeByName("Alabaster Dark").load_code(); code != "alabaster" {
if err != nil {
t.Fatal(err)
}
t.Fatal("failed to load code for alabaster theme")
}
}

View File

@ -148,6 +148,13 @@ func (self *Term) Close() error {
return err
}
func (self *Term) WasEchoOnOriginally() bool {
if len(self.states) > 0 {
return self.states[0].Lflag&unix.ECHO != 0
}
return false
}
func (self *Term) Tcgetattr(ans *unix.Termios) error {
return eintr_retry_noret(func() error { return Tcgetattr(self.Fd(), ans) })
}
@ -256,13 +263,19 @@ func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error)
}
num_ready, err := pselect()
if err != nil {
return
return 0, err
}
if num_ready == 0 {
err = os.ErrDeadlineExceeded
return
return 0, err
}
for {
n, err = self.Read(b)
if errors.Is(err, unix.EINTR) {
continue
}
return n, err
}
return self.Read(b)
}
func (self *Term) Read(b []byte) (int, error) {
@ -273,10 +286,14 @@ func (self *Term) Write(b []byte) (int, error) {
return self.os_file.Write(b)
}
func is_temporary_error(err error) bool {
return errors.Is(err, unix.EINTR) || errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, io.ErrShortWrite)
}
func (self *Term) WriteAll(b []byte) error {
for len(b) > 0 {
n, err := self.os_file.Write(b)
if err != nil && !errors.Is(err, io.ErrShortWrite) {
if err != nil && !is_temporary_error(err) {
return err
}
b = b[n:]
@ -284,6 +301,10 @@ func (self *Term) WriteAll(b []byte) error {
return nil
}
func (self *Term) WriteAllString(s string) error {
return self.WriteAll(utils.UnsafeStringToBytes(s))
}
func (self *Term) WriteString(b string) (int, error) {
return self.os_file.WriteString(b)
}

28
tools/tui/dcs_to_kitty.go Normal file
View File

@ -0,0 +1,28 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package tui
import (
"encoding/base64"
"fmt"
"kitty/tools/utils"
)
var _ = fmt.Print
func DCSToKitty(msgtype, payload string) (string, error) {
data := base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(payload))
ans := "\x1bP@kitty-" + msgtype + "|" + data
tmux := TmuxSocketAddress()
if tmux != "" {
err := TmuxAllowPassthrough()
if err != nil {
return "", err
}
ans = "\033Ptmux;\033" + ans + "\033\033\\\033\\"
} else {
ans += "\033\\"
}
return ans, nil
}

56
tools/tui/tmux.go Normal file
View File

@ -0,0 +1,56 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package tui
import (
"fmt"
"kitty/tools/utils"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"github.com/shirou/gopsutil/v3/process"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
func tmux_socket_address() (socket string) {
socket = os.Getenv("TMUX")
if socket == "" {
return ""
}
addr, pid_str, found := strings.Cut(socket, ",")
if !found {
return ""
}
if unix.Access(addr, unix.R_OK|unix.W_OK) != nil {
return ""
}
pid, err := strconv.ParseInt(pid_str, 10, 32)
if err != nil {
return ""
}
p, err := process.NewProcess(int32(pid))
if err != nil {
return ""
}
cmd, err := p.CmdlineSlice()
if err != nil {
return ""
}
if len(cmd) > 0 && strings.ToLower(filepath.Base(cmd[0])) != "tmux" {
return ""
}
return socket
}
var TmuxSocketAddress = (&utils.Once[string]{Run: tmux_socket_address}).Get
func tmux_allow_passthrough() error {
return exec.Command("tmux", "set", "-p", "allow-passthrough", "on").Run()
}
var TmuxAllowPassthrough = (&utils.Once[error]{Run: tmux_allow_passthrough}).Get

View File

@ -4,11 +4,9 @@ package unicode_names
import (
"bytes"
"compress/zlib"
_ "embed"
"encoding/binary"
"fmt"
"io"
"strings"
"sync"
"time"
@ -64,33 +62,8 @@ func parse_record(record []byte, mark uint16) {
var parse_once sync.Once
func read_all(r io.Reader, expected_size int) ([]byte, error) {
b := make([]byte, 0, expected_size)
for {
if len(b) == cap(b) {
// Add more capacity (let append pick how much).
b = append(b, 0)[:len(b)]
}
n, err := r.Read(b[len(b):cap(b)])
b = b[:len(b)+n]
if err != nil {
if err == io.EOF {
err = nil
}
return b, err
}
}
}
func parse_data() {
compressed := utils.UnsafeStringToBytes(unicode_name_data)
uncompressed_size := binary.LittleEndian.Uint32(compressed)
r, _ := zlib.NewReader(bytes.NewReader(compressed[4:]))
defer r.Close()
raw, err := read_all(r, int(uncompressed_size))
if err != nil {
panic(err)
}
raw := utils.ReadCompressedEmbeddedData(unicode_name_data)
num_of_lines := binary.LittleEndian.Uint32(raw)
raw = raw[4:]
num_of_words := binary.LittleEndian.Uint32(raw)

View File

@ -12,6 +12,33 @@ import (
var _ = fmt.Print
func AtomicCreateSymlink(oldname, newname string) (err error) {
err = os.Symlink(oldname, newname)
if err == nil {
return nil
}
if !errors.Is(err, fs.ErrExist) {
return err
}
if et, err := os.Readlink(newname); err == nil && et == oldname {
return nil
}
for {
tempname := newname + RandomFilename()
err = os.Symlink(oldname, tempname)
if err == nil {
err = os.Rename(tempname, newname)
if err != nil {
os.Remove(tempname)
}
return err
}
if !errors.Is(err, fs.ErrExist) {
return err
}
}
}
func AtomicWriteFile(path string, data []byte, perm os.FileMode) (err error) {
npath, err := filepath.EvalSymlinks(path)
if errors.Is(err, fs.ErrNotExist) {

46
tools/utils/embed.go Normal file
View File

@ -0,0 +1,46 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"bytes"
"compress/bzip2"
"encoding/binary"
"fmt"
"io"
)
var _ = fmt.Print
func ReadAll(r io.Reader, expected_size int) ([]byte, error) {
b := make([]byte, 0, expected_size)
for {
if len(b) == cap(b) {
// Add more capacity (let append pick how much).
b = append(b, 0)[:len(b)]
}
n, err := r.Read(b[len(b):cap(b)])
b = b[:len(b)+n]
if err != nil {
if err == io.EOF {
err = nil
}
return b, err
}
}
}
func ReadCompressedEmbeddedData(raw string) []byte {
compressed := UnsafeStringToBytes(raw)
uncompressed_size := binary.LittleEndian.Uint32(compressed)
r := bzip2.NewReader(bytes.NewReader(compressed[4:]))
ans, err := ReadAll(r, int(uncompressed_size))
if err != nil {
panic(err)
}
return ans
}
func ReaderForCompressedEmbeddedData(raw string) io.Reader {
return bzip2.NewReader(bytes.NewReader(UnsafeStringToBytes(raw)[4:]))
}

166
tools/utils/iso8601.go Normal file
View File

@ -0,0 +1,166 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"fmt"
"strconv"
"strings"
"time"
)
var _ = fmt.Print
func is_digit(x byte) bool {
return '0' <= x && x <= '9'
}
// The following is copied from the Go standard library to implement date range validation logic
// equivalent to the behaviour of Go's time.Parse.
func isLeap(year int) bool {
return year%4 == 0 && (year%100 != 0 || year%400 == 0)
}
// daysInMonth is the number of days for non-leap years in each calendar month starting at 1
var daysInMonth = [13]int{0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}
func daysIn(m time.Month, year int) int {
if m == time.February && isLeap(year) {
return 29
}
return daysInMonth[int(m)]
}
func ISO8601Parse(raw string) (time.Time, error) {
orig := raw
raw = strings.TrimSpace(raw)
required_number := func(num_digits int) (int, error) {
if len(raw) < num_digits {
return 0, fmt.Errorf("Insufficient digits")
}
text := raw[:num_digits]
raw = raw[num_digits:]
ans, err := strconv.ParseUint(text, 10, 32)
return int(ans), err
}
optional_separator := func(x byte) bool {
if len(raw) > 0 && raw[0] == x {
raw = raw[1:]
}
return len(raw) > 0 && is_digit(raw[0])
}
errf := func(msg string) (time.Time, error) {
return time.Time{}, fmt.Errorf("Invalid ISO8601 timestamp: %#v. %s", orig, msg)
}
optional_separator('+')
year, err := required_number(4)
if err != nil {
return errf("timestamp does not start with a 4 digit year")
}
var month int = 1
var day int = 1
if optional_separator('-') {
month, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit month")
}
if optional_separator('-') {
day, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit day")
}
}
}
var hour, minute, second, nsec int
if len(raw) > 0 && (raw[0] == 'T' || raw[0] == ' ') {
raw = raw[1:]
hour, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit hour")
}
if optional_separator(':') {
minute, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit minute")
}
if optional_separator(':') {
second, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit second")
}
}
}
if len(raw) > 0 && (raw[0] == '.' || raw[0] == ',') {
raw = raw[1:]
num_digits := 0
for len(raw) > num_digits && is_digit(raw[num_digits]) {
num_digits++
}
text := raw[:num_digits]
raw = raw[num_digits:]
extra := 9 - len(text)
if extra < 0 {
text = text[:9]
}
if text != "" {
n, err := strconv.ParseUint(text, 10, 64)
if err != nil {
return errf("timestamp does not have a valid nanosecond field")
}
nsec = int(n)
for ; extra > 0; extra-- {
nsec *= 10
}
}
}
}
switch {
case month < 1 || month > 12:
return errf("timestamp has invalid month value")
case day < 1 || day > 31 || day > daysIn(time.Month(month), year):
return errf("timestamp has invalid day value")
case hour < 0 || hour > 23:
return errf("timestamp has invalid hour value")
case minute < 0 || minute > 59:
return errf("timestamp has invalid minute value")
case second < 0 || second > 59:
return errf("timestamp has invalid second value")
}
loc := time.UTC
tzsign, tzhour, tzminute := 0, 0, 0
if len(raw) > 0 {
switch raw[0] {
case '+':
tzsign = 1
case '-':
tzsign = -1
}
}
if tzsign != 0 {
raw = raw[1:]
tzhour, err = required_number(2)
if err != nil {
return errf("timestamp has invalid timezone hour")
}
optional_separator(':')
tzminute, err = required_number(2)
if err != nil {
tzminute = 0
}
seconds := tzhour*3600 + tzminute*60
loc = time.FixedZone("", tzsign*seconds)
}
return time.Date(year, time.Month(month), day, hour, minute, second, nsec, loc), err
}
func ISO8601Format(x time.Time) string {
return x.Format(time.RFC3339Nano)
}

View File

@ -0,0 +1,40 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"fmt"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestISO8601(t *testing.T) {
now := time.Now()
tt := func(raw string, expected time.Time) {
actual, err := ISO8601Parse(raw)
if err != nil {
t.Fatalf("Parsing: %#v failed with error: %s", raw, err)
}
if diff := cmp.Diff(expected, actual); diff != "" {
t.Fatalf("Parsing: %#v failed:\n%s", raw, diff)
}
}
tt(ISO8601Format(now), now)
tt("2023-02-08T07:24:09.551975+00:00", time.Date(2023, 2, 8, 7, 24, 9, 551975000, time.UTC))
tt("2023-02-08T07:24:09.551975Z", time.Date(2023, 2, 8, 7, 24, 9, 551975000, time.UTC))
tt("2023", time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC))
tt("2023-11-13", time.Date(2023, 11, 13, 0, 0, 0, 0, time.UTC))
tt("2023-11-13 07:23", time.Date(2023, 11, 13, 7, 23, 0, 0, time.UTC))
tt("2023-11-13 07:23:01", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC))
tt("2023-11-13 07:23:01.", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC))
tt("2023-11-13 07:23:01.0", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC))
tt("2023-11-13 07:23:01.1", time.Date(2023, 11, 13, 7, 23, 1, 100000000, time.UTC))
tt("202311-13 07", time.Date(2023, 11, 13, 7, 0, 0, 0, time.UTC))
tt("20231113 0705", time.Date(2023, 11, 13, 7, 5, 0, 0, time.UTC))
}

View File

@ -11,12 +11,9 @@ import (
"os"
"path/filepath"
"strings"
"sync"
)
var _ = fmt.Print
var user_mime_only_once sync.Once
var user_defined_mime_map map[string]string
func load_mime_file(filename string, mime_map map[string]string) error {
f, err := os.Open(filename)
@ -45,18 +42,19 @@ func load_mime_file(filename string, mime_map map[string]string) error {
return nil
}
func load_user_mime_maps() {
var UserMimeMap = (&Once[map[string]string]{Run: func() map[string]string {
conf_path := filepath.Join(ConfigDir(), "mime.types")
err := load_mime_file(conf_path, user_defined_mime_map)
ans := make(map[string]string, 32)
err := load_mime_file(conf_path, ans)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
fmt.Fprintln(os.Stderr, "Failed to parse", conf_path, "for MIME types with error:", err)
}
}
return ans
}}).Get
func GuessMimeType(filename string) string {
user_mime_only_once.Do(load_user_mime_maps)
ext := filepath.Ext(filename)
mime_with_parameters := user_defined_mime_map[ext]
mime_with_parameters := UserMimeMap()[ext]
if mime_with_parameters == "" {
mime_with_parameters = mime.TypeByExtension(ext)
}

View File

@ -57,6 +57,14 @@ func Filter[T any](s []T, f func(x T) bool) []T {
return ans
}
func Map[T any](s []T, f func(x T) T) []T {
ans := make([]T, 0, len(s))
for _, x := range s {
ans = append(ans, f(x))
}
return ans
}
func Sort[T any](s []T, less func(a, b T) bool) []T {
sort.Slice(s, func(i, j int) bool { return less(s[i], s[j]) })
return s

35
tools/utils/once.go Normal file
View File

@ -0,0 +1,35 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"fmt"
"sync"
"sync/atomic"
)
var _ = fmt.Print
type Once[T any] struct {
done uint32
mutex sync.Mutex
cached_val T
Run func() T
}
func (self *Once[T]) Get() T {
if atomic.LoadUint32(&self.done) == 0 {
self.do_slow()
}
return self.cached_val
}
func (self *Once[T]) do_slow() {
self.mutex.Lock()
defer self.mutex.Unlock()
if atomic.LoadUint32(&self.done) == 0 {
defer atomic.StoreUint32(&self.done, 1)
self.cached_val = self.Run()
}
}

24
tools/utils/once_test.go Normal file
View File

@ -0,0 +1,24 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"fmt"
"testing"
)
var _ = fmt.Print
func TestOnce(t *testing.T) {
num := 0
var G = (&Once[string]{Run: func() string {
num++
return fmt.Sprintf("%d", num)
}}).Get
G()
G()
G()
if num != 1 {
t.Fatalf("num unexpectedly: %d", num)
}
}

View File

@ -3,13 +3,18 @@
package utils
import (
"crypto/rand"
"encoding/base32"
"fmt"
"io/fs"
not_rand "math/rand"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"golang.org/x/sys/unix"
)
@ -57,61 +62,57 @@ func Abspath(path string) string {
return path
}
var config_dir, kitty_exe, cache_dir string
var kitty_exe_err error
var config_dir_once, kitty_exe_once, cache_dir_once sync.Once
func find_kitty_exe() {
var KittyExe = (&Once[string]{Run: func() string {
exe, err := os.Executable()
if err == nil {
kitty_exe = filepath.Join(filepath.Dir(exe), "kitty")
kitty_exe_err = unix.Access(kitty_exe, unix.X_OK)
} else {
kitty_exe_err = err
return filepath.Join(filepath.Dir(exe), "kitty")
}
}
return ""
}}).Get
func KittyExe() (string, error) {
kitty_exe_once.Do(find_kitty_exe)
return kitty_exe, kitty_exe_err
}
func find_config_dir() {
if os.Getenv("KITTY_CONFIG_DIRECTORY") != "" {
config_dir = Abspath(Expanduser(os.Getenv("KITTY_CONFIG_DIRECTORY")))
} else {
var locations []string
if os.Getenv("XDG_CONFIG_HOME") != "" {
locations = append(locations, os.Getenv("XDG_CACHE_HOME"))
var ConfigDir = (&Once[string]{Run: func() (config_dir string) {
if kcd := os.Getenv("KITTY_CONFIG_DIRECTORY"); kcd != "" {
return Abspath(Expanduser(kcd))
}
var locations []string
seen := NewSet[string]()
add := func(x string) {
x = Abspath(Expanduser(x))
if !seen.Has(x) {
seen.Add(x)
locations = append(locations, x)
}
locations = append(locations, Expanduser("~/.config"))
if runtime.GOOS == "darwin" {
locations = append(locations, Expanduser("~/Library/Preferences"))
}
if xh := os.Getenv("XDG_CONFIG_HOME"); xh != "" {
add(xh)
}
if dirs := os.Getenv("XDG_CONFIG_DIRS"); dirs != "" {
for _, candidate := range strings.Split(dirs, ":") {
add(candidate)
}
for _, loc := range locations {
if loc != "" {
q := filepath.Join(loc, "kitty")
if _, err := os.Stat(filepath.Join(q, "kitty.conf")); err == nil {
config_dir = q
break
}
}
}
for _, loc := range locations {
if loc != "" {
config_dir = filepath.Join(loc, "kitty")
break
}
add("~/.config")
if runtime.GOOS == "darwin" {
add("~/Library/Preferences")
}
for _, loc := range locations {
if loc != "" {
q := filepath.Join(loc, "kitty")
if _, err := os.Stat(filepath.Join(q, "kitty.conf")); err == nil {
config_dir = q
return
}
}
}
}
config_dir = os.Getenv("XDG_CONFIG_HOME")
if config_dir == "" {
config_dir = "~/.config"
}
config_dir = filepath.Join(Expanduser(config_dir), "kitty")
return
}}).Get
func ConfigDir() string {
config_dir_once.Do(find_config_dir)
return config_dir
}
func find_cache_dir() {
var CacheDir = (&Once[string]{Run: func() (cache_dir string) {
candidate := ""
if edir := os.Getenv("KITTY_CACHE_DIRECTORY"); edir != "" {
candidate = Abspath(Expanduser(edir))
@ -125,13 +126,71 @@ func find_cache_dir() {
candidate = filepath.Join(Expanduser(candidate), "kitty")
}
os.MkdirAll(candidate, 0o755)
cache_dir = candidate
return candidate
}}).Get
func macos_user_cache_dir() string {
// Sadly Go does not provide confstr() so we use this hack.
// Note that given a user generateduid and uid we can derive this by using
// the algorithm at https://github.com/ydkhatri/MacForensics/blob/master/darwin_path_generator.py
// but I cant find a good way to get the generateduid. Requires calling dscl in which case we might as well call getconf
// The data is in /var/db/dslocal/nodes/Default/users/<username>.plist but it needs root
// So instead we use various hacks to get it quickly, falling back to running /usr/bin/getconf
is_ok := func(m string) bool {
s, err := os.Stat(m)
if err != nil {
return false
}
stat, ok := s.Sys().(unix.Stat_t)
return ok && s.IsDir() && int(stat.Uid) == os.Geteuid() && s.Mode().Perm() == 0o700 && unix.Access(m, unix.X_OK|unix.W_OK|unix.R_OK) == nil
}
if tdir := strings.TrimRight(os.Getenv("TMPDIR"), "/"); filepath.Base(tdir) == "T" {
if m := filepath.Join(filepath.Dir(tdir), "C"); is_ok(m) {
return m
}
}
matches, err := filepath.Glob("/private/var/folders/*/*/C")
if err == nil {
for _, m := range matches {
if is_ok(m) {
return m
}
}
}
out, err := exec.Command("/usr/bin/getconf", "DARWIN_USER_CACHE_DIR").Output()
if err == nil {
return strings.TrimRight(strings.TrimSpace(UnsafeBytesToString(out)), "/")
}
return ""
}
func CacheDir() string {
cache_dir_once.Do(find_cache_dir)
return cache_dir
}
var RuntimeDir = (&Once[string]{Run: func() (runtime_dir string) {
var candidate string
if q := os.Getenv("KITTY_RUNTIME_DIRECTORY"); q != "" {
candidate = q
} else if runtime.GOOS == "darwin" {
candidate = macos_user_cache_dir()
} else if q := os.Getenv("XDG_RUNTIME_DIR"); q != "" {
candidate = q
}
candidate = strings.TrimRight(candidate, "/")
if candidate == "" {
q := fmt.Sprintf("/run/user/%d", os.Geteuid())
if s, err := os.Stat(q); err == nil && s.IsDir() && unix.Access(q, unix.X_OK|unix.R_OK|unix.W_OK) == nil {
candidate = q
} else {
candidate = filepath.Join(CacheDir(), "run")
}
}
os.MkdirAll(candidate, 0o700)
if s, err := os.Stat(candidate); err == nil && s.Mode().Perm() != 0o700 {
os.Chmod(candidate, 0o700)
}
return candidate
}}).Get
type Walk_callback func(path, abspath string, d fs.DirEntry, err error) error
@ -205,3 +264,21 @@ func WalkWithSymlink(dirpath string, callback Walk_callback, transformers ...fun
seen: make(map[string]bool), real_callback: callback, transform_func: transform, needs_recurse_func: needs_symlink_recurse}
return sw.walk(dirpath)
}
func RandomFilename() string {
b := []byte{0, 0, 0, 0, 0, 0, 0, 0}
_, err := rand.Read(b)
if err != nil {
return strconv.FormatUint(uint64(not_rand.Uint32()), 16)
}
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b)
}
func ResolveConfPath(path string) string {
cs := os.ExpandEnv(Expanduser(path))
if !filepath.IsAbs(cs) {
cs = filepath.Join(ConfigDir(), cs)
}
return cs
}

View File

@ -0,0 +1,65 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package paths
import (
"fmt"
"os"
"path/filepath"
"strings"
"kitty/tools/utils"
)
var _ = fmt.Print
type Ctx struct {
home, cwd string
}
func (ctx *Ctx) SetHome(val string) {
ctx.home = val
}
func (ctx *Ctx) SetCwd(val string) {
ctx.cwd = val
}
func (ctx *Ctx) HomePath() (ans string) {
ans = ctx.home
if ans == "" {
ans = utils.Expanduser("~")
}
return
}
func (ctx *Ctx) CwdPath() (ans string) {
ans = ctx.cwd
if ans == "" {
var err error
ans, err = os.Getwd()
if err != nil {
ans = "."
}
}
return
}
func abspath(path, base string) (ans string) {
return filepath.Join(base, path)
}
func (ctx *Ctx) Abspath(path string) (ans string) {
return abspath(path, ctx.CwdPath())
}
func (ctx *Ctx) AbspathFromHome(path string) (ans string) {
return abspath(path, ctx.HomePath())
}
func (ctx *Ctx) ExpandHome(path string) (ans string) {
if strings.HasPrefix(path, "~/") {
return ctx.AbspathFromHome(path)
}
return path
}

View File

@ -0,0 +1,42 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package secrets
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
)
var _ = fmt.Print
const DEFAULT_NUM_OF_BYTES_FOR_TOKEN = 32
func TokenBytes(nbytes ...int) ([]byte, error) {
if len(nbytes) == 0 {
nbytes = []int{DEFAULT_NUM_OF_BYTES_FOR_TOKEN}
}
buf := make([]byte, nbytes[0])
_, err := rand.Read(buf)
if err != nil {
return nil, err
}
return buf, nil
}
func TokenHex(nbytes ...int) (string, error) {
b, err := TokenBytes(nbytes...)
if err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
func TokenBase64(nbytes ...int) (string, error) {
b, err := TokenBytes(nbytes...)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}

View File

@ -4,6 +4,8 @@ package utils
import (
"fmt"
"golang.org/x/exp/maps"
)
var _ = fmt.Print
@ -22,6 +24,10 @@ func (self *Set[T]) AddItems(val ...T) {
}
}
func (self *Set[T]) String() string {
return fmt.Sprintf("%#v", maps.Keys(self.items))
}
func (self *Set[T]) Remove(val T) {
delete(self.items, val)
}
@ -68,6 +74,25 @@ func (self *Set[T]) Intersect(other *Set[T]) (ans *Set[T]) {
return
}
func (self *Set[T]) Subtract(other *Set[T]) (ans *Set[T]) {
ans = NewSet[T](self.Len())
for x := range self.items {
if !other.Has(x) {
ans.items[x] = struct{}{}
}
}
return ans
}
func (self *Set[T]) IsSubsetOf(other *Set[T]) bool {
for x := range self.items {
if !other.Has(x) {
return false
}
}
return true
}
func NewSet[T comparable](capacity ...int) (ans *Set[T]) {
if len(capacity) == 0 {
ans = &Set[T]{items: make(map[T]struct{}, 8)}
@ -76,3 +101,9 @@ func NewSet[T comparable](capacity ...int) (ans *Set[T]) {
}
return
}
func NewSetWithItems[T comparable](items ...T) (ans *Set[T]) {
ans = NewSet[T](len(items))
ans.AddItems(items...)
return ans
}

View File

@ -3,15 +3,16 @@
package shm
import (
"crypto/rand"
"encoding/base32"
"encoding/binary"
"errors"
"fmt"
not_rand "math/rand"
"io"
"io/fs"
"os"
"strconv"
"strings"
"kitty/tools/cli"
"golang.org/x/sys/unix"
)
@ -43,15 +44,6 @@ func prefix_and_suffix(pattern string) (prefix, suffix string, err error) {
return prefix, suffix, nil
}
func next_random() string {
b := make([]byte, 8)
_, err := rand.Read(b)
if err != nil {
return strconv.FormatUint(uint64(not_rand.Uint32()), 16)
}
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b)
}
type MMap interface {
Close() error
Unlink() error
@ -59,6 +51,13 @@ type MMap interface {
Name() string
IsFileSystemBacked() bool
FileSystemName() string
Stat() (fs.FileInfo, error)
Flush() error
Seek(offset int64, whence int) (int64, error)
Read(b []byte) (int, error)
ReadWithSize() ([]byte, error)
Write(p []byte) (n int, err error)
WriteWithSize([]byte) error
}
type AccessFlags int
@ -109,3 +108,78 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) {
}
return
}
func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) {
p := buf
for len(p) > 0 {
n, err := f.Read(p)
p = p[n:]
if err != nil {
if len(p) == 0 && errors.Is(err, io.EOF) {
err = nil
}
return buf[:len(buf)-len(p)], err
}
}
return buf, nil
}
func read_with_size(f *os.File) ([]byte, error) {
szbuf := []byte{0, 0, 0, 0}
szbuf, err := read_till_buf_full(f, szbuf)
if err != nil {
return nil, err
}
size := int(binary.BigEndian.Uint32(szbuf))
return read_till_buf_full(f, make([]byte, size))
}
func write_with_size(f *os.File, b []byte) error {
szbuf := []byte{0, 0, 0, 0}
binary.BigEndian.PutUint32(szbuf, uint32(len(b)))
_, err := f.Write(szbuf)
if err == nil {
_, err = f.Write(b)
}
return err
}
func test_integration_with_python(args []string) (rc int, err error) {
switch args[0] {
default:
return 1, fmt.Errorf("Unknown test type: %s", args[0])
case "read":
data, err := ReadWithSizeAndUnlink(args[1])
if err != nil {
return 1, err
}
_, err = os.Stdout.Write(data)
if err != nil {
return 1, err
}
case "write":
data, err := io.ReadAll(os.Stdin)
if err != nil {
return 1, err
}
mmap, err := CreateTemp("shmtest-", uint64(len(data)+4))
if err != nil {
return 1, err
}
mmap.WriteWithSize(data)
mmap.Close()
fmt.Println(mmap.Name())
}
return 0, nil
}
func TestEntryPoint(root *cli.Command) {
root.AddSubCommand(&cli.Command{
Name: "shm",
OnlyArgsAllowed: true,
Run: func(cmd *cli.Command, args []string) (rc int, err error) {
return test_integration_with_python(args)
},
})
}

View File

@ -8,10 +8,13 @@ import (
"errors"
"fmt"
"io/fs"
"kitty/tools/utils"
"os"
"path/filepath"
"runtime"
"kitty/tools/utils"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
@ -39,6 +42,10 @@ func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool, speci
return &file_based_mmap{f: f, region: region, special_name: special_name}, nil
}
func (self *file_based_mmap) Stat() (fs.FileInfo, error) {
return self.f.Stat()
}
func (self *file_based_mmap) Name() string {
if self.special_name != "" {
return self.special_name
@ -46,6 +53,30 @@ func (self *file_based_mmap) Name() string {
return filepath.Base(self.f.Name())
}
func (self *file_based_mmap) Flush() error {
return unix.Msync(self.region, unix.MS_SYNC)
}
func (self *file_based_mmap) Seek(offset int64, whence int) (int64, error) {
return self.f.Seek(offset, whence)
}
func (self *file_based_mmap) Read(b []byte) (int, error) {
return self.f.Read(b)
}
func (self *file_based_mmap) Write(b []byte) (int, error) {
return self.f.Write(b)
}
func (self *file_based_mmap) WriteWithSize(b []byte) error {
return write_with_size(self.f, b)
}
func (self *file_based_mmap) ReadWithSize() ([]byte, error) {
return read_with_size(self.f)
}
func (self *file_based_mmap) FileSystemName() string {
return self.f.Name()
}
@ -92,7 +123,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
var f *os.File
try := 0
for {
name := prefix + next_random() + suffix
name := prefix + utils.RandomFilename() + suffix
path := file_path_from_name(name)
f, err = os.OpenFile(path, os.O_EXCL|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
@ -113,7 +144,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
return file_mmap(f, size, WRITE, true, special_name)
}
func Open(name string, size uint64) (MMap, error) {
func open(name string) (*os.File, error) {
ans, err := os.OpenFile(file_path_from_name(name), os.O_RDONLY, 0)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
@ -123,5 +154,29 @@ func Open(name string, size uint64) (MMap, error) {
}
return nil, err
}
return ans, nil
}
func Open(name string, size uint64) (MMap, error) {
ans, err := open(name)
if err != nil {
return nil, err
}
return file_mmap(ans, size, READ, false, name)
}
func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) {
f, err := open(name)
if err != nil {
return nil, err
}
defer f.Close()
defer os.Remove(f.Name())
for _, cb := range file_callback {
err = cb(f)
if err != nil {
return nil, err
}
}
return read_with_size(f)
}

View File

@ -11,6 +11,8 @@ import (
"strings"
"unsafe"
"kitty/tools/utils"
"golang.org/x/sys/unix"
)
@ -86,11 +88,38 @@ func syscall_mmap(f *os.File, size uint64, access AccessFlags, truncate bool) (M
func (self *syscall_based_mmap) Name() string {
return self.f.Name()
}
func (self *syscall_based_mmap) Stat() (fs.FileInfo, error) {
return self.f.Stat()
}
func (self *syscall_based_mmap) Flush() error {
return unix.Msync(self.region, unix.MS_SYNC)
}
func (self *syscall_based_mmap) Slice() []byte {
return self.region
}
func (self *syscall_based_mmap) Seek(offset int64, whence int) (int64, error) {
return self.f.Seek(offset, whence)
}
func (self *syscall_based_mmap) Read(b []byte) (int, error) {
return self.f.Read(b)
}
func (self *syscall_based_mmap) Write(b []byte) (int, error) {
return self.f.Write(b)
}
func (self *syscall_based_mmap) WriteWithSize(b []byte) error {
return write_with_size(self.f, b)
}
func (self *syscall_based_mmap) ReadWithSize() ([]byte, error) {
return read_with_size(self.f)
}
func (self *syscall_based_mmap) Close() (err error) {
if self.region != nil {
self.f.Close()
@ -124,7 +153,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
var f *os.File
try := 0
for {
name := prefix + next_random() + suffix
name := prefix + utils.RandomFilename() + suffix
if len(name) > SHM_NAME_MAX {
return nil, ErrPatternTooLong
}
@ -151,3 +180,19 @@ func Open(name string, size uint64) (MMap, error) {
}
return syscall_mmap(ans, size, READ, false)
}
func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) {
f, err := shm_open(name, os.O_RDONLY, 0)
if err != nil {
return nil, err
}
defer f.Close()
defer shm_unlink(f.Name())
for _, cb := range file_callback {
err = cb(f)
if err != nil {
return nil, err
}
}
return read_with_size(f)
}

View File

@ -23,6 +23,10 @@ func TestSHM(t *testing.T) {
}
copy(mm.Slice(), data)
err = mm.Flush()
if err != nil {
t.Fatalf("Failed to msync() with error: %v", err)
}
err = mm.Close()
if err != nil {
t.Fatalf("Failed to close with error: %v", err)

View File

@ -57,6 +57,10 @@ type RGBA struct {
Red, Green, Blue, Inverse_alpha uint8
}
func (self RGBA) AsRGBSharp() string {
return fmt.Sprintf("#%02x%02x%02x", self.Red, self.Green, self.Blue)
}
func (self *RGBA) parse_rgb_strings(r string, g string, b string) bool {
var rv, gv, bv uint64
var err error
@ -77,6 +81,12 @@ func (self *RGBA) AsRGB() uint32 {
return uint32(self.Blue) | (uint32(self.Green) << 8) | (uint32(self.Red) << 16)
}
func (self *RGBA) FromRGB(col uint32) {
self.Red = uint8((col >> 16) & 0xff)
self.Green = uint8((col >> 8) & 0xff)
self.Blue = uint8((col) & 0xff)
}
type color_type struct {
is_numbered bool
val RGBA

View File

@ -13,15 +13,18 @@ import (
var _ = fmt.Print
func Which(cmd string) string {
func Which(cmd string, paths ...string) string {
if strings.Contains(cmd, string(os.PathSeparator)) {
return ""
}
path := os.Getenv("PATH")
if path == "" {
return ""
if len(paths) == 0 {
path := os.Getenv("PATH")
if path == "" {
return ""
}
paths = strings.Split(path, string(os.PathListSeparator))
}
for _, dir := range strings.Split(path, string(os.PathListSeparator)) {
for _, dir := range paths {
q := filepath.Join(dir, cmd)
if unix.Access(q, unix.X_OK) == nil {
s, err := os.Stat(q)