diff --git a/docs/changelog.rst b/docs/changelog.rst index cfefc9f2c..e84d1b28a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -62,6 +62,9 @@ Detailed list of changes - macOS: Fix the maximized window not taking up full space when the title bar is hidden or when :opt:`resize_in_steps` is configured (:iss:`6021`) +- ssh kitten: Change the syntax of glob patterns slightly to match common usage + elsewhere. Now the syntax is the same a "extendedglob" in most shells. + 0.27.1 [2023-02-07] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/conf.py b/docs/conf.py index 2755b60fa..552c73dd7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -226,15 +226,16 @@ def commit_role( # CLI docs {{{ def write_cli_docs(all_kitten_names: Iterable[str]) -> None: - from kittens.ssh.copy import option_text - from kittens.ssh.options.definition import copy_message + from kittens.ssh.main import copy_message, option_text from kitty.cli import option_spec_as_rst - from kitty.launch import options_spec as launch_options_spec with open('generated/ssh-copy.rst', 'w') as f: f.write(option_spec_as_rst( appname='copy', ospec=option_text, heading_char='^', usage='file-or-dir-to-copy ...', message=copy_message )) + del sys.modules['kittens.ssh.main'] + + from kitty.launch import options_spec as launch_options_spec with open('generated/launch.rst', 'w') as f: f.write(option_spec_as_rst( appname='launch', ospec=launch_options_spec, heading_char='_', @@ -525,9 +526,9 @@ def write_conf_docs(app: Any, all_kitten_names: Iterable[str]) -> None: from kittens.runner import get_kitten_conf_docs for kitten in all_kitten_names: - definition = get_kitten_conf_docs(kitten) - if definition: - generate_default_config(definition, f'kitten-{kitten}') + defn = get_kitten_conf_docs(kitten) + if defn is not None: + generate_default_config(defn, f'kitten-{kitten}') from kitty.actions import as_rst with open('generated/actions.rst', 'w', encoding='utf-8') as f: diff --git a/gen-config.py b/gen-config.py index 8cc3e4350..a59e5a1e1 100755 --- a/gen-config.py +++ b/gen-config.py @@ -51,8 +51,6 @@ def main() -> None: from kittens.diff.options.definition import definition as kd write_output('kittens.diff', kd) - from kittens.ssh.options.definition import definition as sd - write_output('kittens.ssh', sd) if __name__ == '__main__': diff --git a/gen-go-code.py b/gen-go-code.py index 057430751..4c5097b20 100755 --- a/gen-go-code.py +++ b/gen-go-code.py @@ -1,15 +1,17 @@ #!./kitty/launcher/kitty +launch # License: GPLv3 Copyright: 2022, Kovid Goyal +import bz2 import io import json import os import struct import subprocess import sys -import zlib +import tarfile from contextlib import contextmanager, suppress from functools import lru_cache +from itertools import chain from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Set, TextIO, Tuple, Union import kitty.constants as kc @@ -22,6 +24,8 @@ from kitty.cli import ( parse_option_spec, serialize_as_go_string, ) +from kitty.conf.generate import gen_go_code +from kitty.conf.types import Definition from kitty.guess_mime_type import text_mimes from kitty.key_encoding import config_mod_map from kitty.key_names import character_key_name_aliases, functional_key_name_aliases @@ -38,7 +42,7 @@ def newer(dest: str, *sources: str) -> bool: dtime = os.path.getmtime(dest) except OSError: return True - for s in sources: + for s in chain(sources, (__file__,)): with suppress(FileNotFoundError): if os.path.getmtime(s) >= dtime: return True @@ -318,8 +322,45 @@ def wrapped_kittens() -> Sequence[str]: raise Exception('Failed to read wrapped kittens from kitty wrapper script') +def generate_conf_parser(kitten: str, defn: Definition) -> None: + with replace_if_needed(f'tools/cmd/{kitten}/conf_generated.go'): + print(f'package {kitten}') + print(gen_go_code(defn)) + + +def generate_extra_cli_parser(name: str, spec: str) -> None: + print('import "kitty/tools/cli"') + go_opts = tuple(go_options_for_seq(parse_option_spec(spec)[0])) + print(f'type {name}_options struct ''{') + for opt in go_opts: + print(opt.struct_declaration()) + print('}') + print(f'func parse_{name}_args(args []string) (*{name}_options, []string, error) ''{') + print(f'root := cli.Command{{Name: `{name}` }}') + for opt in go_opts: + print(opt.as_option('root')) + print('cmd, err := root.ParseArgs(args)') + print('if err != nil { return nil, nil, err }') + print(f'var opts {name}_options') + print('err = cmd.GetOptionValues(&opts)') + print('if err != nil { return nil, nil, err }') + print('return &opts, cmd.Args, nil') + print('}') + + def kitten_clis() -> None: + from kittens.runner import get_kitten_conf_docs, get_kitten_extra_cli_parsers for kitten in wrapped_kittens(): + defn = get_kitten_conf_docs(kitten) + if defn is not None: + generate_conf_parser(kitten, defn) + ecp = get_kitten_extra_cli_parsers(kitten) + if ecp: + for name, spec in ecp.items(): + with replace_if_needed(f'tools/cmd/{kitten}/{name}_cli_generated.go'): + print(f'package {kitten}') + generate_extra_cli_parser(name, spec) + with replace_if_needed(f'tools/cmd/{kitten}/cli_generated.go'): od = [] kcd = kitten_cli_docs(kitten) @@ -329,10 +370,11 @@ def kitten_clis() -> None: print('func create_cmd(root *cli.Command, run_func func(*cli.Command, *Options, []string)(int, error)) {') print('ans := root.AddSubCommand(&cli.Command{') print(f'Name: "{kitten}",') - print(f'ShortDescription: "{serialize_as_go_string(kcd["short_desc"])}",') - if kcd['usage']: - print(f'Usage: "[options] {serialize_as_go_string(kcd["usage"])}",') - print(f'HelpText: "{serialize_as_go_string(kcd["help_text"])}",') + if kcd: + print(f'ShortDescription: "{serialize_as_go_string(kcd["short_desc"])}",') + if kcd['usage']: + print(f'Usage: "[options] {serialize_as_go_string(kcd["usage"])}",') + print(f'HelpText: "{serialize_as_go_string(kcd["help_text"])}",') print('Run: func(cmd *cli.Command, args []string) (int, error) {') print('opts := Options{}') print('err := cmd.GetOptionValues(&opts)') @@ -351,6 +393,8 @@ def kitten_clis() -> None: print("clone := root.AddClone(ans.Group, ans)") print('clone.Hidden = false') print(f'clone.Name = "{serialize_as_go_string(kitten.replace("_", "-"))}"') + if not kcd: + print('specialize_command(ans)') print('}') print('type Options struct {') print('\n'.join(od)) @@ -383,11 +427,24 @@ def generate_spinners() -> str: def generate_color_names() -> str: + selfg = "" if Options.selection_foreground is None else Options.selection_foreground.as_sharp + selbg = "" if Options.selection_background is None else Options.selection_background.as_sharp + cursor = "" if Options.cursor is None else Options.cursor.as_sharp return 'package style\n\nvar ColorNames = map[string]RGBA{' + '\n'.join( f'\t"{name}": RGBA{{ Red:{val.red}, Green:{val.green}, Blue:{val.blue} }},' for name, val in color_names.items() ) + '\n}' + '\n\nvar ColorTable = [256]uint32{' + ', '.join( - f'{x}' for x in Options.color_table) + '}\n' + f'{x}' for x in Options.color_table) + '}\n' + f''' +var DefaultColors = struct {{ +Foreground, Background, Cursor, SelectionFg, SelectionBg string +}}{{ +Foreground: "{Options.foreground.as_sharp}", +Background: "{Options.background.as_sharp}", +Cursor: "{cursor}", +SelectionFg: "{selfg}", +SelectionBg: "{selbg}", +}} +''' def load_ref_map() -> Dict[str, Dict[str, str]]: @@ -399,6 +456,8 @@ def load_ref_map() -> Dict[str, Dict[str, str]]: def generate_constants() -> str: + from kitty.options.types import Options + from kitty.options.utils import allowed_shell_integration_values ref_map = load_ref_map() dp = ", ".join(map(lambda x: f'"{serialize_as_go_string(x)}"', kc.default_pager_for_help)) return f'''\ @@ -410,6 +469,7 @@ type VersionType struct {{ const VersionString string = "{kc.str_version}" const WebsiteBaseURL string = "{kc.website_base_url}" const VCSRevision string = "" +const SSHControlMasterTemplate = "{kc.ssh_control_master_template}" const RC_ENCRYPTION_PROTOCOL_VERSION string = "{kc.RC_ENCRYPTION_PROTOCOL_VERSION}" const IsFrozenBuild bool = false const IsStandaloneBuild bool = false @@ -421,6 +481,12 @@ var CharacterKeyNameAliases = map[string]string{serialize_go_dict(character_key_ var ConfigModMap = map[string]uint16{serialize_go_dict(config_mod_map)} var RefMap = map[string]string{serialize_go_dict(ref_map['ref'])} var DocTitleMap = map[string]string{serialize_go_dict(ref_map['doc'])} +var AllowedShellIntegrationValues = []string{{ {str(sorted(allowed_shell_integration_values))[1:-1].replace("'", '"')} }} +var KittyConfigDefaults = struct {{ +Term, Shell_integration string +}}{{ +Term: "{Options.term}", Shell_integration: "{' '.join(Options.shell_integration)}", +}} ''' # }}} @@ -598,6 +664,11 @@ def generate_textual_mimetypes() -> str: return '\n'.join(ans) +def write_compressed_data(data: bytes, d: BinaryIO) -> None: + d.write(struct.pack(' None: num_names, num_of_words = map(int, next(src).split()) gob = io.BytesIO() @@ -612,9 +683,31 @@ def generate_unicode_names(src: TextIO, dest: BinaryIO) -> None: if aliases: record += aliases.encode() gob.write(struct.pack(' None: + files = { + 'terminfo/kitty.terminfo', 'terminfo/x/xterm-kitty', + } + for dirpath, dirnames, filenames in os.walk('shell-integration'): + for f in filenames: + path = os.path.join(dirpath, f) + files.add(path.replace(os.sep, '/')) + dest = 'tools/cmd/ssh/data_generated.bin' + + def normalize(t: tarfile.TarInfo) -> tarfile.TarInfo: + t.uid = t.gid = 0 + t.uname = t.gname = '' + return t + + if newer(dest, *files): + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode='w') as tf: + for f in sorted(files): + tf.add(f, filter=normalize) + with open(dest, 'wb') as d: + write_compressed_data(buf.getvalue(), d) def main() -> None: @@ -633,6 +726,7 @@ def main() -> None: if newer('tools/unicode_names/data_generated.bin', 'tools/unicode_names/names.txt'): with open('tools/unicode_names/data_generated.bin', 'wb') as dest, open('tools/unicode_names/names.txt') as src: generate_unicode_names(src, dest) + generate_ssh_kitten_data() update_completion() update_at_commands() diff --git a/go.mod b/go.mod index e1106df31..3ea0d23dc 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,24 @@ go 1.20 require ( github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924 + github.com/bmatcuk/doublestar v1.3.4 github.com/disintegration/imaging v1.6.2 - github.com/google/go-cmp v0.5.8 + github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.0 github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f github.com/seancfoley/ipaddress-go v1.5.3 + github.com/shirou/gopsutil/v3 v3.23.1 golang.org/x/exp v0.0.0-20230202163644-54bba9f4231b golang.org/x/image v0.5.0 golang.org/x/sys v0.4.0 ) -require github.com/seancfoley/bintree v1.2.1 // indirect +require ( + github.com/go-ole/go-ole v1.2.6 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/seancfoley/bintree v1.2.1 // indirect + github.com/tklauser/go-sysconf v0.3.11 // indirect + github.com/tklauser/numcpus v0.6.0 // indirect + github.com/yusufpapurcu/wmi v1.2.2 // indirect +) diff --git a/go.sum b/go.sum index 43c0e463f..8abe1fc47 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,47 @@ github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924 h1:DG4UyTVIujioxwJc8Zj8Nabz1L1wTgQ/xNBSQDfdP3I= github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924/go.mod h1:+NaH2gLeY6RPBPPQf4aRotPPStg+eXc8f9ZaE4vRfD4= +github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= +github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f h1:Ko4+g6K16vSyUrtd/pPXuQnWsiHe5BYptEtTxfwYwCc= github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f/go.mod h1:eHzfhOKbTGJEGPSdMHzU6jft192tHHt2Bu2vIZArvC0= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/seancfoley/bintree v1.2.1 h1:Z/iNjRKkXnn0CTW7jDQYtjW5fz2GH1yWvOTJ4MrMvdo= github.com/seancfoley/bintree v1.2.1/go.mod h1:hIUabL8OFYyFVTQ6azeajbopogQc2l5C/hiXMcemWNU= github.com/seancfoley/ipaddress-go v1.5.3 h1:fLnn4nsatd2rp3IJsVWriXv5gXn2Qiy8uxjxe4iZtTg= github.com/seancfoley/ipaddress-go v1.5.3/go.mod h1:fpvVPC+Jso+YEhNcNiww8HQmBgKP8T4T6BTp1SLxxIo= +github.com/shirou/gopsutil/v3 v3.23.1 h1:a9KKO+kGLKEvcPIs4W62v0nu3sciVDOOOPUD0Hz7z/4= +github.com/shirou/gopsutil/v3 v3.23.1/go.mod h1:NN6mnm5/0k8jw4cBfCnJtr5L7ErOTg18tMNpgFkn0hA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM= +github.com/tklauser/go-sysconf v0.3.11/go.mod h1:GqXfhXY3kiPa0nAXPDIQIWzJbMCB7AmcWpGR8lSZfqI= +github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms= +github.com/tklauser/numcpus v0.6.0/go.mod h1:FEZLMke0lhOUG6w2JadTzp0a+Nl8PF/GFkQ5UVIcaL4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg= +github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20230202163644-54bba9f4231b h1:EqBVA+nNsObCwQoBEHy4wLU0pi7i8a4AL3pbItPdPkE= @@ -27,10 +56,13 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -43,3 +75,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/kittens/runner.py b/kittens/runner.py index be56c5335..3cee69fbd 100644 --- a/kittens/runner.py +++ b/kittens/runner.py @@ -7,7 +7,7 @@ import os import sys from contextlib import contextmanager from functools import partial -from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Generator, List, cast +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Generator, List, Optional, cast from kitty.constants import list_kitty_resources from kitty.types import run_once @@ -171,7 +171,7 @@ def get_kitten_completer(kitten: str) -> Any: return ans -def get_kitten_conf_docs(kitten: str) -> Definition: +def get_kitten_conf_docs(kitten: str) -> Optional[Definition]: setattr(sys, 'options_definition', None) run_kitten(kitten, run_name='__conf__') ans = getattr(sys, 'options_definition') @@ -179,6 +179,14 @@ def get_kitten_conf_docs(kitten: str) -> Definition: return cast(Definition, ans) +def get_kitten_extra_cli_parsers(kitten: str) -> Dict[str,str]: + setattr(sys, 'extra_cli_parsers', {}) + run_kitten(kitten, run_name='__extra_cli_parsers__') + ans = getattr(sys, 'extra_cli_parsers') + delattr(sys, 'extra_cli_parsers') + return cast(Dict[str, str], ans) + + def main() -> None: try: args = sys.argv[1:] diff --git a/kittens/ssh/config.py b/kittens/ssh/config.py deleted file mode 100644 index 0183df7b5..000000000 --- a/kittens/ssh/config.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python -# License: GPLv3 Copyright: 2022, Kovid Goyal - - -import fnmatch -import os -from typing import Any, Dict, Iterable, Optional - -from kitty.conf.utils import load_config as _load_config -from kitty.conf.utils import parse_config_base, resolve_config -from kitty.constants import config_dir - -from .options.types import Options as SSHOptions -from .options.types import defaults - -SYSTEM_CONF = '/etc/xdg/kitty/ssh.conf' -defconf = os.path.join(config_dir, 'ssh.conf') - - -def host_matches(mpat: str, hostname: str, username: str) -> bool: - for pat in mpat.split(): - upat = '*' - if '@' in pat: - upat, pat = pat.split('@', 1) - if fnmatch.fnmatchcase(hostname, pat) and fnmatch.fnmatchcase(username, upat): - return True - return False - - -def load_config(*paths: str, overrides: Optional[Iterable[str]] = None, hostname: str = '!', username: str = '') -> SSHOptions: - from .options.parse import create_result_dict, merge_result_dicts, parse_conf_item - from .options.utils import first_seen_positions, get_per_hosts_dict, init_results_dict - - def merge_dicts(base: Dict[str, Any], vals: Dict[str, Any]) -> Dict[str, Any]: - base_phd = get_per_hosts_dict(base) - vals_phd = get_per_hosts_dict(vals) - for hostname in base_phd: - vals_phd[hostname] = merge_result_dicts(base_phd[hostname], vals_phd.get(hostname, {})) - ans: Dict[str, Any] = vals_phd.pop(vals['hostname']) - ans['per_host_dicts'] = vals_phd - return ans - - def parse_config(lines: Iterable[str]) -> Dict[str, Any]: - ans: Dict[str, Any] = init_results_dict(create_result_dict()) - parse_config_base(lines, parse_conf_item, ans) - return ans - - overrides = tuple(overrides) if overrides is not None else () - first_seen_positions.clear() - first_seen_positions['*'] = 0 - opts_dict, paths = _load_config( - defaults, parse_config, merge_dicts, *paths, overrides=overrides, initialize_defaults=init_results_dict) - phd = get_per_hosts_dict(opts_dict) - final_dict: Dict[str, Any] = {} - for hostname_pat in sorted(phd, key=first_seen_positions.__getitem__): - if host_matches(hostname_pat, hostname, username): - od = phd[hostname_pat] - for k, v in od.items(): - if isinstance(v, dict): - bv = final_dict.setdefault(k, {}) - bv.update(v) - else: - final_dict[k] = v - first_seen_positions.clear() - return SSHOptions(final_dict) - - -def init_config(hostname: str, username: str, overrides: Optional[Iterable[str]] = None) -> SSHOptions: - config = tuple(resolve_config(SYSTEM_CONF, defconf)) - return load_config(*config, overrides=overrides, hostname=hostname, username=username) diff --git a/kittens/ssh/copy.py b/kittens/ssh/copy.py deleted file mode 100644 index 0ced93899..000000000 --- a/kittens/ssh/copy.py +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env python -# License: GPLv3 Copyright: 2022, Kovid Goyal - - -import glob -import os -import shlex -import uuid -from typing import Dict, Iterable, Iterator, List, NamedTuple, Optional, Sequence, Tuple - -from kitty.cli import parse_args -from kitty.cli_stub import CopyCLIOptions -from kitty.types import run_once - -from ..transfer.utils import expand_home, home_path - - -@run_once -def option_text() -> str: - return ''' ---glob -type=bool-set -Interpret file arguments as glob patterns. - - ---dest -The destination on the remote host to copy to. Relative paths are resolved -relative to HOME on the remote host. When this option is not specified, the -local file path is used as the remote destination (with the HOME directory -getting automatically replaced by the remote HOME). Note that environment -variables and ~ are not expanded. - - ---exclude -type=list -A glob pattern. Files with names matching this pattern are excluded from being -transferred. Useful when adding directories. Can -be specified multiple times, if any of the patterns match the file will be -excluded. To exclude a directory use a pattern like */directory_name/*. - - ---symlink-strategy -default=preserve -choices=preserve,resolve,keep-path -Control what happens if the specified path is a symlink. The default is to preserve -the symlink, re-creating it on the remote machine. Setting this to :code:`resolve` -will cause the symlink to be followed and its target used as the file/directory to copy. -The value of :code:`keep-path` is the same as :code:`resolve` except that the remote -file path is derived from the symlink's path instead of the path of the symlink's target. -Note that this option does not apply to symlinks encountered while recursively copying directories. -''' - - -def parse_copy_args(args: Optional[Sequence[str]] = None) -> Tuple[CopyCLIOptions, List[str]]: - args = list(args or ()) - try: - opts, args = parse_args(result_class=CopyCLIOptions, args=args, ospec=option_text) - except SystemExit as e: - raise CopyCLIError from e - return opts, args - - -def resolve_file_spec(spec: str, is_glob: bool) -> Iterator[str]: - ans = os.path.expandvars(expand_home(spec)) - if not os.path.isabs(ans): - ans = expand_home(f'~/{ans}') - if is_glob: - files = glob.glob(ans) - if not files: - raise CopyCLIError(f'{spec} does not exist') - else: - if not os.path.exists(ans): - raise CopyCLIError(f'{spec} does not exist') - files = [ans] - for x in files: - yield os.path.normpath(x).replace(os.sep, '/') - - -class CopyCLIError(ValueError): - pass - - -def get_arcname(loc: str, dest: Optional[str], home: str) -> str: - if dest: - arcname = dest - else: - arcname = os.path.normpath(loc) - if arcname.startswith(home): - arcname = os.path.relpath(arcname, home) - arcname = os.path.normpath(arcname).replace(os.sep, '/') - prefix = 'root' if arcname.startswith('/') else 'home/' - return prefix + arcname - - -class CopyInstruction(NamedTuple): - local_path: str - arcname: str - exclude_patterns: Tuple[str, ...] - - -def parse_copy_instructions(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]: - opts, args = parse_copy_args(shlex.split(val)) - locations: List[str] = [] - for a in args: - locations.extend(resolve_file_spec(a, opts.glob)) - if not locations: - raise CopyCLIError('No files to copy specified') - if len(locations) > 1 and opts.dest: - raise CopyCLIError('Specifying a remote location with more than one file is not supported') - home = home_path() - for loc in locations: - if opts.symlink_strategy != 'preserve': - rp = os.path.realpath(loc) - else: - rp = loc - arcname = get_arcname(rp if opts.symlink_strategy == 'resolve' else loc, opts.dest, home) - yield str(uuid.uuid4()), CopyInstruction(rp, arcname, tuple(opts.exclude)) diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index 0e270acf9..cb541c15f 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -1,728 +1,214 @@ #!/usr/bin/env python3 # License: GPL v3 Copyright: 2018, Kovid Goyal -import fnmatch -import glob -import io -import json -import os -import re -import secrets -import shlex -import shutil -import stat -import subprocess import sys -import tarfile -import tempfile -import termios -import time -import traceback -from base64 import standard_b64decode, standard_b64encode -from contextlib import contextmanager, suppress -from getpass import getuser -from select import select -from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Tuple, Union, cast +from typing import List, Optional -from kitty.constants import cache_dir, runtime_dir, shell_integration_dir, ssh_control_master_template, str_version, terminfo_dir -from kitty.shell_integration import as_str_literal -from kitty.shm import SharedMemory +from kitty.conf.types import Definition from kitty.types import run_once -from kitty.utils import SSHConnectionData, expandvars, resolve_abs_or_config_path -from kitty.utils import set_echo as turn_off_echo -from ..tui.operations import RESTORE_PRIVATE_MODE_VALUES, SAVE_PRIVATE_MODE_VALUES, Mode, restore_colors, save_colors, set_mode -from ..tui.utils import kitty_opts, running_in_tmux -from .config import init_config -from .copy import CopyInstruction -from .options.types import Options as SSHOptions -from .options.utils import DELETE_ENV_VAR -from .utils import create_shared_memory, get_ssh_cli, is_extra_arg, passthrough_args +copy_message = '''\ +Copy files and directories from local to remote hosts. The specified files are +assumed to be relative to the HOME directory and copied to the HOME on the +remote. Directories are copied recursively. If absolute paths are used, they are +copied as is.''' @run_once -def ssh_exe() -> str: - return shutil.which('ssh') or 'ssh' - - -def read_data_from_shared_memory(shm_name: str) -> Any: - with SharedMemory(shm_name, readonly=True) as shm: - shm.unlink() - if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid(): - raise ValueError('Incorrect owner on pwfile') - mode = stat.S_IMODE(shm.stats.st_mode) - if mode != stat.S_IREAD | stat.S_IWRITE: - raise ValueError('Incorrect permissions on pwfile') - return json.loads(shm.read_data_with_size()) - - -# See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html -quote_pat = re.compile('([\\`"])') - - -def quote_env_val(x: str, literal_quote: bool = False) -> str: - if literal_quote: - return as_str_literal(x) - x = quote_pat.sub(r'\\\1', x) - x = x.replace('$(', r'\$(') # prevent execution with $() - return f'"{x}"' - - -def serialize_env(literal_env: Dict[str, str], env: Dict[str, str], base_env: Dict[str, str], for_python: bool = False) -> bytes: - lines = [] - literal_quote = True - - if for_python: - def a(k: str, val: str = '', prefix: str = 'export') -> None: - if val: - lines.append(f'{prefix} {json.dumps((k, val, literal_quote))}') - else: - lines.append(f'{prefix} {json.dumps((k,))}') - else: - def a(k: str, val: str = '', prefix: str = 'export') -> None: - if val: - lines.append(f'{prefix} {shlex.quote(k)}={quote_env_val(val, literal_quote)}') - else: - lines.append(f'{prefix} {shlex.quote(k)}') - - for k, v in literal_env.items(): - a(k, v) - - literal_quote = False - for k in sorted(env): - v = env[k] - if v == DELETE_ENV_VAR: - a(k, prefix='unset') - elif v == '_kitty_copy_env_var_': - q = base_env.get(k) - if q is not None: - a(k, q) - else: - a(k, v) - return '\n'.join(lines).encode('utf-8') - - -def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz', literal_env: Dict[str, str] = {}) -> bytes: - - def normalize_tarinfo(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo: - tarinfo.uname = tarinfo.gname = '' - tarinfo.uid = tarinfo.gid = 0 - # some distro's like nix mess with installed file permissions so ensure - # files are at least readable and writable by owning user - tarinfo.mode |= stat.S_IWUSR | stat.S_IRUSR - return tarinfo - - def add_data_as_file(tf: tarfile.TarFile, arcname: str, data: Union[str, bytes]) -> tarfile.TarInfo: - ans = tarfile.TarInfo(arcname) - ans.mtime = 0 - ans.type = tarfile.REGTYPE - if isinstance(data, str): - data = data.encode('utf-8') - ans.size = len(data) - normalize_tarinfo(ans) - tf.addfile(ans, io.BytesIO(data)) - return ans - - def filter_from_globs(*pats: str) -> Callable[[tarfile.TarInfo], Optional[tarfile.TarInfo]]: - def filter(tarinfo: tarfile.TarInfo) -> Optional[tarfile.TarInfo]: - for junk_dir in ('.DS_Store', '__pycache__'): - for pat in (f'*/{junk_dir}', f'*/{junk_dir}/*'): - if fnmatch.fnmatch(tarinfo.name, pat): - return None - for pat in pats: - if fnmatch.fnmatch(tarinfo.name, pat): - return None - return normalize_tarinfo(tarinfo) - return filter - - from kitty.shell_integration import get_effective_ksi_env_var - if ssh_opts.shell_integration == 'inherited': - ksi = get_effective_ksi_env_var(kitty_opts()) - else: - from kitty.options.types import Options - from kitty.options.utils import shell_integration - ksi = get_effective_ksi_env_var(Options({'shell_integration': shell_integration(ssh_opts.shell_integration)})) - - env = { - 'TERM': os.environ.get('TERM') or kitty_opts().term, - 'COLORTERM': 'truecolor', - } - env.update(ssh_opts.env) - for q in ('KITTY_WINDOW_ID', 'WINDOWID'): - val = os.environ.get(q) - if val is not None: - env[q] = val - env['KITTY_SHELL_INTEGRATION'] = ksi or DELETE_ENV_VAR - env['KITTY_SSH_KITTEN_DATA_DIR'] = ssh_opts.remote_dir - if ssh_opts.login_shell: - env['KITTY_LOGIN_SHELL'] = ssh_opts.login_shell - if ssh_opts.cwd: - env['KITTY_LOGIN_CWD'] = ssh_opts.cwd - if ssh_opts.remote_kitty != 'no': - env['KITTY_REMOTE'] = ssh_opts.remote_kitty - if os.environ.get('KITTY_PUBLIC_KEY'): - env.pop('KITTY_PUBLIC_KEY', None) - literal_env['KITTY_PUBLIC_KEY'] = os.environ['KITTY_PUBLIC_KEY'] - env_script = serialize_env(literal_env, env, base_env, for_python=compression != 'gz') - buf = io.BytesIO() - with tarfile.open(mode=f'w:{compression}', fileobj=buf, encoding='utf-8') as tf: - rd = ssh_opts.remote_dir.rstrip('/') - for ci in ssh_opts.copy.values(): - tf.add(ci.local_path, arcname=ci.arcname, filter=filter_from_globs(*ci.exclude_patterns)) - add_data_as_file(tf, 'data.sh', env_script) - if compression == 'gz': - tf.add(f'{shell_integration_dir}/ssh/bootstrap-utils.sh', arcname='bootstrap-utils.sh', filter=normalize_tarinfo) - if ksi: - arcname = 'home/' + rd + '/shell-integration' - tf.add(shell_integration_dir, arcname=arcname, filter=filter_from_globs( - f'{arcname}/ssh/*', # bootstrap files are sent as command line args - f'{arcname}/zsh/kitty.zsh', # present for legacy compat not needed by ssh kitten - )) - if ssh_opts.remote_kitty != 'no': - arcname = 'home/' + rd + '/kitty' - add_data_as_file(tf, arcname + '/version', str_version.encode('ascii')) - tf.add(shell_integration_dir + '/ssh/kitty', arcname=arcname + '/bin/kitty', filter=normalize_tarinfo) - tf.add(shell_integration_dir + '/ssh/kitten', arcname=arcname + '/bin/kitten', filter=normalize_tarinfo) - tf.add(f'{terminfo_dir}/kitty.terminfo', arcname='home/.terminfo/kitty.terminfo', filter=normalize_tarinfo) - tf.add(glob.glob(f'{terminfo_dir}/*/xterm-kitty')[0], arcname='home/.terminfo/x/xterm-kitty', filter=normalize_tarinfo) - return buf.getvalue() - - -def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]: - yield b'\nKITTY_DATA_START\n' # to discard leading data - try: - msg = standard_b64decode(msg).decode('utf-8') - md = dict(x.split('=', 1) for x in msg.split(':')) - pw = md['pw'] - pwfilename = md['pwfile'] - rq_id = md['id'] - except Exception: - traceback.print_exc() - yield b'invalid ssh data request message\n' - else: - try: - env_data = read_data_from_shared_memory(pwfilename) - if pw != env_data['pw']: - raise ValueError('Incorrect password') - if rq_id != request_id: - raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window') - except Exception as e: - traceback.print_exc() - yield f'{e}\n'.encode('utf-8') - else: - yield b'OK\n' - ssh_opts = SSHOptions(env_data['opts']) - ssh_opts.copy = {k: CopyInstruction(*v) for k, v in ssh_opts.copy.items()} - encoded_data = memoryview(env_data['tarfile'].encode('ascii')) - # macOS has a 255 byte limit on its input queue as per man stty. - # Not clear if that applies to canonical mode input as well, but - # better to be safe. - line_sz = 254 - while encoded_data: - yield encoded_data[:line_sz] - yield b'\n' - encoded_data = encoded_data[line_sz:] - yield b'KITTY_DATA_END\n' - - -def safe_remove(x: str) -> None: - with suppress(OSError): - os.remove(x) - - -def prepare_script(ans: str, replacements: Dict[str, str], script_type: str) -> str: - for k in ('EXEC_CMD', 'EXPORT_HOME_CMD'): - replacements[k] = replacements.get(k, '') - - def sub(m: 're.Match[str]') -> str: - return replacements[m.group()] - - return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans) - - -def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str: - # ssh simply concatenates multiple commands using a space see - # line 1129 of ssh.c and on the remote side sshd.c runs the - # concatenated command as shell -c cmd - if is_python: - return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii') - args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args) - return f"""unset KITTY_SHELL_INTEGRATION; exec "$login_shell" -c '{args}'""" - - -def prepare_export_home_cmd(ssh_opts: SSHOptions, is_python: bool) -> str: - home = ssh_opts.env.get('HOME') - if home == '_kitty_copy_env_var_': - home = os.environ.get('HOME') - if home: - if is_python: - return standard_b64encode(home.encode('utf-8')).decode('ascii') - else: - return f'export HOME={quote_env_val(home)}; cd "$HOME"' - return '' - - -def bootstrap_script( - ssh_opts: SSHOptions, script_type: str = 'sh', remote_args: Sequence[str] = (), - test_script: str = '', request_id: Optional[str] = None, cli_hostname: str = '', cli_uname: str = '', - request_data: bool = False, echo_on: bool = True, literal_env: Dict[str, str] = {} -) -> Tuple[str, Dict[str, str], str]: - if request_id is None: - request_id = os.environ['KITTY_PID'] + '-' + os.environ['KITTY_WINDOW_ID'] - is_python = script_type == 'py' - export_home_cmd = prepare_export_home_cmd(ssh_opts, is_python) if 'HOME' in ssh_opts.env else '' - exec_cmd = prepare_exec_cmd(remote_args, is_python) if remote_args else '' - with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f: - ans = f.read() - pw = secrets.token_hex() - tfd = standard_b64encode(make_tarfile(ssh_opts, dict(os.environ), 'gz' if script_type == 'sh' else 'bz2', literal_env=literal_env)).decode('ascii') - data = {'pw': pw, 'opts': ssh_opts._asdict(), 'hostname': cli_hostname, 'uname': cli_uname, 'tarfile': tfd} - shm_name = create_shared_memory(data, prefix=f'kssh-{os.getpid()}-') - sensitive_data = {'REQUEST_ID': request_id, 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm_name} - replacements = { - 'EXPORT_HOME_CMD': export_home_cmd, - 'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script, - 'REQUEST_DATA': '1' if request_data else '0', 'ECHO_ON': '1' if echo_on else '0', - } - sd = replacements.copy() - if request_data: - sd.update(sensitive_data) - replacements.update(sensitive_data) - return prepare_script(ans, sd, script_type), replacements, shm_name - - -def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]: - boolean_ssh_args, other_ssh_args = get_ssh_cli() - port: Optional[int] = None - expecting_port = expecting_identity = False - expecting_option_val = False - expecting_hostname = False - expecting_extra_val = '' - host_name = identity_file = found_ssh = '' - found_extra_args: List[Tuple[str, str]] = [] - - for i, arg in enumerate(args): - if not found_ssh: - if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'): - found_ssh = arg - continue - if expecting_hostname: - host_name = arg - continue - if arg.startswith('-') and not expecting_option_val: - if arg in boolean_ssh_args: - continue - if arg == '--': - expecting_hostname = True - if arg.startswith('-p'): - if arg[2:].isdigit(): - with suppress(Exception): - port = int(arg[2:]) - continue - elif arg == '-p': - expecting_port = True - elif arg.startswith('-i'): - if arg == '-i': - expecting_identity = True - else: - identity_file = arg[2:] - continue - if arg.startswith('--') and extra_args: - matching_ex = is_extra_arg(arg, extra_args) - if matching_ex: - if '=' in arg: - exval = arg.partition('=')[-1] - found_extra_args.append((matching_ex, exval)) - continue - expecting_extra_val = matching_ex - - expecting_option_val = True - continue - - if expecting_option_val: - if expecting_port: - with suppress(Exception): - port = int(arg) - expecting_port = False - elif expecting_identity: - identity_file = arg - elif expecting_extra_val: - found_extra_args.append((expecting_extra_val, arg)) - expecting_extra_val = '' - expecting_option_val = False - continue - - if not host_name: - host_name = arg - if not host_name: - return None - if host_name.startswith('ssh://'): - from urllib.parse import urlparse - purl = urlparse(host_name) - if purl.hostname: - host_name = purl.hostname - if purl.username: - host_name = f'{purl.username}@{host_name}' - if port is None and purl.port: - port = purl.port - if identity_file: - if not os.path.isabs(identity_file): - identity_file = os.path.expanduser(identity_file) - if not os.path.isabs(identity_file): - identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file)) - - return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args)) - - -class InvalidSSHArgs(ValueError): - - def __init__(self, msg: str = ''): - super().__init__(msg) - self.err_msg = msg - - def system_exit(self) -> None: - if self.err_msg: - print(self.err_msg, file=sys.stderr) - os.execlp(ssh_exe(), 'ssh') - - -def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[List[str], List[str], bool, Tuple[str, ...]]: - boolean_ssh_args, other_ssh_args = get_ssh_cli() - ssh_args = [] - server_args: List[str] = [] - expecting_option_val = False - passthrough = False - stop_option_processing = False - found_extra_args: List[str] = [] - expecting_extra_val = '' - for argument in args: - if len(server_args) > 1 or stop_option_processing: - server_args.append(argument) - continue - if argument.startswith('-') and not expecting_option_val: - if argument == '--': - stop_option_processing = True - continue - if extra_args: - matching_ex = is_extra_arg(argument, extra_args) - if matching_ex: - if '=' in argument: - exval = argument.partition('=')[-1] - found_extra_args.extend((matching_ex, exval)) - else: - expecting_extra_val = matching_ex - expecting_option_val = True - continue - # could be a multi-character option - all_args = argument[1:] - for i, arg in enumerate(all_args): - arg = f'-{arg}' - if arg in passthrough_args: - passthrough = True - if arg in boolean_ssh_args: - ssh_args.append(arg) - continue - if arg in other_ssh_args: - ssh_args.append(arg) - rest = all_args[i+1:] - if rest: - ssh_args.append(rest) - else: - expecting_option_val = True - break - raise InvalidSSHArgs(f'unknown option -- {arg[1:]}') - continue - if expecting_option_val: - if expecting_extra_val: - found_extra_args.extend((expecting_extra_val, argument)) - expecting_extra_val = '' - else: - ssh_args.append(argument) - expecting_option_val = False - continue - server_args.append(argument) - if not server_args: - raise InvalidSSHArgs() - return ssh_args, server_args, passthrough, tuple(found_extra_args) - - -def wrap_bootstrap_script(sh_script: str, interpreter: str) -> List[str]: - # sshd will execute the command we pass it by join all command line - # arguments with a space and passing it as a single argument to the users - # login shell with -c. If the user has a non POSIX login shell it might - # have different escaping semantics and syntax, so the command it should - # execute has to be as simple as possible, basically of the form - # interpreter -c unwrap_script escaped_bootstrap_script - # The unwrap_script is responsible for unescaping the bootstrap script and - # executing it. - q = os.path.basename(interpreter).lower() - is_python = 'python' in q - if is_python: - es = standard_b64encode(sh_script.encode('utf-8')).decode('ascii') - unwrap_script = '''"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"''' - else: - # We cant rely on base64 being available on the remote system, so instead - # we quote the bootstrap script by replacing ' and \ with \v and \f - # also replacing \n and ! with \r and \b for tcsh - # finally surrounding with ' - es = "'" + sh_script.replace("'", '\v').replace('\\', '\f').replace('\n', '\r').replace('!', '\b') + "'" - unwrap_script = r"""'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' """ - # exec is supported by all sh like shells, and fish and csh - return ['exec', interpreter, '-c', unwrap_script, es] - - -def get_remote_command( - remote_args: List[str], ssh_opts: SSHOptions, cli_hostname: str = '', cli_uname: str = '', - echo_on: bool = True, request_data: bool = False, literal_env: Dict[str, str] = {} -) -> Tuple[List[str], Dict[str, str], str]: - interpreter = ssh_opts.interpreter - q = os.path.basename(interpreter).lower() - is_python = 'python' in q - sh_script, replacements, shm_name = bootstrap_script( - ssh_opts, script_type='py' if is_python else 'sh', remote_args=remote_args, literal_env=literal_env, - cli_hostname=cli_hostname, cli_uname=cli_uname, echo_on=echo_on, request_data=request_data) - return wrap_bootstrap_script(sh_script, interpreter), replacements, shm_name - - -def connection_sharing_args(kitty_pid: int) -> List[str]: - rd = runtime_dir() - # Bloody OpenSSH generates a 40 char hash and in creating the socket - # appends a 27 char temp suffix to it. Socket max path length is approx - # ~104 chars. macOS has no system runtime dir so we use a cache dir in - # /Users/WHY_DOES_ANYONE_USE_MACOS/Library/Caches/APPLE_ARE_IDIOTIC - if len(rd) > 35 and os.path.isdir('/tmp'): - idiotic_design = f'/tmp/kssh-rdir-{os.getuid()}' - try: - os.symlink(rd, idiotic_design) - except FileExistsError: - try: - dest = os.readlink(idiotic_design) - except OSError as e: - raise ValueError(f'The {idiotic_design} symlink could not be created as something with that name exists already') from e - else: - if dest != rd: - with tempfile.TemporaryDirectory(dir='/tmp') as tdir: - tlink = os.path.join(tdir, 'sigh') - os.symlink(rd, tlink) - os.rename(tlink, idiotic_design) - rd = idiotic_design - - cp = os.path.join(rd, ssh_control_master_template.format(kitty_pid=kitty_pid, ssh_placeholder='%C')) - ans: List[str] = [ - '-o', 'ControlMaster=auto', - '-o', f'ControlPath={cp}', - '-o', 'ControlPersist=yes', - '-o', 'ServerAliveInterval=60', - '-o', 'ServerAliveCountMax=5', - '-o', 'TCPKeepAlive=no', - ] - return ans - - -@contextmanager -def restore_terminal_state() -> Iterator[bool]: - with open(os.ctermid()) as f: - val = termios.tcgetattr(f.fileno()) - print(end=SAVE_PRIVATE_MODE_VALUES) - print(end=set_mode(Mode.HANDLE_TERMIOS_SIGNALS), flush=True) - try: - yield bool(val[3] & termios.ECHO) - finally: - termios.tcsetattr(f.fileno(), termios.TCSAFLUSH, val) - print(end=RESTORE_PRIVATE_MODE_VALUES, flush=True) - - -def dcs_to_kitty(payload: Union[bytes, str], type: str = 'ssh') -> bytes: - if isinstance(payload, str): - payload = payload.encode('utf-8') - payload = standard_b64encode(payload) - ans = b'\033P@kitty-' + type.encode('ascii') + b'|' + payload - tmux = running_in_tmux() - if tmux: - cp = subprocess.run([tmux, 'set', '-p', 'allow-passthrough', 'on']) - if cp.returncode != 0: - raise SystemExit(cp.returncode) - ans = b'\033Ptmux;\033' + ans + b'\033\033\\\033\\' - else: - ans += b'\033\\' - return ans - - -@run_once -def ssh_version() -> Tuple[int, int]: - o = subprocess.check_output([ssh_exe(), '-V'], stderr=subprocess.STDOUT).decode() - m = re.match(r'OpenSSH_(\d+).(\d+)', o) - if m is None: - raise ValueError(f'Invalid version string for OpenSSH: {o}') - return int(m.group(1)), int(m.group(2)) - - -@contextmanager -def drain_potential_tty_garbage(p: 'subprocess.Popen[bytes]', data_request: str) -> Iterator[None]: - with open(os.open(os.ctermid(), os.O_CLOEXEC | os.O_RDWR | os.O_NOCTTY), 'wb') as tty: - if data_request: - turn_off_echo(tty.fileno()) - tty.write(dcs_to_kitty(data_request)) - tty.flush() - try: - yield - finally: - # discard queued input data on tty in case data transmission was - # interrupted due to SSH failure, avoids spewing garbage to screen - from uuid import uuid4 - canary = uuid4().hex.encode('ascii') - turn_off_echo(tty.fileno()) - tty.write(dcs_to_kitty(canary + b'\n\r', type='echo')) - tty.flush() - data = b'' - give_up_at = time.monotonic() + 2 - tty_fd = tty.fileno() - while time.monotonic() < give_up_at and canary not in data: - with suppress(KeyboardInterrupt): - rd, wr, err = select([tty_fd], [], [tty_fd], max(0, give_up_at - time.monotonic())) - if err or not rd: - break - q = os.read(tty_fd, io.DEFAULT_BUFFER_SIZE) - if not q: - break - data += q - - -def change_colors(color_scheme: str) -> bool: - if not color_scheme: - return False - from kittens.themes.collection import NoCacheFound, load_themes, text_as_opts - from kittens.themes.main import colors_as_escape_codes - if color_scheme.endswith('.conf'): - conf_file = resolve_abs_or_config_path(color_scheme) - try: - with open(conf_file) as f: - opts = text_as_opts(f.read()) - except FileNotFoundError: - raise SystemExit(f'Failed to find the color conf file: {expandvars(conf_file)}') - else: - try: - themes = load_themes(-1) - except NoCacheFound: - themes = load_themes() - cs = expandvars(color_scheme) - try: - theme = themes[cs] - except KeyError: - raise SystemExit(f'Failed to find the color theme: {cs}') - opts = theme.kitty_opts - raw = colors_as_escape_codes(opts) - print(save_colors(), sep='', end=raw, flush=True) - return True - - -def add_cloned_env(shm_name: str) -> Dict[str, str]: - try: - return cast(Dict[str, str], read_data_from_shared_memory(shm_name)) - except FileNotFoundError: - pass - return {} - - -def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple[str, ...]) -> NoReturn: - cmd = [ssh_exe()] + ssh_args - hostname, remote_args = server_args[0], server_args[1:] - if not remote_args: - cmd.append('-t') - insertion_point = len(cmd) - cmd.append('--') - cmd.append(hostname) - uname = getuser() - if hostname.startswith('ssh://'): - from urllib.parse import urlparse - purl = urlparse(hostname) - hostname_for_match = purl.hostname or hostname[6:].split('/', 1)[0] - uname = purl.username or uname - elif '@' in hostname and hostname[0] != '@': - uname, hostname_for_match = hostname.split('@', 1) - else: - hostname_for_match = hostname - hostname_for_match = hostname_for_match.split('@', 1)[-1].split(':', 1)[0] - overrides: List[str] = [] - literal_env: Dict[str, str] = {} - pat = re.compile(r'^([a-zA-Z0-9_]+)[ \t]*=') - for i, a in enumerate(found_extra_args): - if i % 2 == 1: - aq = pat.sub(r'\1 ', a.lstrip()) - key = aq.split(maxsplit=1)[0] - if key == 'clone_env': - literal_env = add_cloned_env(aq.split(maxsplit=1)[1]) - elif key != 'hostname': - overrides.append(aq) - if overrides: - overrides.insert(0, f'hostname {uname}@{hostname_for_match}') - host_opts = init_config(hostname_for_match, uname, overrides) - if host_opts.share_connections: - cmd[insertion_point:insertion_point] = connection_sharing_args(int(os.environ['KITTY_PID'])) - use_kitty_askpass = host_opts.askpass == 'native' or (host_opts.askpass == 'unless-set' and 'SSH_ASKPASS' not in os.environ) - need_to_request_data = True - if use_kitty_askpass: - sentinel = os.path.join(cache_dir(), 'openssh-is-new-enough-for-askpass') - sentinel_exists = os.path.exists(sentinel) - if sentinel_exists or ssh_version() >= (8, 4): - if not sentinel_exists: - open(sentinel, 'w').close() - # SSH_ASKPASS_REQUIRE was introduced in 8.4 release on 2020-09-27 - need_to_request_data = False - os.environ['SSH_ASKPASS_REQUIRE'] = 'force' - os.environ['SSH_ASKPASS'] = os.path.join(shell_integration_dir, 'ssh', 'askpass.py') - if need_to_request_data and host_opts.share_connections: - cp = subprocess.run(cmd[:1] + ['-O', 'check'] + cmd[1:], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - if cp.returncode == 0: - # we will use the master connection so SSH does not need to use the tty - need_to_request_data = False - with restore_terminal_state() as echo_on: - rcmd, replacements, shm_name = get_remote_command( - remote_args, host_opts, hostname_for_match, uname, echo_on, request_data=need_to_request_data, literal_env=literal_env) - cmd += rcmd - colors_changed = change_colors(host_opts.color_scheme) - try: - p = subprocess.Popen(cmd) - except FileNotFoundError: - raise SystemExit('Could not find the ssh executable, is it in your PATH?') - else: - rq = '' if need_to_request_data else 'id={REQUEST_ID}:pwfile={PASSWORD_FILENAME}:pw={DATA_PASSWORD}'.format(**replacements) - with drain_potential_tty_garbage(p, rq): - raise SystemExit(p.wait()) - finally: - if colors_changed: - print(end=restore_colors(), flush=True) - - -def main(args: List[str]) -> None: - args = args[1:] - if args and args[0] == 'use-python': - args = args[1:] # backwards compat from when we had a python implementation - try: - ssh_args, server_args, passthrough, found_extra_args = parse_ssh_args(args, extra_args=('--kitten',)) - except InvalidSSHArgs as e: - e.system_exit() - if passthrough: - if found_extra_args: - raise SystemExit(f'The SSH kitten cannot work with the options: {", ".join(passthrough_args)}') - os.execlp(ssh_exe(), 'ssh', *args) - - if not os.environ.get('KITTY_WINDOW_ID') or not os.environ.get('KITTY_PID'): - raise SystemExit('The SSH kitten is meant to run inside a kitty window') - if not sys.stdin.isatty(): - raise SystemExit('The SSH kitten is meant for interactive use only, STDIN must be a terminal') - try: - run_ssh(ssh_args, server_args, found_extra_args) - except KeyboardInterrupt: - sys.excepthook = lambda *a: None - raise - +def option_text() -> str: + return ''' +--glob +type=bool-set +Interpret file arguments as glob patterns. Globbing is based on +Based on standard wildcards with the addition that ``/**/`` matches any number of directories. +See the :link:`detailed syntax `. + + +--dest +The destination on the remote host to copy to. Relative paths are resolved +relative to HOME on the remote host. When this option is not specified, the +local file path is used as the remote destination (with the HOME directory +getting automatically replaced by the remote HOME). Note that environment +variables and ~ are not expanded. + + +--exclude +type=list +A glob pattern. Files with names matching this pattern are excluded from being +transferred. Useful when adding directories. Can +be specified multiple times, if any of the patterns match the file will be +excluded. If the pattern includes a :code:`/` then it will match against the full +path, not just the filename. In such patterns you can use :code:`/**/` to match zero +or more directories. For example, to exclude a directory and everything under it use +:code:`**/directory_name`. +See the :link:`detailed syntax ` for +how wildcards match. + + +--symlink-strategy +default=preserve +choices=preserve,resolve,keep-path +Control what happens if the specified path is a symlink. The default is to preserve +the symlink, re-creating it on the remote machine. Setting this to :code:`resolve` +will cause the symlink to be followed and its target used as the file/directory to copy. +The value of :code:`keep-path` is the same as :code:`resolve` except that the remote +file path is derived from the symlink's path instead of the path of the symlink's target. +Note that this option does not apply to symlinks encountered while recursively copying directories, +those are always preserved. +''' + + + +definition = Definition( + '!kittens.ssh', +) + +agr = definition.add_group +egr = definition.end_group +opt = definition.add_option + +agr('bootstrap', 'Host bootstrap configuration') # {{{ + +opt('hostname', '*', long_text=''' +The hostname that the following options apply to. A glob pattern to match +multiple hosts can be used. Multiple hostnames can also be specified, separated +by spaces. The hostname can include an optional username in the form +:code:`user@host`. When not specified options apply to all hosts, until the +first hostname specification is found. Note that matching of hostname is done +against the name you specify on the command line to connect to the remote host. +If you wish to include the same basic configuration for many different hosts, +you can do so with the :ref:`include ` directive. +''') + +opt('interpreter', 'sh', long_text=''' +The interpreter to use on the remote host. Must be either a POSIX complaint +shell or a :program:`python` executable. If the default :program:`sh` is not +available or broken, using an alternate interpreter can be useful. +''') + +opt('remote_dir', '.local/share/kitty-ssh-kitten', long_text=''' +The location on the remote host where the files needed for this kitten are +installed. Relative paths are resolved with respect to :code:`$HOME`. +''') + +opt('+copy', '', add_to_default=False, ctype='CopyInstruction', long_text=f''' +{copy_message} For example:: + + copy .vimrc .zshrc .config/some-dir + +Use :code:`--dest` to copy a file to some other destination on the remote host:: + + copy --dest some-other-name some-file + +Glob patterns can be specified to copy multiple files, with :code:`--glob`:: + + copy --glob images/*.png + +Files can be excluded when copying with :code:`--exclude`:: + + copy --glob --exclude *.jpg --exclude *.bmp images/* + +Files whose remote name matches the exclude pattern will not be copied. +For more details, see :ref:`ssh_copy_command`. +''') +egr() # }}} + +agr('shell', 'Login shell environment') # {{{ + +opt('shell_integration', 'inherited', long_text=''' +Control the shell integration on the remote host. See :ref:`shell_integration` +for details on how this setting works. The special value :code:`inherited` means +use the setting from :file:`kitty.conf`. This setting is useful for overriding +integration on a per-host basis. +''') + +opt('login_shell', '', long_text=''' +The login shell to execute on the remote host. By default, the remote user +account's login shell is used. +''') + +opt('+env', '', add_to_default=False, ctype='EnvInstruction', long_text=''' +Specify the environment variables to be set on the remote host. Using the +name with an equal sign (e.g. :code:`env VAR=`) will set it to the empty string. +Specifying only the name (e.g. :code:`env VAR`) will remove the variable from +the remote shell environment. The special value :code:`_kitty_copy_env_var_` +will cause the value of the variable to be copied from the local environment. +The definitions are processed alphabetically. Note that environment variables +are expanded recursively, for example:: + + env VAR1=a + env VAR2=${HOME}/${VAR1}/b + +The value of :code:`VAR2` will be :code:`/a/b`. +''') + +opt('cwd', '', long_text=''' +The working directory on the remote host to change to. Environment variables in +this value are expanded. The default is empty so no changing is done, which +usually means the HOME directory is used. +''') + +opt('color_scheme', '', long_text=''' +Specify a color scheme to use when connecting to the remote host. If this option +ends with :code:`.conf`, it is assumed to be the name of a config file to load +from the kitty config directory, otherwise it is assumed to be the name of a +color theme to load via the :doc:`themes kitten `. Note that +only colors applying to the text/background are changed, other config settings +in the .conf files/themes are ignored. +''') + +opt('remote_kitty', 'if-needed', choices=('if-needed', 'no', 'yes'), long_text=''' +Make :program:`kitty` available on the remote host. Useful to run kittens such +as the :doc:`icat kitten ` to display images or the +:doc:`transfer file kitten ` to transfer files. Only works if +the remote host has an architecture for which :link:`pre-compiled kitty binaries +` are available. Note that kitty +is not actually copied to the remote host, instead a small bootstrap script is +copied which will download and run kitty when kitty is first executed on the +remote host. A value of :code:`if-needed` means kitty is installed only if not +already present in the system-wide PATH. A value of :code:`yes` means that kitty +is installed even if already present, and the installed kitty takes precedence. +Finally, :code:`no` means no kitty is installed on the remote host. The +installed kitty can be updated by running: :code:`kitty +update-kitty` on the +remote host. +''') +egr() # }}} + +agr('ssh', 'SSH configuration') # {{{ + +opt('share_connections', 'yes', option_type='to_bool', long_text=''' +Within a single kitty instance, all connections to a particular server can be +shared. This reduces startup latency for subsequent connections and means that +you have to enter the password only once. Under the hood, it uses SSH +ControlMasters and these are automatically cleaned up by kitty when it quits. +You can map a shortcut to :ac:`close_shared_ssh_connections` to disconnect all +active shared connections. +''') + +opt('askpass', 'unless-set', choices=('unless-set', 'ssh', 'native'), long_text=''' +Control the program SSH uses to ask for passwords or confirmation of host keys +etc. The default is to use kitty's native :program:`askpass`, unless the +:envvar:`SSH_ASKPASS` environment variable is set. Set this option to +:code:`ssh` to not interfere with the normal ssh askpass mechanism at all, which +typically means that ssh will prompt at the terminal. Set it to :code:`native` +to always use kitty's native, built-in askpass implementation. Note that not +using the kitty askpass implementation means that SSH might need to use the +terminal before the connection is established, so the kitten cannot use the +terminal to send data without an extra roundtrip, adding to initial connection +latency. +''') +egr() # }}} + + +def main(args: List[str]) -> Optional[str]: + raise SystemExit('This should be run as kitten unicode_input') if __name__ == '__main__': - main(sys.argv) + main([]) elif __name__ == '__wrapper_of__': - cd = sys.cli_docs # type: ignore + cd = getattr(sys, 'cli_docs') cd['wrapper_of'] = 'ssh' elif __name__ == '__conf__': - from .options.definition import definition - sys.options_definition = definition # type: ignore + setattr(sys, 'options_definition', definition) +elif __name__ == '__extra_cli_parsers__': + setattr(sys, 'extra_cli_parsers', {'copy': option_text()}) diff --git a/kittens/ssh/options/__init__.py b/kittens/ssh/options/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/kittens/ssh/options/definition.py b/kittens/ssh/options/definition.py deleted file mode 100644 index 14fe00ca7..000000000 --- a/kittens/ssh/options/definition.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# vim:fileencoding=utf-8 -# License: GPLv3 Copyright: 2021, Kovid Goyal - -# After editing this file run ./gen-config.py to apply the changes - -from kitty.conf.types import Definition - -copy_message = '''\ -Copy files and directories from local to remote hosts. The specified files are -assumed to be relative to the HOME directory and copied to the HOME on the -remote. Directories are copied recursively. If absolute paths are used, they are -copied as is.''' - -definition = Definition( - 'kittens.ssh', -) - -agr = definition.add_group -egr = definition.end_group -opt = definition.add_option - -agr('bootstrap', 'Host bootstrap configuration') # {{{ - -opt('hostname', '*', option_type='hostname', long_text=''' -The hostname that the following options apply to. A glob pattern to match -multiple hosts can be used. Multiple hostnames can also be specified, separated -by spaces. The hostname can include an optional username in the form -:code:`user@host`. When not specified options apply to all hosts, until the -first hostname specification is found. Note that matching of hostname is done -against the name you specify on the command line to connect to the remote host. -If you wish to include the same basic configuration for many different hosts, -you can do so with the :ref:`include ` directive. -''') - -opt('interpreter', 'sh', long_text=''' -The interpreter to use on the remote host. Must be either a POSIX complaint -shell or a :program:`python` executable. If the default :program:`sh` is not -available or broken, using an alternate interpreter can be useful. -''') - -opt('remote_dir', '.local/share/kitty-ssh-kitten', long_text=''' -The location on the remote host where the files needed for this kitten are -installed. Relative paths are resolved with respect to :code:`$HOME`. -''') - -opt('+copy', '', option_type='copy', add_to_default=False, long_text=f''' -{copy_message} For example:: - - copy .vimrc .zshrc .config/some-dir - -Use :code:`--dest` to copy a file to some other destination on the remote host:: - - copy --dest some-other-name some-file - -Glob patterns can be specified to copy multiple files, with :code:`--glob`:: - - copy --glob images/*.png - -Files can be excluded when copying with :code:`--exclude`:: - - copy --glob --exclude *.jpg --exclude *.bmp images/* - -Files whose remote name matches the exclude pattern will not be copied. -For more details, see :ref:`ssh_copy_command`. -''') -egr() # }}} - -agr('shell', 'Login shell environment') # {{{ - -opt('shell_integration', 'inherited', long_text=''' -Control the shell integration on the remote host. See :ref:`shell_integration` -for details on how this setting works. The special value :code:`inherited` means -use the setting from :file:`kitty.conf`. This setting is useful for overriding -integration on a per-host basis. -''') - -opt('login_shell', '', long_text=''' -The login shell to execute on the remote host. By default, the remote user -account's login shell is used. -''') - -opt('+env', '', option_type='env', add_to_default=False, long_text=''' -Specify the environment variables to be set on the remote host. Using the -name with an equal sign (e.g. :code:`env VAR=`) will set it to the empty string. -Specifying only the name (e.g. :code:`env VAR`) will remove the variable from -the remote shell environment. The special value :code:`_kitty_copy_env_var_` -will cause the value of the variable to be copied from the local environment. -The definitions are processed alphabetically. Note that environment variables -are expanded recursively, for example:: - - env VAR1=a - env VAR2=${HOME}/${VAR1}/b - -The value of :code:`VAR2` will be :code:`/a/b`. -''') - -opt('cwd', '', long_text=''' -The working directory on the remote host to change to. Environment variables in -this value are expanded. The default is empty so no changing is done, which -usually means the HOME directory is used. -''') - -opt('color_scheme', '', long_text=''' -Specify a color scheme to use when connecting to the remote host. If this option -ends with :code:`.conf`, it is assumed to be the name of a config file to load -from the kitty config directory, otherwise it is assumed to be the name of a -color theme to load via the :doc:`themes kitten `. Note that -only colors applying to the text/background are changed, other config settings -in the .conf files/themes are ignored. -''') - -opt('remote_kitty', 'if-needed', choices=('if-needed', 'no', 'yes'), long_text=''' -Make :program:`kitty` available on the remote host. Useful to run kittens such -as the :doc:`icat kitten ` to display images or the -:doc:`transfer file kitten ` to transfer files. Only works if -the remote host has an architecture for which :link:`pre-compiled kitty binaries -` are available. Note that kitty -is not actually copied to the remote host, instead a small bootstrap script is -copied which will download and run kitty when kitty is first executed on the -remote host. A value of :code:`if-needed` means kitty is installed only if not -already present in the system-wide PATH. A value of :code:`yes` means that kitty -is installed even if already present, and the installed kitty takes precedence. -Finally, :code:`no` means no kitty is installed on the remote host. The -installed kitty can be updated by running: :code:`kitty +update-kitty` on the -remote host. -''') -egr() # }}} - -agr('ssh', 'SSH configuration') # {{{ - -opt('share_connections', 'yes', option_type='to_bool', long_text=''' -Within a single kitty instance, all connections to a particular server can be -shared. This reduces startup latency for subsequent connections and means that -you have to enter the password only once. Under the hood, it uses SSH -ControlMasters and these are automatically cleaned up by kitty when it quits. -You can map a shortcut to :ac:`close_shared_ssh_connections` to disconnect all -active shared connections. -''') - -opt('askpass', 'unless-set', choices=('unless-set', 'ssh', 'native'), long_text=''' -Control the program SSH uses to ask for passwords or confirmation of host keys -etc. The default is to use kitty's native :program:`askpass`, unless the -:envvar:`SSH_ASKPASS` environment variable is set. Set this option to -:code:`ssh` to not interfere with the normal ssh askpass mechanism at all, which -typically means that ssh will prompt at the terminal. Set it to :code:`native` -to always use kitty's native, built-in askpass implementation. Note that not -using the kitty askpass implementation means that SSH might need to use the -terminal before the connection is established, so the kitten cannot use the -terminal to send data without an extra roundtrip, adding to initial connection -latency. -''') -egr() # }}} diff --git a/kittens/ssh/options/parse.py b/kittens/ssh/options/parse.py deleted file mode 100644 index aa380f007..000000000 --- a/kittens/ssh/options/parse.py +++ /dev/null @@ -1,90 +0,0 @@ -# generated by gen-config.py DO NOT edit - -# isort: skip_file -import typing -from kittens.ssh.options.utils import copy, env, hostname -from kitty.conf.utils import merge_dicts, to_bool - - -class Parser: - - def askpass(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - val = val.lower() - if val not in self.choices_for_askpass: - raise ValueError(f"The value {val} is not a valid choice for askpass") - ans["askpass"] = val - - choices_for_askpass = frozenset(('unless-set', 'ssh', 'native')) - - def color_scheme(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - ans['color_scheme'] = str(val) - - def copy(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - for k, v in copy(val, ans["copy"]): - ans["copy"][k] = v - - def cwd(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - ans['cwd'] = str(val) - - def env(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - for k, v in env(val, ans["env"]): - ans["env"][k] = v - - def hostname(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - hostname(val, ans) - - def interpreter(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - ans['interpreter'] = str(val) - - def login_shell(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - ans['login_shell'] = str(val) - - def remote_dir(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - ans['remote_dir'] = str(val) - - def remote_kitty(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - val = val.lower() - if val not in self.choices_for_remote_kitty: - raise ValueError(f"The value {val} is not a valid choice for remote_kitty") - ans["remote_kitty"] = val - - choices_for_remote_kitty = frozenset(('if-needed', 'no', 'yes')) - - def share_connections(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - ans['share_connections'] = to_bool(val) - - def shell_integration(self, val: str, ans: typing.Dict[str, typing.Any]) -> None: - ans['shell_integration'] = str(val) - - -def create_result_dict() -> typing.Dict[str, typing.Any]: - return { - 'copy': {}, - 'env': {}, - } - - -actions: typing.FrozenSet[str] = frozenset(()) - - -def merge_result_dicts(defaults: typing.Dict[str, typing.Any], vals: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: - ans = {} - for k, v in defaults.items(): - if isinstance(v, dict): - ans[k] = merge_dicts(v, vals.get(k, {})) - elif k in actions: - ans[k] = v + vals.get(k, []) - else: - ans[k] = vals.get(k, v) - return ans - - -parser = Parser() - - -def parse_conf_item(key: str, val: str, ans: typing.Dict[str, typing.Any]) -> bool: - func = getattr(parser, key, None) - if func is not None: - func(val, ans) - return True - return False diff --git a/kittens/ssh/options/types.py b/kittens/ssh/options/types.py deleted file mode 100644 index 90ce49fc5..000000000 --- a/kittens/ssh/options/types.py +++ /dev/null @@ -1,93 +0,0 @@ -# generated by gen-config.py DO NOT edit - -# isort: skip_file -import typing -import kittens.ssh.copy - -if typing.TYPE_CHECKING: - choices_for_askpass = typing.Literal['unless-set', 'ssh', 'native'] - choices_for_remote_kitty = typing.Literal['if-needed', 'no', 'yes'] -else: - choices_for_askpass = str - choices_for_remote_kitty = str - -option_names = ( # {{{ - 'askpass', - 'color_scheme', - 'copy', - 'cwd', - 'env', - 'hostname', - 'interpreter', - 'login_shell', - 'remote_dir', - 'remote_kitty', - 'share_connections', - 'shell_integration') # }}} - - -class Options: - askpass: choices_for_askpass = 'unless-set' - color_scheme: str = '' - cwd: str = '' - hostname: str = '*' - interpreter: str = 'sh' - login_shell: str = '' - remote_dir: str = '.local/share/kitty-ssh-kitten' - remote_kitty: choices_for_remote_kitty = 'if-needed' - share_connections: bool = True - shell_integration: str = 'inherited' - copy: typing.Dict[str, kittens.ssh.copy.CopyInstruction] = {} - env: typing.Dict[str, str] = {} - config_paths: typing.Tuple[str, ...] = () - config_overrides: typing.Tuple[str, ...] = () - - def __init__(self, options_dict: typing.Optional[typing.Dict[str, typing.Any]] = None) -> None: - if options_dict is not None: - null = object() - for key in option_names: - val = options_dict.get(key, null) - if val is not null: - setattr(self, key, val) - - @property - def _fields(self) -> typing.Tuple[str, ...]: - return option_names - - def __iter__(self) -> typing.Iterator[str]: - return iter(self._fields) - - def __len__(self) -> int: - return len(self._fields) - - def _copy_of_val(self, name: str) -> typing.Any: - ans = getattr(self, name) - if isinstance(ans, dict): - ans = ans.copy() - elif isinstance(ans, list): - ans = ans[:] - return ans - - def _asdict(self) -> typing.Dict[str, typing.Any]: - return {k: self._copy_of_val(k) for k in self} - - def _replace(self, **kw: typing.Any) -> "Options": - ans = Options() - for name in self: - setattr(ans, name, self._copy_of_val(name)) - for name, val in kw.items(): - setattr(ans, name, val) - return ans - - def __getitem__(self, key: typing.Union[int, str]) -> typing.Any: - k = option_names[key] if isinstance(key, int) else key - try: - return getattr(self, k) - except AttributeError: - pass - raise KeyError(f"No option named: {k}") - - -defaults = Options() -defaults.copy = {} -defaults.env = {} diff --git a/kittens/ssh/options/utils.py b/kittens/ssh/options/utils.py deleted file mode 100644 index 135bb2704..000000000 --- a/kittens/ssh/options/utils.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python -# License: GPLv3 Copyright: 2022, Kovid Goyal - -from typing import Any, Dict, Iterable, Optional, Tuple - -from ..copy import CopyInstruction, parse_copy_instructions - -DELETE_ENV_VAR = '_delete_this_env_var_' - - -def env(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, str]]: - val = val.strip() - if val: - if '=' in val: - key, v = val.split('=', 1) - key, v = key.strip(), v.strip() - if key: - yield key, v - else: - yield val, DELETE_ENV_VAR - - -def copy(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]: - yield from parse_copy_instructions(val, current_val) - - -def init_results_dict(ans: Dict[str, Any]) -> Dict[str, Any]: - ans['hostname'] = '*' - ans['per_host_dicts'] = {} - return ans - - -def get_per_hosts_dict(results_dict: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: - ans: Dict[str, Dict[str, Any]] = results_dict.get('per_host_dicts', {}).copy() - h = results_dict['hostname'] - hd = {k: v for k, v in results_dict.items() if k != 'per_host_dicts'} - ans[h] = hd - return ans - - -first_seen_positions: Dict[str, int] = {} - - -def hostname(val: str, dict_with_parse_results: Optional[Dict[str, Any]] = None) -> str: - if dict_with_parse_results is not None: - ch = dict_with_parse_results['hostname'] - if val != ch: - from .parse import create_result_dict - phd = get_per_hosts_dict(dict_with_parse_results) - dict_with_parse_results.clear() - dict_with_parse_results.update(phd.pop(val, create_result_dict())) - dict_with_parse_results['per_host_dicts'] = phd - dict_with_parse_results['hostname'] = val - if val not in first_seen_positions: - first_seen_positions[val] = len(first_seen_positions) - return val diff --git a/kittens/ssh/utils.py b/kittens/ssh/utils.py index 540ec5d09..66b58242d 100644 --- a/kittens/ssh/utils.py +++ b/kittens/ssh/utils.py @@ -4,9 +4,12 @@ import os import subprocess -from typing import Any, Dict, List, Sequence, Set, Tuple +import traceback +from contextlib import suppress +from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple from kitty.types import run_once +from kitty.utils import SSHConnectionData @run_once @@ -94,6 +97,57 @@ def create_shared_memory(data: Any, prefix: str) -> str: return shm.name +def read_data_from_shared_memory(shm_name: str) -> Any: + import json + import stat + + from kitty.shm import SharedMemory + with SharedMemory(shm_name, readonly=True) as shm: + shm.unlink() + if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid(): + raise ValueError('Incorrect owner on pwfile') + mode = stat.S_IMODE(shm.stats.st_mode) + if mode != stat.S_IREAD | stat.S_IWRITE: + raise ValueError('Incorrect permissions on pwfile') + return json.loads(shm.read_data_with_size()) + + +def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]: + from base64 import standard_b64decode + yield b'\nKITTY_DATA_START\n' # to discard leading data + try: + msg = standard_b64decode(msg).decode('utf-8') + md = dict(x.split('=', 1) for x in msg.split(':')) + pw = md['pw'] + pwfilename = md['pwfile'] + rq_id = md['id'] + except Exception: + traceback.print_exc() + yield b'invalid ssh data request message\n' + else: + try: + env_data = read_data_from_shared_memory(pwfilename) + if pw != env_data['pw']: + raise ValueError('Incorrect password') + if rq_id != request_id: + raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window') + except Exception as e: + traceback.print_exc() + yield f'{e}\n'.encode('utf-8') + else: + yield b'OK\n' + encoded_data = memoryview(env_data['tarfile'].encode('ascii')) + # macOS has a 255 byte limit on its input queue as per man stty. + # Not clear if that applies to canonical mode input as well, but + # better to be safe. + line_sz = 254 + while encoded_data: + yield encoded_data[:line_sz] + yield b'\n' + encoded_data = encoded_data[line_sz:] + yield b'KITTY_DATA_END\n' + + def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None: patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv) @@ -183,3 +237,86 @@ def set_server_args_in_cmdline( ans.insert(i, '-t') break argv[:] = ans + server_args + + +def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]: + boolean_ssh_args, other_ssh_args = get_ssh_cli() + port: Optional[int] = None + expecting_port = expecting_identity = False + expecting_option_val = False + expecting_hostname = False + expecting_extra_val = '' + host_name = identity_file = found_ssh = '' + found_extra_args: List[Tuple[str, str]] = [] + + for i, arg in enumerate(args): + if not found_ssh: + if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'): + found_ssh = arg + continue + if expecting_hostname: + host_name = arg + continue + if arg.startswith('-') and not expecting_option_val: + if arg in boolean_ssh_args: + continue + if arg == '--': + expecting_hostname = True + if arg.startswith('-p'): + if arg[2:].isdigit(): + with suppress(Exception): + port = int(arg[2:]) + continue + elif arg == '-p': + expecting_port = True + elif arg.startswith('-i'): + if arg == '-i': + expecting_identity = True + else: + identity_file = arg[2:] + continue + if arg.startswith('--') and extra_args: + matching_ex = is_extra_arg(arg, extra_args) + if matching_ex: + if '=' in arg: + exval = arg.partition('=')[-1] + found_extra_args.append((matching_ex, exval)) + continue + expecting_extra_val = matching_ex + + expecting_option_val = True + continue + + if expecting_option_val: + if expecting_port: + with suppress(Exception): + port = int(arg) + expecting_port = False + elif expecting_identity: + identity_file = arg + elif expecting_extra_val: + found_extra_args.append((expecting_extra_val, arg)) + expecting_extra_val = '' + expecting_option_val = False + continue + + if not host_name: + host_name = arg + if not host_name: + return None + if host_name.startswith('ssh://'): + from urllib.parse import urlparse + purl = urlparse(host_name) + if purl.hostname: + host_name = purl.hostname + if purl.username: + host_name = f'{purl.username}@{host_name}' + if port is None and purl.port: + port = purl.port + if identity_file: + if not os.path.isabs(identity_file): + identity_file = os.path.expanduser(identity_file) + if not os.path.isabs(identity_file): + identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file)) + + return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args)) diff --git a/kitty/cli_stub.py b/kitty/cli_stub.py index f9dfd88ba..bc1d01f13 100644 --- a/kitty/cli_stub.py +++ b/kitty/cli_stub.py @@ -13,7 +13,7 @@ LaunchCLIOptions = AskCLIOptions = ClipboardCLIOptions = DiffCLIOptions = CLIOpt HintsCLIOptions = IcatCLIOptions = PanelCLIOptions = ResizeCLIOptions = CLIOptions ErrorCLIOptions = UnicodeCLIOptions = RCOptions = RemoteFileCLIOptions = CLIOptions QueryTerminalCLIOptions = BroadcastCLIOptions = ShowKeyCLIOptions = CLIOptions -ThemesCLIOptions = TransferCLIOptions = CopyCLIOptions = CLIOptions +ThemesCLIOptions = TransferCLIOptions = CLIOptions def generate_stub() -> None: @@ -78,9 +78,6 @@ def generate_stub() -> None: from kittens.transfer.main import option_text as OPTIONS do(OPTIONS(), 'TransferCLIOptions') - from kittens.ssh.copy import option_text as OPTIONS - do(OPTIONS(), 'CopyCLIOptions') - from kitty.rc.base import all_command_names, command_for_name for cmd_name in all_command_names(): cmd = command_for_name(cmd_name) diff --git a/kitty/conf/generate.py b/kitty/conf/generate.py index 1075663da..044b8fd09 100644 --- a/kitty/conf/generate.py +++ b/kitty/conf/generate.py @@ -9,7 +9,7 @@ import re import textwrap from typing import Any, Callable, Dict, Iterator, List, Set, Tuple, Union, get_type_hints -from kitty.conf.types import Definition, MultiOption, Option, unset +from kitty.conf.types import Definition, MultiOption, Option, ParserFuncType, unset from kitty.types import _T @@ -442,6 +442,121 @@ def write_output(loc: str, defn: Definition) -> None: f.write(f'{c}\n') +def go_type_data(parser_func: ParserFuncType, ctype: str) -> Tuple[str, str]: + if ctype: + return f'*{ctype}', f'Parse{ctype}(val)' + p = parser_func.__name__ + if p == 'int': + return 'int64', 'strconv.ParseInt(val, 10, 64)' + if p == 'str': + return 'string', 'val, nil' + if p == 'float': + return 'float64', 'strconv.ParseFloat(val, 10, 64)' + if p == 'to_bool': + return 'bool', 'config.StringToBool(val), nil' + th = get_type_hints(parser_func) + rettype = th['return'] + return {int: 'int64', str: 'string', float: 'float64'}[rettype], f'{p}(val)' + + +def gen_go_code(defn: Definition) -> str: + lines = ['import "fmt"', 'import "strconv"', 'import "kitty/tools/config"', + 'var _ = fmt.Println', 'var _ = config.StringToBool', 'var _ = strconv.Atoi'] + a = lines.append + choices = {} + go_types = {} + go_parsers = {} + defaults = {} + multiopts = {''} + for option in sorted(defn.iter_all_options(), key=lambda a: natural_keys(a.name)): + name = option.name.capitalize() + if isinstance(option, MultiOption): + go_types[name], go_parsers[name] = go_type_data(option.parser_func, option.ctype) + multiopts.add(name) + else: + defaults[name] = option.parser_func(option.defval_as_string) + if option.choices: + choices[name] = option.choices + go_types[name] = f'{name}_Choice_Type' + go_parsers[name] = f'Parse_{name}(val)' + continue + go_types[name], go_parsers[name] = go_type_data(option.parser_func, option.ctype) + + for oname in choices: + a(f'type {go_types[oname]} int') + a('type Config struct {') + for name, gotype in go_types.items(): + if name in multiopts: + a(f'{name} []{gotype}') + else: + a(f'{name} {gotype}') + a('}') + + def cval(x: str) -> str: + return x.replace('-', '_') + + a('func NewConfig() *Config {') + a('return &Config{') + for name, pname in go_parsers.items(): + if name in multiopts: + continue + d = defaults[name] + if not d: + continue + if isinstance(d, str): + dval = f'{name}_{cval(d)}' if name in choices else f'`{d}`' + elif isinstance(d, bool): + dval = repr(d).lower() + else: + dval = repr(d) + a(f'{name}: {dval},') + a('}''}') + + for oname, choice_vals in choices.items(): + a('const (') + for i, c in enumerate(choice_vals): + c = cval(c) + if i == 0: + a(f'{oname}_{c} {oname}_Choice_Type = iota') + else: + a(f'{oname}_{c}') + a(')') + a(f'func (x {oname}_Choice_Type) String() string'' {') + a('switch x {') + a('default: return ""') + for c in choice_vals: + a(f'case {oname}_{cval(c)}: return "{c}"') + a('}''}') + a(f'func {go_parsers[oname].split("(")[0]}(val string) (ans {go_types[oname]}, err error) ''{') + a('switch val {') + for c in choice_vals: + a(f'case "{c}": return {oname}_{cval(c)}, nil') + vals = ', '.join(choice_vals) + a(f'default: return ans, fmt.Errorf("%#v is not a valid value for %s. Valid values are: %s", val, "{c}", "{vals}")') + a('}''}') + + a('func (c *Config) Parse(key, val string) (err error) {') + a('switch key {') + a('default: return fmt.Errorf("Unknown configuration key: %#v", key)') + for oname, pname in go_parsers.items(): + ol = oname.lower() + is_multiple = oname in multiopts + a(f'case "{ol}":') + if is_multiple: + a(f'var temp_val []{go_types[oname]}') + else: + a(f'var temp_val {go_types[oname]}') + a(f'temp_val, err = {pname}') + a(f'if err != nil {{ return fmt.Errorf("Failed to parse {ol} = %#v with error: %w", val, err) }}') + if is_multiple: + a(f'c.{oname} = append(c.{oname}, temp_val...)') + else: + a(f'c.{oname} = temp_val') + a('}') + a('return}') + return '\n'.join(lines) + + def main() -> None: # To use run it as: # kitty +runpy 'from kitty.conf.generate import main; main()' /path/to/kitten/file.py diff --git a/kitty/options/utils.py b/kitty/options/utils.py index 50adaf0cc..2254f08b9 100644 --- a/kitty/options/utils.py +++ b/kitty/options/utils.py @@ -854,12 +854,14 @@ def store_multiple(val: str, current_val: Container[str]) -> Iterable[Tuple[str, yield val, val +allowed_shell_integration_values = frozenset({'enabled', 'disabled', 'no-rc', 'no-cursor', 'no-title', 'no-prompt-mark', 'no-complete', 'no-cwd'}) + + def shell_integration(x: str) -> FrozenSet[str]: - s = frozenset({'enabled', 'disabled', 'no-rc', 'no-cursor', 'no-title', 'no-prompt-mark', 'no-complete', 'no-cwd'}) q = frozenset(x.lower().split()) - if not q.issubset(s): - log_error(f'Invalid shell integration options: {q - s}, ignoring') - return q & s or frozenset({'invalid'}) + if not q.issubset(allowed_shell_integration_values): + log_error(f'Invalid shell integration options: {q - allowed_shell_integration_values}, ignoring') + return q & allowed_shell_integration_values or frozenset({'invalid'}) return q diff --git a/kitty/window.py b/kitty/window.py index 76e4eaed6..e25531472 100644 --- a/kitty/window.py +++ b/kitty/window.py @@ -959,7 +959,7 @@ class Window: def handle_remote_file(self, netloc: str, remote_path: str) -> None: from kittens.remote_file.main import is_ssh_kitten_sentinel - from kittens.ssh.main import get_connection_data + from kittens.ssh.utils import get_connection_data from .utils import SSHConnectionData args = self.ssh_kitten_cmdline() @@ -1156,7 +1156,7 @@ class Window: self.write_to_child(data) def handle_remote_ssh(self, msg: str) -> None: - from kittens.ssh.main import get_ssh_data + from kittens.ssh.utils import get_ssh_data for line in get_ssh_data(msg, f'{os.getpid()}-{self.id}'): self.write_to_child(line) diff --git a/kitty_tests/__init__.py b/kitty_tests/__init__.py index 5737cad56..017c796aa 100644 --- a/kitty_tests/__init__.py +++ b/kitty_tests/__init__.py @@ -112,7 +112,7 @@ class Callbacks: self.current_clone_data += rest def handle_remote_ssh(self, msg): - from kittens.ssh.main import get_ssh_data + from kittens.ssh.utils import get_ssh_data if self.pty: for line in get_ssh_data(msg, "testing"): self.pty.write_to_child(line) diff --git a/kitty_tests/check_build.py b/kitty_tests/check_build.py index d2463e6a3..89745f3d4 100644 --- a/kitty_tests/check_build.py +++ b/kitty_tests/check_build.py @@ -67,7 +67,7 @@ class TestBuild(BaseTest): q = stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH return mode & q == q - for x in ('kitty', 'kitten', 'askpass.py'): + for x in ('kitty', 'kitten'): x = os.path.join(shell_integration_dir, 'ssh', x) self.assertTrue(is_executable(x), f'{x} is not executable') if getattr(sys, 'frozen', False): diff --git a/kitty_tests/shm.py b/kitty_tests/shm.py new file mode 100644 index 000000000..b2cc70406 --- /dev/null +++ b/kitty_tests/shm.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# License: GPLv3 Copyright: 2023, Kovid Goyal + + +import os +import subprocess + +from kitty.constants import kitten_exe +from kitty.fast_data_types import shm_unlink +from kitty.shm import SharedMemory + +from . import BaseTest + + +class SHMTest(BaseTest): + + def test_shm_with_kitten(self): + data = os.urandom(333) + with SharedMemory(size=363) as shm: + shm.write_data_with_size(data) + cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'read', shm.name], stdout=subprocess.PIPE) + self.assertEqual(cp.returncode, 0) + self.assertEqual(cp.stdout, data) + self.assertRaises(FileNotFoundError, shm_unlink, shm.name) + cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'write'], input=data, stdout=subprocess.PIPE) + self.assertEqual(cp.returncode, 0) + name = cp.stdout.decode().strip() + with SharedMemory(name=name, unlink_on_exit=True) as shm: + q = shm.read_data_with_size() + self.assertEqual(data, q) diff --git a/kitty_tests/ssh.py b/kitty_tests/ssh.py index a6b9d6046..b03668d1e 100644 --- a/kitty_tests/ssh.py +++ b/kitty_tests/ssh.py @@ -3,18 +3,16 @@ import glob +import json import os import shutil +import subprocess import tempfile from contextlib import suppress from functools import lru_cache -from kittens.ssh.config import load_config -from kittens.ssh.main import bootstrap_script, get_connection_data, wrap_bootstrap_script -from kittens.ssh.options.types import Options as SSHOptions -from kittens.ssh.options.utils import DELETE_ENV_VAR -from kittens.transfer.utils import set_paths -from kitty.constants import is_macos, runtime_dir +from kittens.ssh.utils import get_connection_data +from kitty.constants import is_macos, kitten_exe, runtime_dir from kitty.fast_data_types import CURSOR_BEAM, shm_unlink from kitty.utils import SSHConnectionData @@ -58,32 +56,6 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77) t('ssh --kitten=one -p 12 --kitten two -ix main', identity_file='x', port=12, extra_args=(('--kitten', 'one'), ('--kitten', 'two'))) self.assertTrue(runtime_dir()) - def test_ssh_config_parsing(self): - def parse(conf, hostname='unmatched_host', username=''): - return load_config(overrides=conf.splitlines(), hostname=hostname, username=username) - - self.ae(parse('').env, {}) - self.ae(parse('env a=b').env, {'a': 'b'}) - conf = 'env a=b\nhostname 2\nenv a=c\nenv b=b' - self.ae(parse(conf).env, {'a': 'b'}) - self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'}) - self.ae(parse('env a=').env, {'a': ''}) - self.ae(parse('env a').env, {'a': '_delete_this_env_var_'}) - conf = 'env a=b\nhostname test@2\nenv a=c\nenv b=b' - self.ae(parse(conf).env, {'a': 'b'}) - self.ae(parse(conf, '2').env, {'a': 'b'}) - self.ae(parse(conf, '2', 'test').env, {'a': 'c', 'b': 'b'}) - conf = 'env a=b\nhostname 1 2\nenv a=c\nenv b=b' - self.ae(parse(conf).env, {'a': 'b'}) - self.ae(parse(conf, '1').env, {'a': 'c', 'b': 'b'}) - self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'}) - - def test_ssh_bootstrap_sh_cmd_limit(self): - # dropbear has a 9000 bytes maximum command length limit - sh_script, _, _ = bootstrap_script(SSHOptions({'interpreter': 'sh'}), script_type='sh', remote_args=[], request_id='123-123') - rcmd = wrap_bootstrap_script(sh_script, 'sh') - self.assertLessEqual(sum(len(x) for x in rcmd), 9000) - @property @lru_cache() def all_possible_sh(self): @@ -98,11 +70,13 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77) f.write(simple_data) for sh in self.all_possible_sh: - with self.subTest(sh=sh), tempfile.TemporaryDirectory() as remote_home, tempfile.TemporaryDirectory() as local_home, set_paths(home=local_home): + with self.subTest(sh=sh), tempfile.TemporaryDirectory() as remote_home, tempfile.TemporaryDirectory() as local_home: tuple(map(touch, 'simple-file g.1 g.2'.split())) os.makedirs(f'{local_home}/d1/d2/d3') touch('d1/d2/x') touch('d1/d2/w.exclude') + os.mkdir(f'{local_home}/d1/r') + touch('d1/r/noooo') os.symlink('d2/x', f'{local_home}/d1/y') os.symlink('simple-file', f'{local_home}/s1') os.symlink('simple-file', f'{local_home}/s2') @@ -110,15 +84,13 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77) conf = '''\ copy simple-file copy s1 -copy --symlink-strategy=keep-name s2 +copy --symlink-strategy=keep-path s2 copy --dest=a/sfa simple-file copy --glob g.* -copy --exclude */w.* d1 +copy --exclude **/w.* --exclude **/r d1 ''' - copy = load_config(overrides=filter(None, conf.splitlines())).copy self.check_bootstrap( - sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', - ssh_opts={'copy': copy} + sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', conf=conf, home=local_home, ) tname = '.terminfo' if os.path.exists('/usr/share/misc/terminfo.cdb'): @@ -148,17 +120,18 @@ copy --exclude */w.* d1 self.ae(len(glob.glob(f'{remote_home}/{tname}/*/xterm-kitty')), 2) def test_ssh_env_vars(self): - tset = '$A-$(echo no)-`echo no2` !Q5 "something\nelse"' + tset = '$A-$(echo no)-`echo no2` !Q5 "something else"' for sh in self.all_possible_sh: with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir: os.mkdir(os.path.join(tdir, 'cwd')) + conf = f''' +cwd $HOME/cwd +env A=AAA +env TSET={tset} +env COLORTERM +''' pty = self.check_bootstrap( - sh, tdir, test_script='env; pwd; exit 0', SHELL_INTEGRATION_VALUE='', - ssh_opts={'cwd': '$HOME/cwd', 'env': { - 'A': 'AAA', - 'TSET': tset, - 'COLORTERM': DELETE_ENV_VAR, - }} + sh, tdir, test_script='env; pwd; exit 0', SHELL_INTEGRATION_VALUE='', conf=conf ) pty.wait_till(lambda: 'TSET={}'.format(tset.replace('$A', 'AAA')) in pty.screen_contents()) self.assertNotIn('COLORTERM', pty.screen_contents()) @@ -240,34 +213,34 @@ copy --exclude */w.* d1 self.assertEqual(pty.screen.cursor.shape, 0) self.assertNotIn(b'\x1b]133;', pty.received_bytes) - def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None, launcher='sh'): - ssh_opts = ssh_opts or {} + def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', conf='', launcher='sh', home=''): if login_shell: - ssh_opts['login_shell'] = login_shell + conf += f'\nlogin_shell {login_shell}' if 'python' in sh: if test_script.startswith('env;'): test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})' test_script = f'print("UNTAR_DONE", flush=True); {test_script}' else: test_script = f'echo "UNTAR_DONE"; {test_script}' - ssh_opts['shell_integration'] = SHELL_INTEGRATION_VALUE or 'disabled' - script, replacements, shm_name = bootstrap_script( - SSHOptions(ssh_opts), script_type='py' if 'python' in sh else 'sh', request_id="testing", test_script=test_script, - request_data=True - ) + conf += '\nshell_integration ' + (SHELL_INTEGRATION_VALUE or 'disabled') + conf += '\ninterpreter ' + sh + env = os.environ.copy() + if home: + env['HOME'] = home + cp = subprocess.run([kitten_exe(), '__pytest__', 'ssh', test_script], env=env, stdout=subprocess.PIPE, input=conf.encode('utf-8')) + self.assertEqual(cp.returncode, 0) + self.rdata = json.loads(cp.stdout) + del cp try: env = basic_shell_env(home_dir) # Avoid generating unneeded completion scripts os.makedirs(os.path.join(home_dir, '.local', 'share', 'fish', 'generated_completions'), exist_ok=True) # prevent newuser-install from running open(os.path.join(home_dir, '.zshrc'), 'w').close() - cmd = wrap_bootstrap_script(script, sh) - pty = self.create_pty([launcher, '-c', ' '.join(cmd)], cwd=home_dir, env=env) + pty = self.create_pty([launcher, '-c', ' '.join(self.rdata['cmd'])], cwd=home_dir, env=env) pty.turn_off_echo() - del cmd if pre_data: pty.write_buf = pre_data.encode('utf-8') - del script def check_untar_or_fail(): q = pty.screen_contents() @@ -284,4 +257,4 @@ copy --exclude */w.* d1 return pty finally: with suppress(FileNotFoundError): - shm_unlink(shm_name) + shm_unlink(self.rdata['shm_name']) diff --git a/pyproject.toml b/pyproject.toml index 4c1505aef..90f20d628 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.mypy] -files = 'kitty,kittens,glfw,*.py,docs/conf.py,shell-integration/ssh/askpass.py' +files = 'kitty,kittens,glfw,*.py,docs/conf.py' no_implicit_optional = true sqlite_cache = true cache_fine_grained = true diff --git a/setup.py b/setup.py index 0a4250221..bdc74876b 100755 --- a/setup.py +++ b/setup.py @@ -1459,7 +1459,7 @@ def package(args: Options, bundle_type: str) -> None: if path.endswith('.so'): return True q = path.split(os.sep)[-2:] - if len(q) == 2 and q[0] == 'ssh' and q[1] in ('askpass.py', 'kitty', 'kitten'): + if len(q) == 2 and q[0] == 'ssh' and q[1] in ('kitty', 'kitten'): return True return False diff --git a/shell-integration/ssh/askpass.py b/shell-integration/ssh/askpass.py deleted file mode 100755 index 868d79b16..000000000 --- a/shell-integration/ssh/askpass.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env -S kitty +launch -# License: GPLv3 Copyright: 2022, Kovid Goyal - -import json -import os -import sys -import time - -from kitty.shm import SharedMemory - -msg = sys.argv[-1] -prompt = os.environ.get('SSH_ASKPASS_PROMPT', '') -is_confirm = prompt == 'confirm' -is_fingerprint_check = '(yes/no/[fingerprint])' in msg -q = { - 'message': msg, - 'type': 'confirm' if is_confirm else 'get_line', - 'is_password': not is_fingerprint_check, -} - -data = json.dumps(q) -with SharedMemory( - size=len(data) + 1 + SharedMemory.num_bytes_for_size, unlink_on_exit=True, prefix=f'askpass-{os.getpid()}-') as shm, \ - open(os.ctermid(), 'wb') as tty: - shm.write(b'\0') - shm.write_data_with_size(data) - shm.flush() - with open(os.ctermid(), 'wb') as f: - f.write(f'\x1bP@kitty-ask|{shm.name}\x1b\\'.encode('ascii')) - f.flush() - while True: - # TODO: Replace sleep() with a mutex and condition variable created in the shared memory - time.sleep(0.05) - shm.seek(0) - if shm.read(1) == b'\x01': - break - response = json.loads(shm.read_data_with_size()) -if is_confirm: - response = 'yes' if response else 'no' -elif is_fingerprint_check: - if response.lower() in ('y', 'yes'): - response = 'yes' - if response.lower() in ('n', 'no'): - response = 'no' -if response: - print(response, flush=True) diff --git a/shell-integration/ssh/kitty b/shell-integration/ssh/kitty index 89abab2c5..8eb0f59be 100755 --- a/shell-integration/ssh/kitty +++ b/shell-integration/ssh/kitty @@ -24,7 +24,7 @@ exec_kitty() { is_wrapped_kitten() { - wrapped_kittens="clipboard icat unicode_input" + wrapped_kittens="clipboard icat unicode_input ssh" [ -n "$1" ] && { case " $wrapped_kittens " in *" $1 "*) printf "%s" "$1" ;; diff --git a/tools/cli/command.go b/tools/cli/command.go index 13f774a1b..07e62ba64 100644 --- a/tools/cli/command.go +++ b/tools/cli/command.go @@ -33,9 +33,11 @@ type Command struct { ArgCompleter CompletionFunc // Stop completion processing at this arg num StopCompletingAtArg int - // Consider all args as non-options args + // Consider all args as non-options args when parsing for completion OnlyArgsAllowed bool - // Specialised arg aprsing + // Pass through all args, useful for wrapper commands + IgnoreAllArgs bool + // Specialised arg parsing ParseArgsForCompletion func(cmd *Command, args []string, completions *Completions) SubCommandGroups []*CommandGroup diff --git a/tools/cli/parse-args.go b/tools/cli/parse-args.go index cddb3946a..fc34f10eb 100644 --- a/tools/cli/parse-args.go +++ b/tools/cli/parse-args.go @@ -13,6 +13,10 @@ func (self *Command) parse_args(ctx *Context, args []string) error { args_to_parse := make([]string, len(args)) copy(args_to_parse, args) ctx.SeenCommands = append(ctx.SeenCommands, self) + if self.IgnoreAllArgs { + self.Args = args + return nil + } var expecting_arg_for *Option options_allowed := true diff --git a/tools/cmd/completion/kitty.go b/tools/cmd/completion/kitty.go index a6642ad14..f9284a4e9 100644 --- a/tools/cmd/completion/kitty.go +++ b/tools/cmd/completion/kitty.go @@ -66,8 +66,8 @@ func complete_plus_open(completions *cli.Completions, word string, arg_num int) } func complete_themes(completions *cli.Completions, word string, arg_num int) { - kitty, err := utils.KittyExe() - if err == nil { + kitty := utils.KittyExe() + if kitty != "" { out, err := exec.Command(kitty, "+runpy", "from kittens.themes.collection import *; print_theme_names()").Output() if err == nil { mg := completions.AddMatchGroup("Themes") diff --git a/tools/cmd/icat/magick.go b/tools/cmd/icat/magick.go index 621109b91..d3390730f 100644 --- a/tools/cmd/icat/magick.go +++ b/tools/cmd/icat/magick.go @@ -14,7 +14,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "kitty/tools/tui/graphics" "kitty/tools/utils" @@ -24,12 +23,13 @@ import ( var _ = fmt.Print -var find_exe_lock sync.Once -var magick_exe string = "" - -func find_magick_exe() { - magick_exe = utils.Which("magick") -} +var MagickExe = (&utils.Once[string]{Run: func() string { + ans := utils.Which("magick") + if ans == "" { + ans = utils.Which("magick", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin", "/usr/sbin", "/sbin") + } + return ans +}}).Get func run_magick(path string, cmd []string) ([]byte, error) { c := exec.Command(cmd[0], cmd[1:]...) @@ -154,10 +154,9 @@ func parse_identify_record(ans *IdentifyRecord, raw *IdentifyOutput) (err error) } func Identify(path string) (ans []IdentifyRecord, err error) { - find_exe_lock.Do(find_magick_exe) cmd := []string{"identify"} - if magick_exe != "" { - cmd = []string{magick_exe, cmd[0]} + if MagickExe() != "" { + cmd = []string{MagickExe(), cmd[0]} } q := `{"fmt":"%m","canvas":"%g","transparency":"%A","gap":"%T","index":"%p","size":"%wx%h",` + `"dpi":"%xx%y","dispose":"%D","orientation":"%[EXIF:Orientation]"},` @@ -227,10 +226,9 @@ func check_resize(frame *image_frame) error { } func Render(path string, ro *RenderOptions, frames []IdentifyRecord) (ans []*image_frame, err error) { - find_exe_lock.Do(find_magick_exe) cmd := []string{"convert"} - if magick_exe != "" { - cmd = []string{magick_exe, cmd[0]} + if MagickExe() != "" { + cmd = []string{MagickExe(), cmd[0]} } ans = make([]*image_frame, 0, len(frames)) defer func() { diff --git a/tools/cmd/main.go b/tools/cmd/main.go index 08c697b7d..23c12705c 100644 --- a/tools/cmd/main.go +++ b/tools/cmd/main.go @@ -3,12 +3,22 @@ package main import ( + "os" + "kitty/tools/cli" "kitty/tools/cmd/completion" + "kitty/tools/cmd/ssh" "kitty/tools/cmd/tool" ) func main() { + krm := os.Getenv("KITTY_KITTEN_RUN_MODULE") + os.Unsetenv("KITTY_KITTEN_RUN_MODULE") + switch krm { + case "ssh_askpass": + ssh.RunSSHAskpass() + return + } root := cli.NewRootCommand() root.ShortDescription = "Fast, statically compiled implementations for various kittens (command line tools for use with kitty)" root.Usage = "command [command options] [command args]" diff --git a/tools/cmd/pytest/main.go b/tools/cmd/pytest/main.go new file mode 100644 index 000000000..bb929c705 --- /dev/null +++ b/tools/cmd/pytest/main.go @@ -0,0 +1,22 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package pytest + +import ( + "fmt" + + "kitty/tools/cli" + "kitty/tools/cmd/ssh" + "kitty/tools/utils/shm" +) + +var _ = fmt.Print + +func EntryPoint(root *cli.Command) { + root = root.AddSubCommand(&cli.Command{ + Name: "__pytest__", + Hidden: true, + }) + shm.TestEntryPoint(root) + ssh.TestEntryPoint(root) +} diff --git a/tools/cmd/ssh/askpass.go b/tools/cmd/ssh/askpass.go new file mode 100644 index 000000000..edc7bbf31 --- /dev/null +++ b/tools/cmd/ssh/askpass.go @@ -0,0 +1,118 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + "kitty/tools/cli" + "kitty/tools/tty" + "kitty/tools/utils/shm" +) + +var _ = fmt.Print + +func fatal(err error) { + cli.ShowError(err) + os.Exit(1) +} + +func trigger_ask(name string) { + term, err := tty.OpenControllingTerm() + if err != nil { + fatal(err) + } + defer term.Close() + _, err = term.WriteString("\x1bP@kitty-ask|" + name + "\x1b\\") + if err != nil { + fatal(err) + } + +} + +func RunSSHAskpass() { + msg := os.Args[len(os.Args)-1] + prompt := os.Getenv("SSH_ASKPASS_PROMPT") + is_confirm := prompt == "confirm" + q_type := "get_line" + if is_confirm { + q_type = "confirm" + } + is_fingerprint_check := strings.Contains(msg, "(yes/no/[fingerprint])") + q := map[string]any{ + "message": msg, + "type": q_type, + "is_password": !is_fingerprint_check, + } + data, err := json.Marshal(q) + if err != nil { + fatal(err) + } + shm, err := shm.CreateTemp("askpass-*", uint64(len(data)+32)) + if err != nil { + fatal(fmt.Errorf("Failed to create SHM file with error: %w", err)) + } + defer shm.Close() + defer shm.Unlink() + + shm.Slice()[0] = 0 + binary.BigEndian.PutUint32(shm.Slice()[1:], uint32(len(data))) + copy(shm.Slice()[5:], data) + err = shm.Flush() + if err != nil { + fatal(fmt.Errorf("Failed to flush SHM file with error: %w", err)) + } + trigger_ask(shm.Name()) + buf := []byte{0} + for { + time.Sleep(50 * time.Millisecond) + _, err = shm.Seek(0, os.SEEK_SET) + if err != nil { + fatal(fmt.Errorf("Failed to seek into SHM file while waiting for response with error: %w", err)) + } + _, err = shm.Read(buf) + if err != nil { + fatal(fmt.Errorf("Failed to read from SHM file while waiting for response with error: %w", err)) + } + if buf[0] == 1 { + break + } + } + data, err = shm.ReadWithSize() + if err != nil { + fatal(fmt.Errorf("Failed to read response data from SHM file with error: %w", err)) + } + response := "" + if is_confirm { + var ok bool + err = json.Unmarshal(data, &ok) + if err != nil { + fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err)) + } + response = "no" + if ok { + response = "yes" + } + } else { + err = json.Unmarshal(data, &response) + if err != nil { + fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err)) + } + if is_fingerprint_check { + response = strings.ToLower(response) + if response == "y" { + response = "yes" + } else if response == "n" { + response = "no" + } + } + } + if response != "" { + fmt.Println(response) + } +} diff --git a/tools/cmd/ssh/config.go b/tools/cmd/ssh/config.go new file mode 100644 index 000000000..a863a9cee --- /dev/null +++ b/tools/cmd/ssh/config.go @@ -0,0 +1,409 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "archive/tar" + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "path" + "path/filepath" + "strings" + "time" + + "kitty/tools/config" + "kitty/tools/utils" + "kitty/tools/utils/paths" + "kitty/tools/utils/shlex" + + "github.com/bmatcuk/doublestar" + "golang.org/x/sys/unix" +) + +var _ = fmt.Print + +type EnvInstruction struct { + key, val string + delete_on_remote, copy_from_local, literal_quote bool +} + +func quote_for_sh(val string, literal_quote bool) string { + if literal_quote { + return utils.QuoteStringForSH(val) + } + // See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html + b := strings.Builder{} + b.Grow(len(val) + 16) + b.WriteRune('"') + runes := []rune(val) + for i, ch := range runes { + if ch == '\\' || ch == '`' || ch == '"' || (ch == '$' && i+1 < len(runes) && runes[i+1] == '(') { + // special chars are escaped + // $( is escaped to prevent execution + b.WriteRune('\\') + } + b.WriteRune(ch) + } + b.WriteRune('"') + return b.String() +} + +func (self *EnvInstruction) Serialize(for_python bool, get_local_env func(string) (string, bool)) string { + var unset func() string + var export func(string) string + if for_python { + dumps := func(x ...any) string { + ans, _ := json.Marshal(x) + return utils.UnsafeBytesToString(ans) + } + export = func(val string) string { + if val == "" { + return fmt.Sprintf("export %s", dumps(self.key)) + } + return fmt.Sprintf("export %s", dumps(self.key, val, self.literal_quote)) + } + unset = func() string { + return fmt.Sprintf("unset %s", dumps(self.key)) + } + } else { + kq := utils.QuoteStringForSH(self.key) + unset = func() string { + return fmt.Sprintf("unset %s", kq) + } + export = func(val string) string { + return fmt.Sprintf("export %s=%s", kq, quote_for_sh(val, self.literal_quote)) + } + } + if self.delete_on_remote { + return unset() + } + if self.copy_from_local { + val, found := get_local_env(self.key) + if !found { + return "" + } + return export(val) + } + return export(self.val) +} + +func final_env_instructions(for_python bool, get_local_env func(string) (string, bool), env ...*EnvInstruction) string { + seen := make(map[string]int, len(env)) + ans := make([]string, 0, len(env)) + for _, ei := range env { + q := ei.Serialize(for_python, get_local_env) + if q != "" { + if pos, found := seen[ei.key]; found { + ans[pos] = q + } else { + seen[ei.key] = len(ans) + ans = append(ans, q) + } + } + } + return strings.Join(ans, "\n") +} + +type CopyInstruction struct { + local_path, arcname string + exclude_patterns []string +} + +func ParseEnvInstruction(spec string) (ans []*EnvInstruction, err error) { + const COPY_FROM_LOCAL string = "_kitty_copy_env_var_" + ei := &EnvInstruction{} + found := false + ei.key, ei.val, found = strings.Cut(spec, "=") + ei.key = strings.TrimSpace(ei.key) + if found { + ei.val = strings.TrimSpace(ei.val) + if ei.val == COPY_FROM_LOCAL { + ei.val = "" + ei.copy_from_local = true + } + } else { + ei.delete_on_remote = true + } + if ei.key == "" { + err = fmt.Errorf("The env directive must not be empty") + } + ans = []*EnvInstruction{ei} + return +} + +var paths_ctx *paths.Ctx + +func resolve_file_spec(spec string, is_glob bool) ([]string, error) { + if paths_ctx == nil { + paths_ctx = &paths.Ctx{} + } + ans := os.ExpandEnv(paths_ctx.ExpandHome(spec)) + if !filepath.IsAbs(ans) { + ans = paths_ctx.AbspathFromHome(ans) + } + if is_glob { + files, err := doublestar.Glob(ans) + if err != nil { + return nil, fmt.Errorf("%s is not a valid glob pattern with error: %w", spec, err) + } + if len(files) == 0 { + return nil, fmt.Errorf("%s matches no files", spec) + } + return files, nil + } + err := unix.Access(ans, unix.R_OK) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("%s does not exist", spec) + } + return nil, fmt.Errorf("Cannot read from: %s with error: %w", spec, err) + } + return []string{ans}, nil +} + +func get_arcname(loc, dest, home string) (arcname string) { + if dest != "" { + arcname = dest + } else { + arcname = filepath.Clean(loc) + if filepath.HasPrefix(arcname, home) { + ra, err := filepath.Rel(home, arcname) + if err == nil { + arcname = ra + } + } + } + prefix := "home/" + if strings.HasPrefix(arcname, "/") { + prefix = "root" + } + return prefix + arcname +} + +func ParseCopyInstruction(spec string) (ans []*CopyInstruction, err error) { + args, err := shlex.Split("copy " + spec) + if err != nil { + return nil, err + } + opts, args, err := parse_copy_args(args) + if err != nil { + return nil, err + } + locations := make([]string, 0, len(args)) + for _, arg := range args { + locs, err := resolve_file_spec(arg, opts.Glob) + if err != nil { + return nil, err + } + locations = append(locations, locs...) + } + if len(locations) == 0 { + return nil, fmt.Errorf("No files to copy specified") + } + if len(locations) > 1 && opts.Dest != "" { + return nil, fmt.Errorf("Specifying a remote location with more than one file is not supported") + } + home := paths_ctx.HomePath() + ans = make([]*CopyInstruction, 0, len(locations)) + for _, loc := range locations { + ci := CopyInstruction{local_path: loc, exclude_patterns: opts.Exclude} + if opts.SymlinkStrategy != "preserve" { + ci.local_path, err = filepath.EvalSymlinks(loc) + if err != nil { + return nil, fmt.Errorf("Failed to resolve symlinks in %#v with error: %w", loc, err) + } + } + if opts.SymlinkStrategy == "resolve" { + ci.arcname = get_arcname(ci.local_path, opts.Dest, home) + } else { + ci.arcname = get_arcname(loc, opts.Dest, home) + } + ans = append(ans, &ci) + } + return +} + +type file_unique_id struct { + dev, inode uint64 +} + +func excluded(pattern, path string) bool { + if !strings.ContainsRune(pattern, '/') { + path = filepath.Base(path) + } + if matched, err := doublestar.PathMatch(pattern, path); matched && err == nil { + return true + } + return false +} + +func get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string, local_path, arcname string, exclude_patterns []string) error { + s, err := os.Lstat(local_path) + if err != nil { + return err + } + u, ok := s.Sys().(unix.Stat_t) + cb := func(h *tar.Header, data []byte, arcname string) error { + h.Name = arcname + if h.Typeflag == tar.TypeDir { + h.Name = strings.TrimRight(h.Name, "/") + "/" + } + h.Size = int64(len(data)) + h.Mode = int64(s.Mode().Perm()) + h.ModTime = s.ModTime() + h.Format = tar.FormatPAX + if ok { + h.AccessTime = time.Unix(0, u.Atim.Nano()) + h.ChangeTime = time.Unix(0, u.Ctim.Nano()) + } + return callback(h, data) + } + // we only copy regular files, directories and symlinks + switch s.Mode().Type() { + case fs.ModeSymlink: + target, err := os.Readlink(local_path) + if err != nil { + return err + } + err = cb(&tar.Header{ + Typeflag: tar.TypeSymlink, + Linkname: target, + }, nil, arcname) + if err != nil { + return err + } + case fs.ModeDir: + local_path = filepath.Clean(local_path) + type entry struct { + path, arcname string + } + stack := []entry{{local_path, arcname}} + for len(stack) > 0 { + x := stack[0] + stack = stack[1:] + entries, err := os.ReadDir(x.path) + if err != nil { + if x.path == local_path { + return err + } + continue + } + err = cb(&tar.Header{Typeflag: tar.TypeDir}, nil, x.arcname) + if err != nil { + return err + } + for _, e := range entries { + entry_path := filepath.Join(x.path, e.Name()) + aname := path.Join(x.arcname, e.Name()) + ok := true + for _, pat := range exclude_patterns { + if excluded(pat, entry_path) { + ok = false + break + } + } + if !ok { + continue + } + if e.IsDir() { + stack = append(stack, entry{entry_path, aname}) + } else { + err = get_file_data(callback, seen, entry_path, aname, exclude_patterns) + if err != nil { + return err + } + } + } + } + case 0: // Regular file + fid := file_unique_id{dev: u.Dev, inode: u.Ino} + if prev, ok := seen[fid]; ok { // Hard link + err = cb(&tar.Header{Typeflag: tar.TypeLink, Linkname: prev}, nil, arcname) + if err != nil { + return err + } + } + seen[fid] = arcname + data, err := os.ReadFile(local_path) + if err != nil { + return err + } + err = cb(&tar.Header{Typeflag: tar.TypeReg}, data, arcname) + if err != nil { + return err + } + } + return nil +} + +func (ci *CopyInstruction) get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string) (err error) { + ep := ci.exclude_patterns + for _, folder_name := range []string{"__pycache__", ".DS_Store"} { + ep = append(ep, "**/"+folder_name, "**/"+folder_name+"/**") + } + return get_file_data(callback, seen, ci.local_path, ci.arcname, ep) +} + +type ConfigSet struct { + all_configs []*Config +} + +func config_for_hostname(hostname_to_match, username_to_match string, cs *ConfigSet) *Config { + matcher := func(q *Config) bool { + for _, pat := range strings.Split(q.Hostname, " ") { + upat := "*" + if strings.Contains(pat, "@") { + upat, pat, _ = strings.Cut(pat, "@") + } + var host_matched, user_matched bool + if matched, err := filepath.Match(pat, hostname_to_match); matched && err == nil { + host_matched = true + } + if matched, err := filepath.Match(upat, username_to_match); matched && err == nil { + user_matched = true + } + if host_matched && user_matched { + return true + } + } + return false + } + for _, c := range utils.Reversed(cs.all_configs) { + if matcher(c) { + return c + } + } + return cs.all_configs[0] +} + +func (self *ConfigSet) line_handler(key, val string) error { + c := self.all_configs[len(self.all_configs)-1] + if key == "hostname" { + c = NewConfig() + self.all_configs = append(self.all_configs, c) + } + return c.Parse(key, val) +} + +func load_config(hostname_to_match string, username_to_match string, overrides []string, paths ...string) (*Config, []config.ConfigLine, error) { + ans := &ConfigSet{all_configs: []*Config{NewConfig()}} + p := config.ConfigParser{LineHandler: ans.line_handler} + if len(paths) == 0 { + paths = []string{filepath.Join(utils.ConfigDir(), "ssh.conf")} + } + paths = utils.Filter(paths, func(x string) bool { return x != "" }) + err := p.ParseFiles(paths...) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return nil, nil, err + } + if len(overrides) > 0 { + err = p.ParseOverrides(overrides...) + if err != nil { + return nil, nil, err + } + } + return config_for_hostname(hostname_to_match, username_to_match, ans), p.BadLines(), nil +} diff --git a/tools/cmd/ssh/config_test.go b/tools/cmd/ssh/config_test.go new file mode 100644 index 000000000..ef2dd16da --- /dev/null +++ b/tools/cmd/ssh/config_test.go @@ -0,0 +1,110 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "fmt" + "kitty/tools/utils" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" +) + +var _ = fmt.Print + +func TestSSHConfigParsing(t *testing.T) { + tdir := t.TempDir() + hostname := "unmatched" + username := "" + conf := "" + for_python := false + cf := filepath.Join(tdir, "ssh.conf") + rt := func(expected_env ...string) { + os.WriteFile(cf, []byte(conf), 0o600) + c, bad_lines, err := load_config(hostname, username, nil, cf) + if err != nil { + t.Fatal(err) + } + if len(bad_lines) != 0 { + t.Fatalf("Bad config line: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err) + } + actual := final_env_instructions(for_python, func(key string) (string, bool) { + if key == "LOCAL_ENV" { + return "LOCAL_VAL", true + } + return "", false + }, c.Env...) + if expected_env == nil { + expected_env = []string{} + } + diff := cmp.Diff(expected_env, utils.Splitlines(actual)) + if diff != "" { + t.Fatalf("Unexpected env for\nhostname: %#v\nusername: %#v\nconf: %s\n%s", hostname, username, conf, diff) + } + } + rt() + conf = "env a=b" + rt(`export 'a'="b"`) + conf = "env a=b\nhostname 2\nenv a=c\nenv b=b" + rt(`export 'a'="b"`) + hostname = "2" + rt(`export 'a'="c"`, `export 'b'="b"`) + conf = "env a=" + rt(`export 'a'=""`) + conf = "env a" + rt(`unset 'a'`) + conf = "env a=b\nhostname test@2\nenv a=c\nenv b=b" + hostname = "unmatched" + rt(`export 'a'="b"`) + hostname = "2" + rt(`export 'a'="b"`) + username = "test" + rt(`export 'a'="c"`, `export 'b'="b"`) + conf = "env a=b\nhostname 1 2\nenv a=c\nenv b=b" + username = "" + hostname = "unmatched" + rt(`export 'a'="b"`) + hostname = "1" + rt(`export 'a'="c"`, `export 'b'="b"`) + hostname = "2" + rt(`export 'a'="c"`, `export 'b'="b"`) + for_python = true + rt(`export ["a","c",false]`, `export ["b","b",false]`) + conf = "env a=" + rt(`export ["a"]`) + conf = "env a" + rt(`unset ["a"]`) + conf = "env LOCAL_ENV=_kitty_copy_env_var_" + rt(`export ["LOCAL_ENV","LOCAL_VAL",false]`) + + ci, err := ParseCopyInstruction("--exclude moose --dest=target " + cf) + if err != nil { + t.Fatal(err) + } + diff := cmp.Diff("home/target", ci[0].arcname) + if diff != "" { + t.Fatalf("Incorrect arcname:\n%s", diff) + } + diff = cmp.Diff(cf, ci[0].local_path) + if diff != "" { + t.Fatalf("Incorrect local_path:\n%s", diff) + } + diff = cmp.Diff([]string{"moose"}, ci[0].exclude_patterns) + if diff != "" { + t.Fatalf("Incorrect excludes:\n%s", diff) + } + ci, err = ParseCopyInstruction("--glob " + filepath.Join(filepath.Dir(cf), "*.conf")) + if err != nil { + t.Fatal(err) + } + diff = cmp.Diff(cf, ci[0].local_path) + if diff != "" { + t.Fatalf("Incorrect local_path:\n%s", diff) + } + if len(ci) != 1 { + t.Fatal(ci) + } + +} diff --git a/tools/cmd/ssh/data.go b/tools/cmd/ssh/data.go new file mode 100644 index 000000000..4b224f9b1 --- /dev/null +++ b/tools/cmd/ssh/data.go @@ -0,0 +1,69 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "archive/tar" + _ "embed" + "errors" + "fmt" + "io" + "kitty/tools/utils" + "regexp" + "strings" +) + +var _ = fmt.Print + +//go:embed data_generated.bin +var embedded_data string + +type Entry struct { + metadata *tar.Header + data []byte +} + +type Container map[string]Entry + +var Data = (&utils.Once[Container]{Run: func() Container { + tr := tar.NewReader(utils.ReaderForCompressedEmbeddedData(embedded_data)) + ans := make(Container, 64) + for { + hdr, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + panic(err) + } + data, err := utils.ReadAll(tr, int(hdr.Size)) + if err != nil { + panic(err) + } + ans[hdr.Name] = Entry{hdr, data} + } + return ans +}}).Get + +func (self Container) files_matching(prefix string, exclude_patterns ...string) []string { + ans := make([]string, 0, len(self)) + patterns := make([]*regexp.Regexp, len(exclude_patterns)) + for i, exp := range exclude_patterns { + patterns[i] = regexp.MustCompile(exp) + } + for name := range self { + if strings.HasPrefix(name, prefix) { + excluded := false + for _, pat := range patterns { + if matched := pat.FindString(name); matched != "" { + excluded = true + break + } + } + if !excluded { + ans = append(ans, name) + } + } + } + return ans +} diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go new file mode 100644 index 000000000..da4ffc3d4 --- /dev/null +++ b/tools/cmd/ssh/main.go @@ -0,0 +1,801 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "kitty" + "net/url" + "os" + "os/exec" + "os/user" + "path" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "kitty/tools/cli" + "kitty/tools/themes" + "kitty/tools/tty" + "kitty/tools/tui" + "kitty/tools/tui/loop" + "kitty/tools/utils" + "kitty/tools/utils/secrets" + "kitty/tools/utils/shm" + + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + "golang.org/x/sys/unix" +) + +var _ = fmt.Print + +func get_destination(hostname string) (username, hostname_for_match string) { + u, err := user.Current() + if err == nil { + username = u.Username + } + hostname_for_match = hostname + if strings.HasPrefix(hostname, "ssh://") { + p, err := url.Parse(hostname) + if err == nil { + hostname_for_match = p.Hostname() + if p.User.Username() != "" { + username = p.User.Username() + } + } + } else if strings.Contains(hostname, "@") && hostname[0] != '@' { + username, hostname_for_match, _ = strings.Cut(hostname, "@") + } + if strings.Contains(hostname, "@") && hostname[0] != '@' { + _, hostname_for_match, _ = strings.Cut(hostname_for_match, "@") + } + hostname_for_match, _, _ = strings.Cut(hostname_for_match, ":") + return +} + +func read_data_from_shared_memory(shm_name string) ([]byte, error) { + data, err := shm.ReadWithSizeAndUnlink(shm_name, func(f *os.File) error { + s, err := f.Stat() + if err != nil { + return fmt.Errorf("Failed to stat SHM file with error: %w", err) + } + if stat, ok := s.Sys().(unix.Stat_t); ok { + if os.Getuid() != int(stat.Uid) || os.Getgid() != int(stat.Gid) { + return fmt.Errorf("Incorrect owner on SHM file") + } + } + if s.Mode().Perm() != 0o600 { + return fmt.Errorf("Incorrect permissions on SHM file") + } + return nil + }) + return data, err +} + +func add_cloned_env(val string) (ans map[string]string, err error) { + data, err := read_data_from_shared_memory(val) + if err != nil { + return nil, err + } + err = json.Unmarshal(data, &ans) + return ans, err +} + +func parse_kitten_args(found_extra_args []string, username, hostname_for_match string) (overrides []string, literal_env map[string]string, ferr error) { + literal_env = make(map[string]string) + overrides = make([]string, 0, 4) + for i, a := range found_extra_args { + if i%2 == 0 { + continue + } + if key, val, found := strings.Cut(a, "="); found { + if key == "clone_env" { + le, err := add_cloned_env(val) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return nil, nil, ferr + } + } else if le != nil { + literal_env = le + } + } else if key != "hostname" { + overrides = append(overrides, key+" "+val) + } + } + } + if len(overrides) > 0 { + overrides = append([]string{"hostname " + username + "@" + hostname_for_match}, overrides...) + } + return +} + +func connection_sharing_args(kitty_pid int) ([]string, error) { + rd := utils.RuntimeDir() + // Bloody OpenSSH generates a 40 char hash and in creating the socket + // appends a 27 char temp suffix to it. Socket max path length is approx + // ~104 chars. And on idiotic Apple the path length to the runtime dir + // (technically the cache dir since Apple has no runtime dir and thinks it's + // a great idea to delete files in /tmp) is ~48 chars. + if len(rd) > 35 { + idiotic_design := fmt.Sprintf("/tmp/kssh-rdir-%d", os.Geteuid()) + if err := utils.AtomicCreateSymlink(rd, idiotic_design); err != nil { + return nil, err + } + rd = idiotic_design + } + cp := strings.Replace(kitty.SSHControlMasterTemplate, "{kitty_pid}", strconv.Itoa(kitty_pid), 1) + cp = strings.Replace(cp, "{ssh_placeholder}", "%C", 1) + return []string{ + "-o", "ControlMaster=auto", + "-o", "ControlPath=" + filepath.Join(rd, cp), + "-o", "ControlPersist=yes", + "-o", "ServerAliveInterval=60", + "-o", "ServerAliveCountMax=5", + "-o", "TCPKeepAlive=no", + }, nil +} + +func set_askpass() (need_to_request_data bool) { + need_to_request_data = true + sentinel := filepath.Join(utils.CacheDir(), "openssh-is-new-enough-for-askpass") + _, err := os.Stat(sentinel) + sentinel_exists := err == nil + if sentinel_exists || GetSSHVersion().SupportsAskpassRequire() { + if !sentinel_exists { + os.WriteFile(sentinel, []byte{0}, 0o644) + } + need_to_request_data = false + } + exe, err := os.Executable() + if err == nil { + os.Setenv("SSH_ASKPASS", exe) + os.Setenv("KITTY_KITTEN_RUN_MODULE", "ssh_askpass") + if !need_to_request_data { + os.Setenv("SSH_ASKPASS_REQUIRE", "force") + } + } else { + need_to_request_data = true + } + return +} + +type connection_data struct { + remote_args []string + host_opts *Config + hostname_for_match string + username string + echo_on bool + request_data bool + literal_env map[string]string + test_script string + + shm_name string + script_type string + rcmd []string + replacements map[string]string + request_id string + bootstrap_script string +} + +func get_effective_ksi_env_var(x string) string { + parts := strings.Split(strings.TrimSpace(strings.ToLower(x)), " ") + current := utils.NewSetWithItems(parts...) + if current.Has("disabled") { + return "" + } + allowed := utils.NewSetWithItems(kitty.AllowedShellIntegrationValues...) + if !current.IsSubsetOf(allowed) { + return RelevantKittyOpts().Shell_integration + } + return x +} + +func serialize_env(cd *connection_data, get_local_env func(string) (string, bool)) (string, string) { + ksi := "" + if cd.host_opts.Shell_integration == "inherited" { + ksi = get_effective_ksi_env_var(RelevantKittyOpts().Shell_integration) + } else { + ksi = get_effective_ksi_env_var(cd.host_opts.Shell_integration) + } + env := make([]*EnvInstruction, 0, 8) + add_env := func(key, val string, fallback ...string) *EnvInstruction { + if val == "" && len(fallback) > 0 { + val = fallback[0] + } + if val != "" { + env = append(env, &EnvInstruction{key: key, val: val, literal_quote: true}) + return env[len(env)-1] + } + return nil + } + add_non_literal_env := func(key, val string, fallback ...string) *EnvInstruction { + ans := add_env(key, val, fallback...) + if ans != nil { + ans.literal_quote = false + } + return ans + } + for k, v := range cd.literal_env { + add_env(k, v) + } + add_env("TERM", os.Getenv("TERM"), RelevantKittyOpts().Term) + add_env("COLORTERM", "truecolor") + env = append(env, cd.host_opts.Env...) + add_env("KITTY_WINDOW_ID", os.Getenv("KITTY_WINDOW_ID")) + add_env("WINDOWID", os.Getenv("WINDOWID")) + if ksi != "" { + add_env("KITTY_SHELL_INTEGRATION", ksi) + } else { + env = append(env, &EnvInstruction{key: "KITTY_SHELL_INTEGRATION", delete_on_remote: true}) + } + add_non_literal_env("KITTY_SSH_KITTEN_DATA_DIR", cd.host_opts.Remote_dir) + add_non_literal_env("KITTY_LOGIN_SHELL", cd.host_opts.Login_shell) + add_non_literal_env("KITTY_LOGIN_CWD", cd.host_opts.Cwd) + if cd.host_opts.Remote_kitty != Remote_kitty_no { + add_env("KITTY_REMOTE", cd.host_opts.Remote_kitty.String()) + } + add_env("KITTY_PUBLIC_KEY", os.Getenv("KITTY_PUBLIC_KEY")) + return final_env_instructions(cd.script_type == "py", get_local_env, env...), ksi +} + +func make_tarfile(cd *connection_data, get_local_env func(string) (string, bool)) ([]byte, error) { + env_script, ksi := serialize_env(cd, get_local_env) + w := bytes.Buffer{} + w.Grow(64 * 1024) + gw, err := gzip.NewWriterLevel(&w, gzip.BestCompression) + if err != nil { + return nil, err + } + tw := tar.NewWriter(gw) + rd := strings.TrimRight(cd.host_opts.Remote_dir, "/") + seen := make(map[file_unique_id]string, 32) + add := func(h *tar.Header, data []byte) (err error) { + // some distro's like nix mess with installed file permissions so ensure + // files are at least readable and writable by owning user + h.Mode |= 0o600 + err = tw.WriteHeader(h) + if err != nil { + return + } + if data != nil { + _, err := tw.Write(data) + if err != nil { + return err + } + } + return + } + for _, ci := range cd.host_opts.Copy { + err = ci.get_file_data(add, seen) + if err != nil { + return nil, err + } + } + type fe struct { + arcname string + data []byte + } + now := time.Now() + add_data := func(items ...fe) error { + for _, item := range items { + err := add( + &tar.Header{ + Typeflag: tar.TypeReg, Name: item.arcname, Format: tar.FormatPAX, Size: int64(len(item.data)), + Mode: 0o644, ModTime: now, ChangeTime: now, AccessTime: now, + }, item.data) + if err != nil { + return err + } + } + return nil + } + add_entries := func(prefix string, items ...Entry) error { + for _, item := range items { + err := add( + &tar.Header{ + Typeflag: item.metadata.Typeflag, Name: path.Join(prefix, path.Base(item.metadata.Name)), Format: tar.FormatPAX, + Size: int64(len(item.data)), Mode: item.metadata.Mode, ModTime: item.metadata.ModTime, + AccessTime: item.metadata.AccessTime, ChangeTime: item.metadata.ChangeTime, + }, item.data) + if err != nil { + return err + } + } + return nil + + } + add_data(fe{"data.sh", utils.UnsafeStringToBytes(env_script)}) + if cd.script_type == "sh" { + add_data(fe{"bootstrap-utils.sh", Data()[path.Join("shell-integration/ssh/bootstrap-utils.sh")].data}) + } + if ksi != "" { + for _, fname := range Data().files_matching( + "shell-integration/", + "shell-integration/ssh/.+", // bootstrap files are sent as command line args + "shell-integration/zsh/kitty.zsh", // backward compat file not needed by ssh kitten + ) { + arcname := path.Join("home/", rd, "/", path.Dir(fname)) + err = add_entries(arcname, Data()[fname]) + if err != nil { + return nil, err + } + } + } + if cd.host_opts.Remote_kitty != Remote_kitty_no { + arcname := path.Join("home/", rd, "/kitty") + err = add_data(fe{arcname + "/version", utils.UnsafeStringToBytes(kitty.VersionString)}) + if err != nil { + return nil, err + } + for _, x := range []string{"kitty", "kitten"} { + err = add_entries(path.Join(arcname, "bin"), Data()[path.Join("shell-integration", "ssh", x)]) + if err != nil { + return nil, err + } + } + } + err = add_entries(path.Join("home", ".terminfo"), Data()["terminfo/kitty.terminfo"]) + if err == nil { + err = add_entries(path.Join("home", ".terminfo", "x"), Data()["terminfo/x/xterm-kitty"]) + } + if err == nil { + err = tw.Close() + if err == nil { + err = gw.Close() + } + } + return w.Bytes(), err +} + +func prepare_home_command(cd *connection_data) string { + is_python := cd.script_type == "py" + homevar := "" + for _, ei := range cd.host_opts.Env { + if ei.key == "HOME" && !ei.delete_on_remote { + if ei.copy_from_local { + homevar = os.Getenv("HOME") + } else { + homevar = ei.val + } + } + } + export_home_cmd := "" + if homevar != "" { + if is_python { + export_home_cmd = base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(homevar)) + } else { + export_home_cmd = fmt.Sprintf("export HOME=%s; cd \"$HOME\"", utils.QuoteStringForSH(homevar)) + } + } + return export_home_cmd +} + +func prepare_exec_cmd(cd *connection_data) string { + // ssh simply concatenates multiple commands using a space see + // line 1129 of ssh.c and on the remote side sshd.c runs the + // concatenated command as shell -c cmd + if cd.script_type == "py" { + return base64.RawStdEncoding.EncodeToString(utils.UnsafeStringToBytes(strings.Join(cd.remote_args, " "))) + } + args := make([]string, len(cd.remote_args)) + for i, arg := range cd.remote_args { + args[i] = strings.ReplaceAll(arg, "'", "'\"'\"'") + } + return "unset KITTY_SHELL_INTEGRATION; exec \"$login_shell\" -c '" + strings.Join(args, " ") + "'" +} + +var data_shm shm.MMap + +func prepare_script(script string, replacements map[string]string) string { + if _, found := replacements["EXEC_CMD"]; !found { + replacements["EXEC_CMD"] = "" + } + if _, found := replacements["EXPORT_HOME_CMD"]; !found { + replacements["EXPORT_HOME_CMD"] = "" + } + keys := maps.Keys(replacements) + for i, key := range keys { + keys[i] = "\\b" + key + "\\b" + } + pat := regexp.MustCompile(strings.Join(keys, "|")) + return pat.ReplaceAllStringFunc(script, func(key string) string { return replacements[key] }) +} + +func bootstrap_script(cd *connection_data) (err error) { + if cd.request_id == "" { + cd.request_id = os.Getenv("KITTY_PID") + "-" + os.Getenv("KITTY_WINDOW_ID") + } + export_home_cmd := prepare_home_command(cd) + exec_cmd := "" + if len(cd.remote_args) > 0 { + exec_cmd = prepare_exec_cmd(cd) + } + pw, err := secrets.TokenHex() + if err != nil { + return err + } + tfd, err := make_tarfile(cd, os.LookupEnv) + if err != nil { + return err + } + data := map[string]string{ + "tarfile": base64.StdEncoding.EncodeToString(tfd), + "pw": pw, + "hostname": cd.hostname_for_match, "username": cd.username, + } + encoded_data, err := json.Marshal(data) + if err == nil { + data_shm, err = shm.CreateTemp(fmt.Sprintf("kssh-%d-", os.Getpid()), uint64(len(encoded_data)+8)) + if err == nil { + err = data_shm.WriteWithSize(encoded_data) + if err == nil { + err = data_shm.Flush() + } + } + } + if err != nil { + return err + } + cd.shm_name = data_shm.Name() + sensitive_data := map[string]string{"REQUEST_ID": cd.request_id, "DATA_PASSWORD": pw, "PASSWORD_FILENAME": cd.shm_name} + replacements := map[string]string{ + "EXPORT_HOME_CMD": export_home_cmd, + "EXEC_CMD": exec_cmd, + "TEST_SCRIPT": cd.test_script, + } + add_bool := func(ok bool, key string) { + if ok { + replacements[key] = "1" + } else { + replacements[key] = "0" + } + } + add_bool(cd.request_data, "REQUEST_DATA") + add_bool(cd.echo_on, "ECHO_ON") + sd := maps.Clone(replacements) + if cd.request_data { + maps.Copy(sd, sensitive_data) + } + maps.Copy(replacements, sensitive_data) + cd.replacements = replacements + cd.bootstrap_script = utils.UnsafeBytesToString(Data()["shell-integration/ssh/bootstrap."+cd.script_type].data) + cd.bootstrap_script = prepare_script(cd.bootstrap_script, sd) + return err +} + +func wrap_bootstrap_script(cd *connection_data) { + // sshd will execute the command we pass it by join all command line + // arguments with a space and passing it as a single argument to the users + // login shell with -c. If the user has a non POSIX login shell it might + // have different escaping semantics and syntax, so the command it should + // execute has to be as simple as possible, basically of the form + // interpreter -c unwrap_script escaped_bootstrap_script + // The unwrap_script is responsible for unescaping the bootstrap script and + // executing it. + encoded_script := "" + unwrap_script := "" + if cd.script_type == "py" { + encoded_script = base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(cd.bootstrap_script)) + unwrap_script = `"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"` + } else { + // We cant rely on base64 being available on the remote system, so instead + // we quote the bootstrap script by replacing ' and \ with \v and \f + // also replacing \n and ! with \r and \b for tcsh + // finally surrounding with ' + encoded_script = "'" + strings.NewReplacer("'", "\v", "\\", "\f", "\n", "\r", "!", "\b").Replace(cd.bootstrap_script) + "'" + unwrap_script = `'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' ` + } + cd.rcmd = []string{"exec", cd.host_opts.Interpreter, "-c", unwrap_script, encoded_script} +} + +func get_remote_command(cd *connection_data) error { + interpreter := cd.host_opts.Interpreter + q := strings.ToLower(path.Base(interpreter)) + is_python := strings.Contains(q, "python") + cd.script_type = "sh" + if is_python { + cd.script_type = "py" + } + err := bootstrap_script(cd) + if err != nil { + return err + } + wrap_bootstrap_script(cd) + return nil +} + +func drain_potential_tty_garbage(term *tty.Term) { + err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho) + if err != nil { + return + } + canary, err := secrets.TokenBase64() + if err != nil { + return + } + dcs, err := tui.DCSToKitty("echo", canary+"\n\r") + if err != nil { + return + } + err = term.WriteAllString(dcs) + if err != nil { + return + } + q := utils.UnsafeStringToBytes(canary) + data := make([]byte, 0) + give_up_at := time.Now().Add(2 * time.Second) + buf := make([]byte, 0, 8192) + for !bytes.Contains(data, q) { + buf = buf[:cap(buf)] + timeout := give_up_at.Sub(time.Now()) + if timeout < 0 { + break + } + n, err := term.ReadWithTimeout(buf, timeout) + if err != nil { + return + } + data = append(data, buf[:n]...) + } +} + +func change_colors(color_scheme string) (ans string, err error) { + if color_scheme == "" { + return + } + var theme *themes.Theme + if !strings.HasSuffix(color_scheme, ".conf") { + cs := os.ExpandEnv(color_scheme) + tc, closer, err := themes.LoadThemes(-1) + if err != nil && errors.Is(err, themes.ErrNoCacheFound) { + tc, closer, err = themes.LoadThemes(time.Hour * 24) + } + if err != nil { + return "", err + } + defer closer.Close() + theme = tc.ThemeByName(cs) + if theme == nil { + return "", fmt.Errorf("No theme named %#v found", cs) + } + } else { + theme, err = themes.ThemeFromFile(utils.ResolveConfPath(color_scheme)) + if err != nil { + return "", err + } + } + ans, err = theme.AsEscapeCodes() + if err == nil { + ans = "\033[#P" + ans + } + return +} + +func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) { + go Data() + go RelevantKittyOpts() + defer func() { + if data_shm != nil { + data_shm.Close() + data_shm.Unlink() + } + }() + cmd := append([]string{SSHExe()}, ssh_args...) + cd := connection_data{remote_args: server_args[1:]} + hostname := server_args[0] + if len(cd.remote_args) == 0 { + cmd = append(cmd, "-t") + } + insertion_point := len(cmd) + cmd = append(cmd, "--", hostname) + uname, hostname_for_match := get_destination(hostname) + overrides, literal_env, err := parse_kitten_args(found_extra_args, uname, hostname_for_match) + if err != nil { + return 1, err + } + host_opts, bad_lines, err := load_config(hostname_for_match, uname, overrides) + if err != nil { + return 1, err + } + if len(bad_lines) > 0 { + for _, x := range bad_lines { + fmt.Fprintf(os.Stderr, "Ignoring bad config line: %s:%d with error: %s", filepath.Base(x.Src_file), x.Line_number, x.Err) + } + } + if host_opts.Share_connections { + kpid, err := strconv.Atoi(os.Getenv("KITTY_PID")) + if err != nil { + return 1, fmt.Errorf("Invalid KITTY_PID env var not an integer: %#v", os.Getenv("KITTY_PID")) + } + cpargs, err := connection_sharing_args(kpid) + if err != nil { + return 1, err + } + cmd = slices.Insert(cmd, insertion_point, cpargs...) + } + use_kitty_askpass := host_opts.Askpass == Askpass_native || (host_opts.Askpass == Askpass_unless_set && os.Getenv("SSH_ASKPASS") == "") + need_to_request_data := true + if use_kitty_askpass { + need_to_request_data = set_askpass() + } + if need_to_request_data && host_opts.Share_connections { + check_cmd := slices.Insert(cmd, 1, "-O", "check") + err = exec.Command(check_cmd[0], check_cmd[1:]...).Run() + if err == nil { + need_to_request_data = false + } + } + term, err := tty.OpenControllingTerm(tty.SetNoEcho) + if err != nil { + return 1, fmt.Errorf("Failed to open controlling terminal with error: %w", err) + } + cd.echo_on = term.WasEchoOnOriginally() + cd.host_opts, cd.literal_env = host_opts, literal_env + cd.request_data = need_to_request_data + cd.hostname_for_match, cd.username = hostname_for_match, uname + escape_codes_to_set_colors, err := change_colors(cd.host_opts.Color_scheme) + if err == nil { + err = term.WriteAllString(escape_codes_to_set_colors + loop.SAVE_PRIVATE_MODE_VALUES + loop.HANDLE_TERMIOS_SIGNALS.EscapeCodeToSet()) + } + if err != nil { + return 1, err + } + restore_escape_codes := loop.RESTORE_PRIVATE_MODE_VALUES + if escape_codes_to_set_colors != "" { + restore_escape_codes += "\x1b[#Q" + } + defer func() { + term.WriteAllString(restore_escape_codes) + term.RestoreAndClose() + }() + err = get_remote_command(&cd) + if err != nil { + return 1, err + } + cmd = append(cmd, cd.rcmd...) + c := exec.Command(cmd[0], cmd[1:]...) + c.Stdin, c.Stdout, c.Stderr = os.Stdin, os.Stdout, os.Stderr + err = c.Start() + if err != nil { + return 1, err + } + if !cd.request_data { + rq := fmt.Sprintf("id=%s:pwfile=%s:pw=%s", cd.replacements["REQUEST_ID"], cd.replacements["PASSWORD_FILENAME"], cd.replacements["DATA_PASSWORD"]) + err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho) + if err == nil { + var dcs string + dcs, err = tui.DCSToKitty("ssh", rq) + if err == nil { + err = term.WriteAllString(dcs) + } + } + if err != nil { + c.Process.Kill() + c.Wait() + return 1, err + } + } + err = c.Wait() + drain_potential_tty_garbage(term) + if err != nil { + var exit_err *exec.ExitError + if errors.As(err, &exit_err) { + return exit_err.ExitCode(), nil + } + return 1, err + } + return 0, nil +} + +func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) { + if len(args) > 0 { + switch args[0] { + case "use-python": + args = args[1:] // backwards compat from when we had a python implementation + case "-h", "--help": + cmd.ShowHelp() + return + } + } + ssh_args, server_args, passthrough, found_extra_args, err := ParseSSHArgs(args, "--kitten") + if err != nil { + var invargs *ErrInvalidSSHArgs + switch { + case errors.As(err, &invargs): + if invargs.Msg != "" { + fmt.Fprintln(os.Stderr, invargs.Msg) + } + return 1, unix.Exec(SSHExe(), []string{"ssh"}, os.Environ()) + } + return 1, err + } + if passthrough { + if len(found_extra_args) > 0 { + return 1, fmt.Errorf("The SSH kitten cannot work with the options: %s", strings.Join(maps.Keys(PassthroughArgs()), " ")) + } + return 1, unix.Exec(SSHExe(), append([]string{"ssh"}, args...), os.Environ()) + } + if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" { + return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window") + } + if !tty.IsTerminal(os.Stdin.Fd()) { + return 1, fmt.Errorf("The SSH kitten is meant for interactive use only, STDIN must be a terminal") + } + return run_ssh(ssh_args, server_args, found_extra_args) +} + +func EntryPoint(parent *cli.Command) { + create_cmd(parent, main) +} + +func specialize_command(ssh *cli.Command) { + ssh.Usage = "arguments for the ssh command" + ssh.ShortDescription = "Truly convenient SSH" + ssh.HelpText = "The ssh kitten is a thin wrapper around the ssh command. It automatically enables shell integration on the remote host, re-uses existing connections to reduce latency, makes the kitty terminfo database available, etc. It's invocation is identical to the ssh command. For details on its usage, see :doc:`/kittens/ssh`." + ssh.IgnoreAllArgs = true + ssh.OnlyArgsAllowed = true + ssh.ArgCompleter = cli.CompletionForWrapper("ssh") +} + +func test_integration_with_python(args []string) (rc int, err error) { + f, err := os.CreateTemp("", "*.conf") + if err != nil { + return 1, err + } + defer func() { + f.Close() + os.Remove(f.Name()) + }() + _, err = io.Copy(f, os.Stdin) + if err != nil { + return 1, err + } + cd := &connection_data{ + request_id: "testing", remote_args: []string{}, + username: "testuser", hostname_for_match: "host.test", request_data: true, + test_script: args[0], echo_on: true, + } + opts, bad_lines, err := load_config(cd.hostname_for_match, cd.username, nil, f.Name()) + if err == nil { + if len(bad_lines) > 0 { + return 1, fmt.Errorf("Bad config lines: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err) + } + cd.host_opts = opts + err = get_remote_command(cd) + } + if err != nil { + return 1, err + } + data, err := json.Marshal(map[string]any{"cmd": cd.rcmd, "shm_name": cd.shm_name}) + if err == nil { + _, err = os.Stdout.Write(data) + os.Stdout.Close() + } + if err != nil { + return 1, err + } + + return +} + +func TestEntryPoint(root *cli.Command) { + root.AddSubCommand(&cli.Command{ + Name: "ssh", + OnlyArgsAllowed: true, + Run: func(cmd *cli.Command, args []string) (rc int, err error) { + return test_integration_with_python(args) + }, + }) + +} diff --git a/tools/cmd/ssh/main_test.go b/tools/cmd/ssh/main_test.go new file mode 100644 index 000000000..e3fb09054 --- /dev/null +++ b/tools/cmd/ssh/main_test.go @@ -0,0 +1,154 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io/fs" + "kitty/tools/utils/shm" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" +) + +var _ = fmt.Print + +func TestCloneEnv(t *testing.T) { + env := map[string]string{"a": "1", "b": "2"} + data, err := json.Marshal(env) + if err != nil { + t.Fatal(err) + } + mmap, err := shm.CreateTemp("", 128) + if err != nil { + t.Fatal(err) + } + defer mmap.Unlink() + copy(mmap.Slice()[4:], data) + binary.BigEndian.PutUint32(mmap.Slice(), uint32(len(data))) + mmap.Close() + x, err := add_cloned_env(mmap.Name()) + if err != nil { + t.Fatal(err) + } + diff := cmp.Diff(env, x) + if diff != "" { + t.Fatalf("Failed to deserialize env\n%s", diff) + } +} + +func basic_connection_data(overrides ...string) *connection_data { + ans := &connection_data{ + script_type: "sh", request_id: "123-123", remote_args: []string{}, + username: "testuser", hostname_for_match: "host.test", + } + opts, bad_lines, err := load_config(ans.hostname_for_match, ans.username, overrides, "") + if err != nil { + panic(err) + } + if len(bad_lines) != 0 { + panic(fmt.Sprintf("Bad config lines: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err)) + } + ans.host_opts = opts + return ans +} + +func TestSSHBootstrapScriptLimit(t *testing.T) { + cd := basic_connection_data() + err := get_remote_command(cd) + if err != nil { + t.Fatal(err) + } + total := 0 + for _, x := range cd.rcmd { + total += len(x) + } + if total > 9000 { + t.Fatalf("Bootstrap script too large: %d bytes", total) + } +} + +func TestSSHTarfile(t *testing.T) { + tdir := t.TempDir() + cd := basic_connection_data() + data, err := make_tarfile(cd, func(key string) (val string, found bool) { return }) + if err != nil { + t.Fatal(err) + } + cmd := exec.Command("tar", "xpzf", "-", "-C", tdir) + cmd.Stderr = os.Stderr + inp, err := cmd.StdinPipe() + if err != nil { + t.Fatal(err) + } + err = cmd.Start() + if err != nil { + t.Fatal(err) + } + _, err = inp.Write(data) + if err != nil { + t.Fatal(err) + } + inp.Close() + err = cmd.Wait() + if err != nil { + t.Fatal(err) + } + + seen := map[string]bool{} + err = filepath.WalkDir(tdir, func(name string, d fs.DirEntry, werr error) error { + if werr != nil { + return werr + } + rname, werr := filepath.Rel(tdir, name) + if werr != nil { + return werr + } + rname = strings.ReplaceAll(rname, "\\", "/") + if rname == "." { + return nil + } + fi, werr := d.Info() + if werr != nil { + return werr + } + if fi.Mode().Perm()&0o600 == 0 { + return fmt.Errorf("%s is not rw for its owner. Actual permissions: %s", rname, fi.Mode().String()) + } + seen[rname] = true + return nil + }) + if err != nil { + t.Fatal(err) + } + if !seen["data.sh"] { + t.Fatalf("data.sh missing") + } + for _, x := range []string{".terminfo/kitty.terminfo", ".terminfo/x/xterm-kitty"} { + if !seen["home/"+x] { + t.Fatalf("%s missing", x) + } + } + for _, x := range []string{"shell-integration/bash/kitty.bash", "shell-integration/fish/vendor_completions.d/kitty.fish"} { + if !seen[path.Join("home", cd.host_opts.Remote_dir, x)] { + t.Fatalf("%s missing", x) + } + } + for _, x := range []string{"kitty", "kitten"} { + p := filepath.Join(tdir, "home", cd.host_opts.Remote_dir, "kitty", "bin", x) + if err = unix.Access(p, unix.X_OK); err != nil { + t.Fatalf("Cannot execute %s with error: %s", x, err) + } + } + if seen[path.Join("home", cd.host_opts.Remote_dir, "shell-integration", "ssh", "kitten")] { + t.Fatalf("Contents of shell-integration/ssh not excluded") + } +} diff --git a/tools/cmd/ssh/utils.go b/tools/cmd/ssh/utils.go new file mode 100644 index 000000000..ef2514260 --- /dev/null +++ b/tools/cmd/ssh/utils.go @@ -0,0 +1,248 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "fmt" + "io" + "kitty" + "kitty/tools/config" + "kitty/tools/utils" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" +) + +var _ = fmt.Print + +var SSHExe = (&utils.Once[string]{Run: func() string { + ans := utils.Which("ssh") + if ans != "" { + return ans + } + ans = utils.Which("ssh", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin", "/usr/sbin", "/sbin") + if ans == "" { + ans = "ssh" + } + return ans +}}).Get + +var SSHOptions = (&utils.Once[map[string]string]{Run: func() (ssh_options map[string]string) { + defer func() { + if ssh_options == nil { + ssh_options = map[string]string{ + "4": "", "6": "", "A": "", "a": "", "C": "", "f": "", "G": "", "g": "", "K": "", "k": "", + "M": "", "N": "", "n": "", "q": "", "s": "", "T": "", "t": "", "V": "", "v": "", "X": "", + "x": "", "Y": "", "y": "", "B": "bind_interface", "b": "bind_address", "c": "cipher_spec", + "D": "[bind_address:]port", "E": "log_file", "e": "escape_char", "F": "configfile", "I": "pkcs11", + "i": "identity_file", "J": "[user@]host[:port]", "L": "address", "l": "login_name", "m": "mac_spec", + "O": "ctl_cmd", "o": "option", "p": "port", "Q": "query_option", "R": "address", + "S": "ctl_path", "W": "host:port", "w": "local_tun[:remote_tun]", + } + } + }() + cmd := exec.Command(SSHExe()) + stderr, err := cmd.StderrPipe() + if err != nil { + return + } + if err := cmd.Start(); err != nil { + return + } + raw, err := io.ReadAll(stderr) + if err != nil { + return + } + text := utils.UnsafeBytesToString(raw) + ssh_options = make(map[string]string, 32) + for { + pos := strings.IndexByte(text, '[') + if pos < 0 { + break + } + num := 1 + epos := pos + for num > 0 { + epos++ + switch text[epos] { + case '[': + num += 1 + case ']': + num -= 1 + } + } + q := text[pos+1 : epos] + text = text[epos:] + if len(q) < 2 || !strings.HasPrefix(q, "-") { + continue + } + opt, desc, found := strings.Cut(q, " ") + if found { + ssh_options[opt[1:]] = desc + } else { + for _, ch := range opt[1:] { + ssh_options[string(ch)] = "" + } + } + } + return +}}).Get + +func GetSSHCLI() (boolean_ssh_args *utils.Set[string], other_ssh_args *utils.Set[string]) { + other_ssh_args, boolean_ssh_args = utils.NewSet[string](32), utils.NewSet[string](32) + for k, v := range SSHOptions() { + k = "-" + k + if v == "" { + boolean_ssh_args.Add(k) + } else { + other_ssh_args.Add(k) + } + } + return +} + +func is_extra_arg(arg string, extra_args []string) string { + for _, x := range extra_args { + if arg == x || strings.HasPrefix(arg, x+"=") { + return x + } + } + return "" +} + +type ErrInvalidSSHArgs struct { + Msg string +} + +func (self *ErrInvalidSSHArgs) Error() string { + return self.Msg +} + +func PassthroughArgs() map[string]bool { + return map[string]bool{"-N": true, "-n": true, "-f": true, "-G": true, "-T": true} +} + +func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, server_args []string, passthrough bool, found_extra_args []string, err error) { + if extra_args == nil { + extra_args = []string{} + } + if len(args) == 0 { + passthrough = true + return + } + passthrough_args := PassthroughArgs() + boolean_ssh_args, other_ssh_args := GetSSHCLI() + ssh_args, server_args, found_extra_args = make([]string, 0, 16), make([]string, 0, 16), make([]string, 0, 16) + expecting_option_val := false + stop_option_processing := false + expecting_extra_val := "" + for _, argument := range args { + if len(server_args) > 1 || stop_option_processing { + server_args = append(server_args, argument) + continue + } + if strings.HasPrefix(argument, "-") && !expecting_option_val { + if argument == "--" { + stop_option_processing = true + continue + } + if len(extra_args) > 0 { + matching_ex := is_extra_arg(argument, extra_args) + if matching_ex != "" { + _, exval, found := strings.Cut(argument, "=") + if found { + found_extra_args = append(found_extra_args, matching_ex, exval) + } else { + expecting_extra_val = matching_ex + expecting_option_val = true + } + continue + } + } + // could be a multi-character option + all_args := []rune(argument[1:]) + for i, ch := range all_args { + arg := "-" + string(ch) + if passthrough_args[arg] { + passthrough = true + } + if boolean_ssh_args.Has(arg) { + ssh_args = append(ssh_args, arg) + continue + } + if other_ssh_args.Has(arg) { + ssh_args = append(ssh_args, arg) + if i+1 < len(all_args) { + ssh_args = append(ssh_args, string(all_args[i+1:])) + } else { + expecting_option_val = true + } + break + } + err = &ErrInvalidSSHArgs{Msg: "unknown option -- " + arg[1:]} + return + } + continue + } + if expecting_option_val { + if expecting_extra_val != "" { + found_extra_args = append(found_extra_args, expecting_extra_val, argument) + } else { + ssh_args = append(ssh_args, argument) + } + expecting_option_val = false + continue + } + server_args = append(server_args, argument) + } + if len(server_args) == 0 && !passthrough { + err = &ErrInvalidSSHArgs{Msg: ""} + } + return +} + +type SSHVersion struct{ Major, Minor int } + +func (self SSHVersion) SupportsAskpassRequire() bool { + return self.Major > 8 || (self.Major == 8 && self.Minor >= 4) +} + +var GetSSHVersion = (&utils.Once[SSHVersion]{Run: func() SSHVersion { + b, err := exec.Command(SSHExe(), "-V").CombinedOutput() + if err != nil { + return SSHVersion{} + } + m := regexp.MustCompile(`OpenSSH_(\d+).(\d+)`).FindSubmatch(b) + if len(m) == 3 { + maj, _ := strconv.Atoi(utils.UnsafeBytesToString(m[1])) + min, _ := strconv.Atoi(utils.UnsafeBytesToString(m[2])) + return SSHVersion{Major: maj, Minor: min} + } + return SSHVersion{} +}}).Get + +type KittyOpts struct { + Term, Shell_integration string +} + +func read_relevant_kitty_opts(path string) KittyOpts { + ans := KittyOpts{Term: kitty.KittyConfigDefaults.Term, Shell_integration: kitty.KittyConfigDefaults.Shell_integration} + handle_line := func(key, val string) error { + switch key { + case "term": + ans.Term = strings.TrimSpace(val) + case "shell_integration": + ans.Shell_integration = strings.TrimSpace(val) + } + return nil + } + cp := config.ConfigParser{LineHandler: handle_line} + cp.ParseFiles(path) + return ans +} + +var RelevantKittyOpts = (&utils.Once[KittyOpts]{Run: func() KittyOpts { + return read_relevant_kitty_opts(filepath.Join(utils.ConfigDir(), "kitty.conf")) +}}).Get diff --git a/tools/cmd/ssh/utils_test.go b/tools/cmd/ssh/utils_test.go new file mode 100644 index 000000000..2d9990ca3 --- /dev/null +++ b/tools/cmd/ssh/utils_test.go @@ -0,0 +1,68 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package ssh + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "kitty/tools/utils/shlex" + + "github.com/google/go-cmp/cmp" +) + +var _ = fmt.Print + +func TestGetSSHOptions(t *testing.T) { + m := SSHOptions() + if m["w"] != "local_tun[:remote_tun]" { + t.Fatalf("Unexpected set of SSH options: %#v", m) + } +} + +func TestParseSSHArgs(t *testing.T) { + split := func(x string) []string { + ans, err := shlex.Split(x) + if err != nil { + t.Fatal(err) + } + return ans + } + + p := func(args, expected_ssh_args, expected_server_args, expected_extra_args string, expected_passthrough bool) { + ssh_args, server_args, passthrough, extra_args, err := ParseSSHArgs(split(args), "--kitten") + if err != nil { + t.Fatal(err) + } + check := func(a, b any) { + diff := cmp.Diff(a, b) + if diff != "" { + t.Fatalf("Unexpected value for args: %s\n%s", args, diff) + } + } + check(split(expected_ssh_args), ssh_args) + check(split(expected_server_args), server_args) + check(split(expected_extra_args), extra_args) + check(expected_passthrough, passthrough) + } + p(`localhost`, ``, `localhost`, ``, false) + p(`-- localhost`, ``, `localhost`, ``, false) + p(`-46p23 localhost sh -c "a b"`, `-4 -6 -p 23`, `localhost sh -c "a b"`, ``, false) + p(`-46p23 -S/moose -W x:6 -- localhost sh -c "a b"`, `-4 -6 -p 23 -S /moose -W x:6`, `localhost sh -c "a b"`, ``, false) + p(`--kitten=abc -np23 --kitten xyz host`, `-n -p 23`, `host`, `--kitten abc --kitten xyz`, true) +} + +func TestRelevantKittyOpts(t *testing.T) { + tdir := t.TempDir() + path := filepath.Join(tdir, "kitty.conf") + os.WriteFile(path, []byte("term XXX\nshell_integration changed\nterm abcd"), 0o600) + rko := read_relevant_kitty_opts(path) + if rko.Term != "abcd" { + t.Fatalf("Unexpected TERM: %s", RelevantKittyOpts().Term) + } + if rko.Shell_integration != "changed" { + t.Fatalf("Unexpected shell_integration: %s", RelevantKittyOpts().Shell_integration) + } +} diff --git a/tools/cmd/tool/main.go b/tools/cmd/tool/main.go index 091a8ba87..a83ff5b56 100644 --- a/tools/cmd/tool/main.go +++ b/tools/cmd/tool/main.go @@ -10,6 +10,8 @@ import ( "kitty/tools/cmd/clipboard" "kitty/tools/cmd/edit_in_kitty" "kitty/tools/cmd/icat" + "kitty/tools/cmd/pytest" + "kitty/tools/cmd/ssh" "kitty/tools/cmd/unicode_input" "kitty/tools/cmd/update_self" "kitty/tools/tui" @@ -30,8 +32,12 @@ func KittyToolEntryPoints(root *cli.Command) { clipboard.EntryPoint(root) // icat icat.EntryPoint(root) + // ssh + ssh.EntryPoint(root) // unicode_input unicode_input.EntryPoint(root) + // __pytest__ + pytest.EntryPoint(root) // __hold_till_enter__ root.AddSubCommand(&cli.Command{ Name: "__hold_till_enter__", diff --git a/tools/config/api.go b/tools/config/api.go new file mode 100644 index 000000000..d9bf80779 --- /dev/null +++ b/tools/config/api.go @@ -0,0 +1,192 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package config + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/fs" + "kitty/tools/utils" + "os" + "path/filepath" + "strings" +) + +var _ = fmt.Print + +func StringToBool(x string) bool { + x = strings.ToLower(x) + return x == "y" || x == "yes" || x == "true" +} + +type ConfigLine struct { + Src_file, Line string + Line_number int + Err error +} + +type ConfigParser struct { + LineHandler func(key, val string) error + CommentsHandler func(line string) error + SourceHandler func(text, path string) + + bad_lines []ConfigLine + seen_includes map[string]bool + override_env []string +} + +type Scanner interface { + Scan() bool + Text() string + Err() error +} + +func (self *ConfigParser) BadLines() []ConfigLine { + return self.bad_lines +} + +func (self *ConfigParser) parse(scanner Scanner, name, base_path_for_includes string, depth int) error { + if self.seen_includes[name] { // avoid include loops + return nil + } + self.seen_includes[name] = true + + recurse := func(r io.Reader, nname, base_path_for_includes string) error { + if depth > 32 { + return fmt.Errorf("Too many nested include directives while processing config file: %s", name) + } + escanner := bufio.NewScanner(r) + return self.parse(escanner, nname, base_path_for_includes, depth+1) + } + + lnum := 0 + make_absolute := func(path string) (string, error) { + if path == "" { + return "", fmt.Errorf("Empty include paths not allowed") + } + if !filepath.IsAbs(path) { + path = filepath.Join(base_path_for_includes, path) + } + return path, nil + } + for scanner.Scan() { + line := strings.TrimLeft(scanner.Text(), " ") + lnum++ + if line == "" { + continue + } + if line[0] == '#' { + if self.CommentsHandler != nil { + err := self.CommentsHandler(line) + if err != nil { + self.bad_lines = append(self.bad_lines, ConfigLine{Src_file: name, Line: line, Line_number: lnum, Err: err}) + } + } + continue + } + key, val, _ := strings.Cut(line, " ") + switch key { + default: + err := self.LineHandler(key, val) + if err != nil { + self.bad_lines = append(self.bad_lines, ConfigLine{Src_file: name, Line: line, Line_number: lnum, Err: err}) + } + case "include", "globinclude", "envinclude": + var includes []string + switch key { + case "include": + aval, err := make_absolute(val) + if err == nil { + includes = []string{aval} + } + case "globinclude": + aval, err := make_absolute(val) + if err == nil { + matches, err := filepath.Glob(aval) + if err == nil { + includes = matches + } + } + case "envinclude": + env := self.override_env + if env == nil { + env = os.Environ() + } + for _, x := range env { + key, eval, _ := strings.Cut(x, "=") + is_match, err := filepath.Match(val, key) + if is_match && err == nil { + err := recurse(strings.NewReader(eval), "", base_path_for_includes) + if err != nil { + return err + } + } + } + } + if len(includes) > 0 { + for _, incpath := range includes { + raw, err := os.ReadFile(incpath) + if err == nil { + err := recurse(bytes.NewReader(raw), incpath, filepath.Dir(incpath)) + if err != nil { + return err + } + } else if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("Failed to process include %#v with error: %w", incpath, err) + } + } + } + } + } + return nil +} + +func (self *ConfigParser) ParseFiles(paths ...string) error { + for _, path := range paths { + apath, err := filepath.Abs(path) + if err == nil { + path = apath + } + raw, err := os.ReadFile(path) + if err != nil { + return err + } + scanner := bufio.NewScanner(bytes.NewReader(raw)) + self.seen_includes = make(map[string]bool) + err = self.parse(scanner, path, filepath.Dir(path), 0) + if err != nil { + return err + } + if self.SourceHandler != nil { + self.SourceHandler(utils.UnsafeBytesToString(raw), path) + } + } + return nil +} + +type LinesScanner struct { + lines []string +} + +func (self *LinesScanner) Scan() bool { + return len(self.lines) > 0 +} + +func (self *LinesScanner) Text() string { + ans := self.lines[0] + self.lines = self.lines[1:] + return ans +} + +func (self *LinesScanner) Err() error { + return nil +} + +func (self *ConfigParser) ParseOverrides(overrides ...string) error { + s := LinesScanner{lines: overrides} + self.seen_includes = make(map[string]bool) + return self.parse(&s, "", utils.ConfigDir(), 0) +} diff --git a/tools/config/api_test.go b/tools/config/api_test.go new file mode 100644 index 000000000..223a57f3d --- /dev/null +++ b/tools/config/api_test.go @@ -0,0 +1,62 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package config + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" +) + +var _ = fmt.Print + +func TestConfigParsing(t *testing.T) { + tdir := t.TempDir() + conf_file := filepath.Join(tdir, "a.conf") + os.Mkdir(filepath.Join(tdir, "sub"), 0o700) + os.WriteFile(conf_file, []byte( + `error main +# ignore me +a one +#: other +include sub/b.conf +b +include non-existent +globinclude sub/c?.conf +`), 0o600) + os.WriteFile(filepath.Join(tdir, "sub/b.conf"), []byte("incb cool\ninclude a.conf"), 0o600) + os.WriteFile(filepath.Join(tdir, "sub/c1.conf"), []byte("inc1 cool"), 0o600) + os.WriteFile(filepath.Join(tdir, "sub/c2.conf"), []byte("inc2 cool\nenvinclude ENVINCLUDE"), 0o600) + os.WriteFile(filepath.Join(tdir, "sub/c.conf"), []byte("inc notcool\nerror sub"), 0o600) + + var parsed_lines []string + pl := func(key, val string) error { + if key == "error" { + return fmt.Errorf("%s", val) + } + parsed_lines = append(parsed_lines, key+" "+val) + return nil + } + + p := ConfigParser{LineHandler: pl, override_env: []string{"ENVINCLUDE=env cool\ninclude c.conf"}} + err := p.ParseFiles(conf_file) + if err != nil { + t.Fatal(err) + } + err = p.ParseOverrides("over one", "over two") + diff := cmp.Diff([]string{"a one", "incb cool", "b ", "inc1 cool", "inc2 cool", "env cool", "inc notcool", "over one", "over two"}, parsed_lines) + if diff != "" { + t.Fatalf("Unexpected parsed config values:\n%s", diff) + } + bad_lines := []string{} + for _, bl := range p.BadLines() { + bad_lines = append(bad_lines, fmt.Sprintf("%s: %d", filepath.Base(bl.Src_file), bl.Line_number)) + } + diff = cmp.Diff([]string{"a.conf: 1", "c.conf: 2"}, bad_lines) + if diff != "" { + t.Fatalf("Unexpected bad lines:\n%s", diff) + } +} diff --git a/tools/themes/collection.go b/tools/themes/collection.go new file mode 100644 index 000000000..dd1f45d7e --- /dev/null +++ b/tools/themes/collection.go @@ -0,0 +1,430 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package themes + +import ( + "archive/zip" + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "kitty/tools/config" + "kitty/tools/utils" + "kitty/tools/utils/style" + "net/http" + "os" + "path" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "golang.org/x/exp/maps" +) + +var _ = fmt.Print + +type JSONMetadata struct { + Etag string `json:"etag"` + Timestamp string `json:"timestamp"` +} + +var ErrNoCacheFound = errors.New("No cache found and max cache age is negative") + +func fetch_cached(name, url, cache_path string, max_cache_age time.Duration) (string, error) { + cache_path = filepath.Join(cache_path, name+".zip") + zf, err := zip.OpenReader(cache_path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return "", err + } + + var jm JSONMetadata + if err == nil { + err = json.Unmarshal(utils.UnsafeStringToBytes(zf.Comment), &jm) + if max_cache_age < 0 { + return cache_path, nil + } + cache_age, err := utils.ISO8601Parse(jm.Timestamp) + if err == nil { + if time.Now().Before(cache_age.Add(max_cache_age)) { + return cache_path, nil + } + } + } + if max_cache_age < 0 { + return "", ErrNoCacheFound + } + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return "", err + } + if jm.Etag != "" { + req.Header.Add("If-None-Match", jm.Etag) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("Failed to download %s with error: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusNotModified { + return cache_path, nil + } + return "", fmt.Errorf("Failed to download %s with HTTP error: %s", url, resp.Status) + } + var tf, tf2 *os.File + tf, err = os.CreateTemp(filepath.Dir(cache_path), name+".temp-*") + if err == nil { + tf2, err = os.CreateTemp(filepath.Dir(cache_path), name+".temp-*") + } + defer func() { + if tf != nil { + tf.Close() + os.Remove(tf.Name()) + tf = nil + } + if tf2 != nil { + tf2.Close() + os.Remove(tf2.Name()) + tf2 = nil + } + }() + if err != nil { + return "", fmt.Errorf("Failed to create temp file in %s with error: %w", filepath.Dir(cache_path), err) + } + _, err = io.Copy(tf, resp.Body) + if err != nil { + return "", fmt.Errorf("Failed to download %s with error: %w", url, err) + } + r, err := zip.OpenReader(tf.Name()) + if err != nil { + return "", fmt.Errorf("Failed to open downloaded zip file with error: %w", err) + } + w := zip.NewWriter(tf2) + jm.Etag = resp.Header.Get("ETag") + jm.Timestamp = utils.ISO8601Format(time.Now()) + comment, _ := json.Marshal(jm) + w.SetComment(utils.UnsafeBytesToString(comment)) + for _, file := range r.File { + err = w.Copy(file) + if err != nil { + return "", fmt.Errorf("Failed to copy zip file from source to destination archive") + } + } + err = w.Close() + if err != nil { + return "", err + } + tf2.Close() + err = os.Rename(tf2.Name(), cache_path) + if err != nil { + return "", fmt.Errorf("Failed to atomic rename temp file to %s with error: %w", cache_path, err) + } + tf2 = nil + return cache_path, nil +} + +func FetchCached(max_cache_age time.Duration) (string, error) { + return fetch_cached("kitty-themes", "https://codeload.github.com/kovidgoyal/kitty-themes/zip/master", utils.CacheDir(), max_cache_age) +} + +type ThemeMetadata struct { + Name string `json:"name"` + Filepath string `json:"file"` + Is_dark bool `json:"is_dark"` + Num_settings int `json:"num_settings"` + Blurb string `json:"blurb"` + License string `json:"license"` + Upstream string `json:"upstream"` + Author string `json:"author"` +} + +func parse_theme_metadata(path string) (*ThemeMetadata, map[string]string, error) { + var in_metadata, in_blurb, finished_metadata bool + ans := ThemeMetadata{} + settings := map[string]string{} + read_is_dark := func(key, val string) (err error) { + settings[key] = val + if key == "background" { + val = strings.TrimSpace(val) + if val != "" { + bg, err := style.ParseColor(val) + if err == nil { + ans.Is_dark = utils.Max(bg.Red, bg.Green, bg.Green) < 115 + } + } + } + return + } + read_metadata := func(line string) (err error) { + is_block := strings.HasPrefix(line, "## ") + if in_metadata && !is_block { + finished_metadata = true + } + if finished_metadata { + return + } + if !in_metadata && is_block { + in_metadata = true + } + if !in_metadata { + return + } + line = line[3:] + if in_blurb { + ans.Blurb += " " + line + return + } + key, val, found := strings.Cut(line, ":") + if !found { + return + } + key = strings.TrimSpace(strings.ToLower(key)) + val = strings.TrimSpace(val) + switch key { + case "name": + ans.Name = val + case "author": + ans.Author = val + case "upstream": + ans.Upstream = val + case "blurb": + ans.Blurb = val + in_blurb = true + case "license": + ans.License = val + } + return + } + cp := config.ConfigParser{LineHandler: read_is_dark, CommentsHandler: read_metadata} + err := cp.ParseFiles(path) + if err != nil { + return nil, nil, err + } + ans.Num_settings = len(settings) + return &ans, settings, nil +} + +type Theme struct { + metadata *ThemeMetadata + + code string + settings map[string]string + zip_reader *zip.File + is_user_defined bool +} + +func (self *Theme) load_code() (string, error) { + if self.zip_reader != nil { + f, err := self.zip_reader.Open() + self.zip_reader = nil + if err != nil { + return "", err + } + defer f.Close() + data, err := io.ReadAll(f) + if err != nil { + return "", err + } + self.code = utils.UnsafeBytesToString(data) + } + return self.code, nil +} + +func (self *Theme) Settings() (map[string]string, error) { + if self.zip_reader != nil { + code, err := self.load_code() + if err != nil { + return nil, err + } + self.settings = make(map[string]string, 64) + scanner := bufio.NewScanner(strings.NewReader(code)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line != "" && line[0] != '#' { + key, val, found := strings.Cut(line, " ") + if found { + self.settings[key] = val + } + } + } + } + return self.settings, nil +} + +func (self *Theme) AsEscapeCodes() (string, error) { + settings, err := self.Settings() + if err != nil { + return "", err + } + w := strings.Builder{} + w.Grow(4096) + + set_color := func(i int, sharp string) { + w.WriteByte(';') + w.WriteString(strconv.Itoa(i)) + w.WriteByte(';') + w.WriteString(sharp) + } + + set_default_color := func(name, defval string, num int) { + w.WriteString("\033]") + defer func() { w.WriteString("\033\\") }() + val, found := settings[name] + if !found { + val = defval + } + if val != "" { + rgba, err := style.ParseColor(val) + if err == nil { + w.WriteString(strconv.Itoa(num)) + w.WriteByte(';') + w.WriteString(rgba.AsRGBSharp()) + return + } + } + w.WriteByte('1') + w.WriteString(strconv.Itoa(num)) + } + set_default_color("foreground", style.DefaultColors.Foreground, 10) + set_default_color("background", style.DefaultColors.Background, 11) + set_default_color("cursor", style.DefaultColors.Cursor, 12) + set_default_color("selection_background", style.DefaultColors.SelectionBg, 17) + set_default_color("selection_foreground", style.DefaultColors.SelectionFg, 19) + + w.WriteString("\033]4") + for i := 0; i < 256; i++ { + key := "color" + strconv.Itoa(i) + val := settings[key] + if val != "" { + rgba, err := style.ParseColor(val) + if err == nil { + set_color(i, rgba.AsRGBSharp()) + continue + } + } + rgba := style.RGBA{} + rgba.FromRGB(style.ColorTable[i]) + set_color(i, rgba.AsRGBSharp()) + } + w.WriteString("\033\\") + return w.String(), nil +} + +type Themes struct { + name_map map[string]*Theme + index_map []string +} + +var camel_case_pat = (&utils.Once[*regexp.Regexp]{Run: func() *regexp.Regexp { + return regexp.MustCompile(`([a-z])([A-Z])`) +}}).Get + +func theme_name_from_file_name(fname string) string { + fname = fname[:len(fname)-len(path.Ext(fname))] + fname = strings.ReplaceAll(fname, "_", " ") + fname = camel_case_pat().ReplaceAllString(fname, "$1 $2") + return strings.Join(utils.Map(strings.Split(fname, " "), strings.Title), " ") +} + +func (self *Themes) AddFromFile(path string) (*Theme, error) { + m, conf, err := parse_theme_metadata(path) + if err != nil { + return nil, err + } + if m.Name == "" { + m.Name = theme_name_from_file_name(filepath.Base(path)) + } + t := Theme{metadata: m, is_user_defined: true, settings: conf} + self.name_map[m.Name] = &t + return &t, nil + +} + +func (self *Themes) add_from_dir(dirpath string) error { + entries, err := os.ReadDir(dirpath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + err = nil + } + return err + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".conf") { + if _, err = self.AddFromFile(filepath.Join(dirpath, e.Name())); err != nil { + return err + } + } + } + return nil +} + +func (self *Themes) add_from_zip_file(zippath string) (io.Closer, error) { + r, err := zip.OpenReader(zippath) + if err != nil { + return nil, err + } + name_map := make(map[string]*zip.File, len(r.File)) + var themes []ThemeMetadata + theme_dir := "" + for _, file := range r.File { + name_map[file.Name] = file + if path.Base(file.Name) == "themes.json" { + theme_dir = path.Dir(file.Name) + fr, err := file.Open() + if err != nil { + return nil, fmt.Errorf("Error while opening %s from the ZIP file: %w", file.Name, err) + } + defer fr.Close() + raw, err := io.ReadAll(fr) + if err != nil { + return nil, fmt.Errorf("Error while reading %s from the ZIP file: %w", file.Name, err) + } + err = json.Unmarshal(raw, &themes) + if err != nil { + return nil, fmt.Errorf("Error while decoding %s: %w", file.Name, err) + } + } + } + if theme_dir == "" { + return nil, fmt.Errorf("No themes.json found in ZIP file") + } + for _, theme := range themes { + key := path.Join(theme_dir, theme.Filepath) + f := name_map[key] + if f != nil { + t := Theme{metadata: &theme, zip_reader: f} + self.name_map[theme.Name] = &t + } + } + return r, nil +} + +func (self *Themes) ThemeByName(name string) *Theme { + return self.name_map[name] +} + +func LoadThemes(cache_age time.Duration) (ans *Themes, closer io.Closer, err error) { + zip_path, err := FetchCached(cache_age) + ans = &Themes{name_map: make(map[string]*Theme)} + if err != nil { + return nil, nil, err + } + if closer, err = ans.add_from_zip_file(zip_path); err != nil { + return nil, nil, err + } + if err = ans.add_from_dir(filepath.Join(utils.ConfigDir(), "themes")); err != nil { + return nil, nil, err + } + ans.index_map = maps.Keys(ans.name_map) + ans.index_map = utils.StableSortWithKey(ans.index_map, strings.ToLower) + return ans, closer, nil +} + +func ThemeFromFile(path string) (*Theme, error) { + ans := &Themes{name_map: make(map[string]*Theme)} + return ans.AddFromFile(path) +} diff --git a/tools/themes/collection_test.go b/tools/themes/collection_test.go new file mode 100644 index 000000000..c8b21c30b --- /dev/null +++ b/tools/themes/collection_test.go @@ -0,0 +1,157 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package themes + +import ( + "archive/zip" + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +var _ = fmt.Print + +func TestThemeCollections(t *testing.T) { + for fname, expected := range map[string]string{ + "moose": "Moose", + "mooseCat": "Moose Cat", + "a_bC": "A B C", + } { + actual := theme_name_from_file_name(fname) + if diff := cmp.Diff(expected, actual); diff != "" { + t.Fatalf("Unexpected theme name for %s:\n%s", fname, diff) + } + } + + tdir := t.TempDir() + + pt := func(expected ThemeMetadata, lines ...string) { + os.WriteFile(filepath.Join(tdir, "temp.conf"), []byte(strings.Join(lines, "\n")), 0o600) + actual, _, err := parse_theme_metadata(filepath.Join(tdir, "temp.conf")) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(&expected, actual); diff != "" { + t.Fatalf("Failed to parse:\n%s\n\n%s", strings.Join(lines, "\n"), diff) + } + } + pt(ThemeMetadata{Name: "XYZ", Blurb: "a b", Author: "A", Is_dark: true, Num_settings: 2}, + "# some crap", " ", "## ", "## author: A", "## name: XYZ", "## blurb: a", "## b", "", "color red", "background black", "include inc.conf") + os.WriteFile(filepath.Join(tdir, "inc.conf"), []byte("background white"), 0o600) + pt(ThemeMetadata{Name: "XYZ", Blurb: "a b", Author: "A", Num_settings: 2}, + "# some crap", " ", "## ", "## author: A", "## name: XYZ", "## blurb: a", "## b", "", "color red", "background black", "include inc.conf") + + buf := bytes.Buffer{} + zw := zip.NewWriter(&buf) + fw, _ := zw.Create("x/themes.json") + fw.Write([]byte(`[ + { + "author": "X Y", + "blurb": "A dark color scheme for the kitty terminal.", + "file": "themes/Alabaster_Dark.conf", + "is_dark": true, + "license": "MIT", + "name": "Alabaster Dark", + "num_settings": 30, + "upstream": "https://xxx.com" + }, + { + "name": "Empty", "file": "empty.conf" + } + ]`)) + fw, _ = zw.Create("x/empty.conf") + fw.Write([]byte("empty")) + fw, _ = zw.Create("x/themes/Alabaster_Dark.conf") + fw.Write([]byte("alabaster")) + zw.Close() + + received_etag := "" + request_count := 0 + check_etag := true + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + request_count++ + received_etag = r.Header.Get("If-None-Match") + if check_etag && received_etag == `"xxx"` { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Add("ETag", `"xxx"`) + w.Write(buf.Bytes()) + })) + defer ts.Close() + + _, err := fetch_cached("test", ts.URL, tdir, 0) + if err != nil { + t.Fatal(err) + } + r, err := zip.OpenReader(filepath.Join(tdir, "test.zip")) + if err != nil { + t.Fatal(err) + } + var jm JSONMetadata + err = json.Unmarshal([]byte(r.Comment), &jm) + if err != nil { + t.Fatal(err) + } + if jm.Etag != `"xxx"` { + t.Fatalf("Unexpected ETag: %#v", jm.Etag) + } + _, err = fetch_cached("test", ts.URL, tdir, time.Hour) + if err != nil { + t.Fatal(err) + } + if request_count != 1 { + t.Fatal("Cached zip file was not used") + } + before, _ := os.Stat(filepath.Join(tdir, "test.zip")) + _, err = fetch_cached("test", ts.URL, tdir, 0) + if err != nil { + t.Fatal(err) + } + if request_count != 2 { + t.Fatal("Cached zip file was incorrectly used") + } + if received_etag != `"xxx"` { + t.Fatalf("Got invalid ETag: %#v", received_etag) + } + after, _ := os.Stat(filepath.Join(tdir, "test.zip")) + if before.ModTime() != after.ModTime() { + t.Fatal("Cached zip file was incorrectly re-downloaded") + } + check_etag = false + _, err = fetch_cached("test", ts.URL, tdir, 0) + if err != nil { + t.Fatal(err) + } + after2, _ := os.Stat(filepath.Join(tdir, "test.zip")) + if after2.ModTime() != after.ModTime() { + t.Fatal("Cached zip file was incorrectly not re-downloaded") + } + coll := Themes{name_map: map[string]*Theme{}} + closer, err := coll.add_from_zip_file(filepath.Join(tdir, "test.zip")) + if err != nil { + t.Fatal(err) + } + defer closer.Close() + if code, err := coll.ThemeByName("Empty").load_code(); code != "empty" { + if err != nil { + t.Fatal(err) + } + t.Fatal("failed to load code for empty theme") + } + if code, err := coll.ThemeByName("Alabaster Dark").load_code(); code != "alabaster" { + if err != nil { + t.Fatal(err) + } + t.Fatal("failed to load code for alabaster theme") + } +} diff --git a/tools/tty/tty.go b/tools/tty/tty.go index 88391f361..7904a3b7a 100644 --- a/tools/tty/tty.go +++ b/tools/tty/tty.go @@ -148,6 +148,13 @@ func (self *Term) Close() error { return err } +func (self *Term) WasEchoOnOriginally() bool { + if len(self.states) > 0 { + return self.states[0].Lflag&unix.ECHO != 0 + } + return false +} + func (self *Term) Tcgetattr(ans *unix.Termios) error { return eintr_retry_noret(func() error { return Tcgetattr(self.Fd(), ans) }) } @@ -256,13 +263,19 @@ func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error) } num_ready, err := pselect() if err != nil { - return + return 0, err } if num_ready == 0 { err = os.ErrDeadlineExceeded - return + return 0, err + } + for { + n, err = self.Read(b) + if errors.Is(err, unix.EINTR) { + continue + } + return n, err } - return self.Read(b) } func (self *Term) Read(b []byte) (int, error) { @@ -273,10 +286,14 @@ func (self *Term) Write(b []byte) (int, error) { return self.os_file.Write(b) } +func is_temporary_error(err error) bool { + return errors.Is(err, unix.EINTR) || errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, io.ErrShortWrite) +} + func (self *Term) WriteAll(b []byte) error { for len(b) > 0 { n, err := self.os_file.Write(b) - if err != nil && !errors.Is(err, io.ErrShortWrite) { + if err != nil && !is_temporary_error(err) { return err } b = b[n:] @@ -284,6 +301,10 @@ func (self *Term) WriteAll(b []byte) error { return nil } +func (self *Term) WriteAllString(s string) error { + return self.WriteAll(utils.UnsafeStringToBytes(s)) +} + func (self *Term) WriteString(b string) (int, error) { return self.os_file.WriteString(b) } diff --git a/tools/tui/dcs_to_kitty.go b/tools/tui/dcs_to_kitty.go new file mode 100644 index 000000000..4621afce3 --- /dev/null +++ b/tools/tui/dcs_to_kitty.go @@ -0,0 +1,28 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package tui + +import ( + "encoding/base64" + "fmt" + + "kitty/tools/utils" +) + +var _ = fmt.Print + +func DCSToKitty(msgtype, payload string) (string, error) { + data := base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(payload)) + ans := "\x1bP@kitty-" + msgtype + "|" + data + tmux := TmuxSocketAddress() + if tmux != "" { + err := TmuxAllowPassthrough() + if err != nil { + return "", err + } + ans = "\033Ptmux;\033" + ans + "\033\033\\\033\\" + } else { + ans += "\033\\" + } + return ans, nil +} diff --git a/tools/tui/tmux.go b/tools/tui/tmux.go new file mode 100644 index 000000000..95b56dcda --- /dev/null +++ b/tools/tui/tmux.go @@ -0,0 +1,56 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package tui + +import ( + "fmt" + "kitty/tools/utils" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + + "github.com/shirou/gopsutil/v3/process" + "golang.org/x/sys/unix" +) + +var _ = fmt.Print + +func tmux_socket_address() (socket string) { + socket = os.Getenv("TMUX") + if socket == "" { + return "" + } + addr, pid_str, found := strings.Cut(socket, ",") + if !found { + return "" + } + if unix.Access(addr, unix.R_OK|unix.W_OK) != nil { + return "" + } + pid, err := strconv.ParseInt(pid_str, 10, 32) + if err != nil { + return "" + } + p, err := process.NewProcess(int32(pid)) + if err != nil { + return "" + } + cmd, err := p.CmdlineSlice() + if err != nil { + return "" + } + if len(cmd) > 0 && strings.ToLower(filepath.Base(cmd[0])) != "tmux" { + return "" + } + return socket +} + +var TmuxSocketAddress = (&utils.Once[string]{Run: tmux_socket_address}).Get + +func tmux_allow_passthrough() error { + return exec.Command("tmux", "set", "-p", "allow-passthrough", "on").Run() +} + +var TmuxAllowPassthrough = (&utils.Once[error]{Run: tmux_allow_passthrough}).Get diff --git a/tools/unicode_names/query.go b/tools/unicode_names/query.go index 6fe9d3981..6e8e20a90 100644 --- a/tools/unicode_names/query.go +++ b/tools/unicode_names/query.go @@ -4,11 +4,9 @@ package unicode_names import ( "bytes" - "compress/zlib" _ "embed" "encoding/binary" "fmt" - "io" "strings" "sync" "time" @@ -64,33 +62,8 @@ func parse_record(record []byte, mark uint16) { var parse_once sync.Once -func read_all(r io.Reader, expected_size int) ([]byte, error) { - b := make([]byte, 0, expected_size) - for { - if len(b) == cap(b) { - // Add more capacity (let append pick how much). - b = append(b, 0)[:len(b)] - } - n, err := r.Read(b[len(b):cap(b)]) - b = b[:len(b)+n] - if err != nil { - if err == io.EOF { - err = nil - } - return b, err - } - } -} - func parse_data() { - compressed := utils.UnsafeStringToBytes(unicode_name_data) - uncompressed_size := binary.LittleEndian.Uint32(compressed) - r, _ := zlib.NewReader(bytes.NewReader(compressed[4:])) - defer r.Close() - raw, err := read_all(r, int(uncompressed_size)) - if err != nil { - panic(err) - } + raw := utils.ReadCompressedEmbeddedData(unicode_name_data) num_of_lines := binary.LittleEndian.Uint32(raw) raw = raw[4:] num_of_words := binary.LittleEndian.Uint32(raw) diff --git a/tools/utils/atomic-write.go b/tools/utils/atomic-write.go index 1b9402f37..41f83896d 100644 --- a/tools/utils/atomic-write.go +++ b/tools/utils/atomic-write.go @@ -12,6 +12,33 @@ import ( var _ = fmt.Print +func AtomicCreateSymlink(oldname, newname string) (err error) { + err = os.Symlink(oldname, newname) + if err == nil { + return nil + } + if !errors.Is(err, fs.ErrExist) { + return err + } + if et, err := os.Readlink(newname); err == nil && et == oldname { + return nil + } + for { + tempname := newname + RandomFilename() + err = os.Symlink(oldname, tempname) + if err == nil { + err = os.Rename(tempname, newname) + if err != nil { + os.Remove(tempname) + } + return err + } + if !errors.Is(err, fs.ErrExist) { + return err + } + } +} + func AtomicWriteFile(path string, data []byte, perm os.FileMode) (err error) { npath, err := filepath.EvalSymlinks(path) if errors.Is(err, fs.ErrNotExist) { diff --git a/tools/utils/embed.go b/tools/utils/embed.go new file mode 100644 index 000000000..c485477ba --- /dev/null +++ b/tools/utils/embed.go @@ -0,0 +1,46 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package utils + +import ( + "bytes" + "compress/bzip2" + "encoding/binary" + "fmt" + "io" +) + +var _ = fmt.Print + +func ReadAll(r io.Reader, expected_size int) ([]byte, error) { + b := make([]byte, 0, expected_size) + for { + if len(b) == cap(b) { + // Add more capacity (let append pick how much). + b = append(b, 0)[:len(b)] + } + n, err := r.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if err == io.EOF { + err = nil + } + return b, err + } + } +} + +func ReadCompressedEmbeddedData(raw string) []byte { + compressed := UnsafeStringToBytes(raw) + uncompressed_size := binary.LittleEndian.Uint32(compressed) + r := bzip2.NewReader(bytes.NewReader(compressed[4:])) + ans, err := ReadAll(r, int(uncompressed_size)) + if err != nil { + panic(err) + } + return ans +} + +func ReaderForCompressedEmbeddedData(raw string) io.Reader { + return bzip2.NewReader(bytes.NewReader(UnsafeStringToBytes(raw)[4:])) +} diff --git a/tools/utils/iso8601.go b/tools/utils/iso8601.go new file mode 100644 index 000000000..2c4f664af --- /dev/null +++ b/tools/utils/iso8601.go @@ -0,0 +1,166 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package utils + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +var _ = fmt.Print + +func is_digit(x byte) bool { + return '0' <= x && x <= '9' +} + +// The following is copied from the Go standard library to implement date range validation logic +// equivalent to the behaviour of Go's time.Parse. + +func isLeap(year int) bool { + return year%4 == 0 && (year%100 != 0 || year%400 == 0) +} + +// daysInMonth is the number of days for non-leap years in each calendar month starting at 1 +var daysInMonth = [13]int{0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31} + +func daysIn(m time.Month, year int) int { + if m == time.February && isLeap(year) { + return 29 + } + return daysInMonth[int(m)] +} + +func ISO8601Parse(raw string) (time.Time, error) { + orig := raw + raw = strings.TrimSpace(raw) + + required_number := func(num_digits int) (int, error) { + if len(raw) < num_digits { + return 0, fmt.Errorf("Insufficient digits") + } + text := raw[:num_digits] + raw = raw[num_digits:] + ans, err := strconv.ParseUint(text, 10, 32) + return int(ans), err + + } + optional_separator := func(x byte) bool { + if len(raw) > 0 && raw[0] == x { + raw = raw[1:] + } + return len(raw) > 0 && is_digit(raw[0]) + } + + errf := func(msg string) (time.Time, error) { + return time.Time{}, fmt.Errorf("Invalid ISO8601 timestamp: %#v. %s", orig, msg) + } + + optional_separator('+') + year, err := required_number(4) + if err != nil { + return errf("timestamp does not start with a 4 digit year") + } + var month int = 1 + var day int = 1 + if optional_separator('-') { + month, err = required_number(2) + if err != nil { + return errf("timestamp does not have a valid 2 digit month") + } + if optional_separator('-') { + day, err = required_number(2) + if err != nil { + return errf("timestamp does not have a valid 2 digit day") + } + } + } + + var hour, minute, second, nsec int + + if len(raw) > 0 && (raw[0] == 'T' || raw[0] == ' ') { + raw = raw[1:] + hour, err = required_number(2) + if err != nil { + return errf("timestamp does not have a valid 2 digit hour") + } + if optional_separator(':') { + minute, err = required_number(2) + if err != nil { + return errf("timestamp does not have a valid 2 digit minute") + } + if optional_separator(':') { + second, err = required_number(2) + if err != nil { + return errf("timestamp does not have a valid 2 digit second") + } + } + } + if len(raw) > 0 && (raw[0] == '.' || raw[0] == ',') { + raw = raw[1:] + num_digits := 0 + for len(raw) > num_digits && is_digit(raw[num_digits]) { + num_digits++ + } + text := raw[:num_digits] + raw = raw[num_digits:] + extra := 9 - len(text) + if extra < 0 { + text = text[:9] + } + if text != "" { + n, err := strconv.ParseUint(text, 10, 64) + if err != nil { + return errf("timestamp does not have a valid nanosecond field") + } + nsec = int(n) + for ; extra > 0; extra-- { + nsec *= 10 + } + } + } + } + switch { + case month < 1 || month > 12: + return errf("timestamp has invalid month value") + case day < 1 || day > 31 || day > daysIn(time.Month(month), year): + return errf("timestamp has invalid day value") + case hour < 0 || hour > 23: + return errf("timestamp has invalid hour value") + case minute < 0 || minute > 59: + return errf("timestamp has invalid minute value") + case second < 0 || second > 59: + return errf("timestamp has invalid second value") + } + loc := time.UTC + tzsign, tzhour, tzminute := 0, 0, 0 + + if len(raw) > 0 { + switch raw[0] { + case '+': + tzsign = 1 + case '-': + tzsign = -1 + } + } + if tzsign != 0 { + raw = raw[1:] + tzhour, err = required_number(2) + if err != nil { + return errf("timestamp has invalid timezone hour") + } + optional_separator(':') + tzminute, err = required_number(2) + if err != nil { + tzminute = 0 + } + seconds := tzhour*3600 + tzminute*60 + loc = time.FixedZone("", tzsign*seconds) + } + return time.Date(year, time.Month(month), day, hour, minute, second, nsec, loc), err +} + +func ISO8601Format(x time.Time) string { + return x.Format(time.RFC3339Nano) +} diff --git a/tools/utils/iso8601_test.go b/tools/utils/iso8601_test.go new file mode 100644 index 000000000..648dbe6c3 --- /dev/null +++ b/tools/utils/iso8601_test.go @@ -0,0 +1,40 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package utils + +import ( + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +var _ = fmt.Print + +func TestISO8601(t *testing.T) { + now := time.Now() + + tt := func(raw string, expected time.Time) { + actual, err := ISO8601Parse(raw) + if err != nil { + t.Fatalf("Parsing: %#v failed with error: %s", raw, err) + } + if diff := cmp.Diff(expected, actual); diff != "" { + t.Fatalf("Parsing: %#v failed:\n%s", raw, diff) + } + } + + tt(ISO8601Format(now), now) + tt("2023-02-08T07:24:09.551975+00:00", time.Date(2023, 2, 8, 7, 24, 9, 551975000, time.UTC)) + tt("2023-02-08T07:24:09.551975Z", time.Date(2023, 2, 8, 7, 24, 9, 551975000, time.UTC)) + tt("2023", time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)) + tt("2023-11-13", time.Date(2023, 11, 13, 0, 0, 0, 0, time.UTC)) + tt("2023-11-13 07:23", time.Date(2023, 11, 13, 7, 23, 0, 0, time.UTC)) + tt("2023-11-13 07:23:01", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC)) + tt("2023-11-13 07:23:01.", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC)) + tt("2023-11-13 07:23:01.0", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC)) + tt("2023-11-13 07:23:01.1", time.Date(2023, 11, 13, 7, 23, 1, 100000000, time.UTC)) + tt("202311-13 07", time.Date(2023, 11, 13, 7, 0, 0, 0, time.UTC)) + tt("20231113 0705", time.Date(2023, 11, 13, 7, 5, 0, 0, time.UTC)) +} diff --git a/tools/utils/mimetypes.go b/tools/utils/mimetypes.go index 80a38935a..b6d01d2f8 100644 --- a/tools/utils/mimetypes.go +++ b/tools/utils/mimetypes.go @@ -11,12 +11,9 @@ import ( "os" "path/filepath" "strings" - "sync" ) var _ = fmt.Print -var user_mime_only_once sync.Once -var user_defined_mime_map map[string]string func load_mime_file(filename string, mime_map map[string]string) error { f, err := os.Open(filename) @@ -45,18 +42,19 @@ func load_mime_file(filename string, mime_map map[string]string) error { return nil } -func load_user_mime_maps() { +var UserMimeMap = (&Once[map[string]string]{Run: func() map[string]string { conf_path := filepath.Join(ConfigDir(), "mime.types") - err := load_mime_file(conf_path, user_defined_mime_map) + ans := make(map[string]string, 32) + err := load_mime_file(conf_path, ans) if err != nil && !errors.Is(err, fs.ErrNotExist) { fmt.Fprintln(os.Stderr, "Failed to parse", conf_path, "for MIME types with error:", err) } -} + return ans +}}).Get func GuessMimeType(filename string) string { - user_mime_only_once.Do(load_user_mime_maps) ext := filepath.Ext(filename) - mime_with_parameters := user_defined_mime_map[ext] + mime_with_parameters := UserMimeMap()[ext] if mime_with_parameters == "" { mime_with_parameters = mime.TypeByExtension(ext) } diff --git a/tools/utils/misc.go b/tools/utils/misc.go index ca5b8702b..b63250f05 100644 --- a/tools/utils/misc.go +++ b/tools/utils/misc.go @@ -57,6 +57,14 @@ func Filter[T any](s []T, f func(x T) bool) []T { return ans } +func Map[T any](s []T, f func(x T) T) []T { + ans := make([]T, 0, len(s)) + for _, x := range s { + ans = append(ans, f(x)) + } + return ans +} + func Sort[T any](s []T, less func(a, b T) bool) []T { sort.Slice(s, func(i, j int) bool { return less(s[i], s[j]) }) return s diff --git a/tools/utils/once.go b/tools/utils/once.go new file mode 100644 index 000000000..03734448d --- /dev/null +++ b/tools/utils/once.go @@ -0,0 +1,35 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package utils + +import ( + "fmt" + "sync" + "sync/atomic" +) + +var _ = fmt.Print + +type Once[T any] struct { + done uint32 + mutex sync.Mutex + cached_val T + + Run func() T +} + +func (self *Once[T]) Get() T { + if atomic.LoadUint32(&self.done) == 0 { + self.do_slow() + } + return self.cached_val +} + +func (self *Once[T]) do_slow() { + self.mutex.Lock() + defer self.mutex.Unlock() + if atomic.LoadUint32(&self.done) == 0 { + defer atomic.StoreUint32(&self.done, 1) + self.cached_val = self.Run() + } +} diff --git a/tools/utils/once_test.go b/tools/utils/once_test.go new file mode 100644 index 000000000..c748dd16a --- /dev/null +++ b/tools/utils/once_test.go @@ -0,0 +1,24 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package utils + +import ( + "fmt" + "testing" +) + +var _ = fmt.Print + +func TestOnce(t *testing.T) { + num := 0 + var G = (&Once[string]{Run: func() string { + num++ + return fmt.Sprintf("%d", num) + }}).Get + G() + G() + G() + if num != 1 { + t.Fatalf("num unexpectedly: %d", num) + } +} diff --git a/tools/utils/paths.go b/tools/utils/paths.go index 7e65281c2..a41d3e78a 100644 --- a/tools/utils/paths.go +++ b/tools/utils/paths.go @@ -3,13 +3,18 @@ package utils import ( + "crypto/rand" + "encoding/base32" + "fmt" "io/fs" + not_rand "math/rand" "os" + "os/exec" "os/user" "path/filepath" "runtime" + "strconv" "strings" - "sync" "golang.org/x/sys/unix" ) @@ -57,61 +62,57 @@ func Abspath(path string) string { return path } -var config_dir, kitty_exe, cache_dir string -var kitty_exe_err error -var config_dir_once, kitty_exe_once, cache_dir_once sync.Once - -func find_kitty_exe() { +var KittyExe = (&Once[string]{Run: func() string { exe, err := os.Executable() if err == nil { - kitty_exe = filepath.Join(filepath.Dir(exe), "kitty") - kitty_exe_err = unix.Access(kitty_exe, unix.X_OK) - } else { - kitty_exe_err = err + return filepath.Join(filepath.Dir(exe), "kitty") } -} + return "" +}}).Get -func KittyExe() (string, error) { - kitty_exe_once.Do(find_kitty_exe) - return kitty_exe, kitty_exe_err -} - -func find_config_dir() { - if os.Getenv("KITTY_CONFIG_DIRECTORY") != "" { - config_dir = Abspath(Expanduser(os.Getenv("KITTY_CONFIG_DIRECTORY"))) - } else { - var locations []string - if os.Getenv("XDG_CONFIG_HOME") != "" { - locations = append(locations, os.Getenv("XDG_CACHE_HOME")) +var ConfigDir = (&Once[string]{Run: func() (config_dir string) { + if kcd := os.Getenv("KITTY_CONFIG_DIRECTORY"); kcd != "" { + return Abspath(Expanduser(kcd)) + } + var locations []string + seen := NewSet[string]() + add := func(x string) { + x = Abspath(Expanduser(x)) + if !seen.Has(x) { + seen.Add(x) + locations = append(locations, x) } - locations = append(locations, Expanduser("~/.config")) - if runtime.GOOS == "darwin" { - locations = append(locations, Expanduser("~/Library/Preferences")) + } + if xh := os.Getenv("XDG_CONFIG_HOME"); xh != "" { + add(xh) + } + if dirs := os.Getenv("XDG_CONFIG_DIRS"); dirs != "" { + for _, candidate := range strings.Split(dirs, ":") { + add(candidate) } - for _, loc := range locations { - if loc != "" { - q := filepath.Join(loc, "kitty") - if _, err := os.Stat(filepath.Join(q, "kitty.conf")); err == nil { - config_dir = q - break - } - } - } - for _, loc := range locations { - if loc != "" { - config_dir = filepath.Join(loc, "kitty") - break + } + add("~/.config") + if runtime.GOOS == "darwin" { + add("~/Library/Preferences") + } + for _, loc := range locations { + if loc != "" { + q := filepath.Join(loc, "kitty") + if _, err := os.Stat(filepath.Join(q, "kitty.conf")); err == nil { + config_dir = q + return } } } -} + config_dir = os.Getenv("XDG_CONFIG_HOME") + if config_dir == "" { + config_dir = "~/.config" + } + config_dir = filepath.Join(Expanduser(config_dir), "kitty") + return +}}).Get -func ConfigDir() string { - config_dir_once.Do(find_config_dir) - return config_dir -} - -func find_cache_dir() { +var CacheDir = (&Once[string]{Run: func() (cache_dir string) { candidate := "" if edir := os.Getenv("KITTY_CACHE_DIRECTORY"); edir != "" { candidate = Abspath(Expanduser(edir)) @@ -125,13 +126,71 @@ func find_cache_dir() { candidate = filepath.Join(Expanduser(candidate), "kitty") } os.MkdirAll(candidate, 0o755) - cache_dir = candidate + return candidate +}}).Get + +func macos_user_cache_dir() string { + // Sadly Go does not provide confstr() so we use this hack. + // Note that given a user generateduid and uid we can derive this by using + // the algorithm at https://github.com/ydkhatri/MacForensics/blob/master/darwin_path_generator.py + // but I cant find a good way to get the generateduid. Requires calling dscl in which case we might as well call getconf + // The data is in /var/db/dslocal/nodes/Default/users/.plist but it needs root + // So instead we use various hacks to get it quickly, falling back to running /usr/bin/getconf + + is_ok := func(m string) bool { + s, err := os.Stat(m) + if err != nil { + return false + } + stat, ok := s.Sys().(unix.Stat_t) + return ok && s.IsDir() && int(stat.Uid) == os.Geteuid() && s.Mode().Perm() == 0o700 && unix.Access(m, unix.X_OK|unix.W_OK|unix.R_OK) == nil + } + + if tdir := strings.TrimRight(os.Getenv("TMPDIR"), "/"); filepath.Base(tdir) == "T" { + if m := filepath.Join(filepath.Dir(tdir), "C"); is_ok(m) { + return m + } + } + + matches, err := filepath.Glob("/private/var/folders/*/*/C") + if err == nil { + for _, m := range matches { + if is_ok(m) { + return m + } + } + } + out, err := exec.Command("/usr/bin/getconf", "DARWIN_USER_CACHE_DIR").Output() + if err == nil { + return strings.TrimRight(strings.TrimSpace(UnsafeBytesToString(out)), "/") + } + return "" } -func CacheDir() string { - cache_dir_once.Do(find_cache_dir) - return cache_dir -} +var RuntimeDir = (&Once[string]{Run: func() (runtime_dir string) { + var candidate string + if q := os.Getenv("KITTY_RUNTIME_DIRECTORY"); q != "" { + candidate = q + } else if runtime.GOOS == "darwin" { + candidate = macos_user_cache_dir() + } else if q := os.Getenv("XDG_RUNTIME_DIR"); q != "" { + candidate = q + } + candidate = strings.TrimRight(candidate, "/") + if candidate == "" { + q := fmt.Sprintf("/run/user/%d", os.Geteuid()) + if s, err := os.Stat(q); err == nil && s.IsDir() && unix.Access(q, unix.X_OK|unix.R_OK|unix.W_OK) == nil { + candidate = q + } else { + candidate = filepath.Join(CacheDir(), "run") + } + } + os.MkdirAll(candidate, 0o700) + if s, err := os.Stat(candidate); err == nil && s.Mode().Perm() != 0o700 { + os.Chmod(candidate, 0o700) + } + return candidate +}}).Get type Walk_callback func(path, abspath string, d fs.DirEntry, err error) error @@ -205,3 +264,21 @@ func WalkWithSymlink(dirpath string, callback Walk_callback, transformers ...fun seen: make(map[string]bool), real_callback: callback, transform_func: transform, needs_recurse_func: needs_symlink_recurse} return sw.walk(dirpath) } + +func RandomFilename() string { + b := []byte{0, 0, 0, 0, 0, 0, 0, 0} + _, err := rand.Read(b) + if err != nil { + return strconv.FormatUint(uint64(not_rand.Uint32()), 16) + } + return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b) + +} + +func ResolveConfPath(path string) string { + cs := os.ExpandEnv(Expanduser(path)) + if !filepath.IsAbs(cs) { + cs = filepath.Join(ConfigDir(), cs) + } + return cs +} diff --git a/tools/utils/paths/well_known.go b/tools/utils/paths/well_known.go new file mode 100644 index 000000000..40490405f --- /dev/null +++ b/tools/utils/paths/well_known.go @@ -0,0 +1,65 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package paths + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "kitty/tools/utils" +) + +var _ = fmt.Print + +type Ctx struct { + home, cwd string +} + +func (ctx *Ctx) SetHome(val string) { + ctx.home = val +} + +func (ctx *Ctx) SetCwd(val string) { + ctx.cwd = val +} + +func (ctx *Ctx) HomePath() (ans string) { + ans = ctx.home + if ans == "" { + ans = utils.Expanduser("~") + } + return +} + +func (ctx *Ctx) CwdPath() (ans string) { + ans = ctx.cwd + if ans == "" { + var err error + ans, err = os.Getwd() + if err != nil { + ans = "." + } + } + return +} + +func abspath(path, base string) (ans string) { + return filepath.Join(base, path) +} + +func (ctx *Ctx) Abspath(path string) (ans string) { + return abspath(path, ctx.CwdPath()) +} + +func (ctx *Ctx) AbspathFromHome(path string) (ans string) { + return abspath(path, ctx.HomePath()) +} + +func (ctx *Ctx) ExpandHome(path string) (ans string) { + if strings.HasPrefix(path, "~/") { + return ctx.AbspathFromHome(path) + } + return path +} diff --git a/tools/utils/secrets/tokens.go b/tools/utils/secrets/tokens.go new file mode 100644 index 000000000..59ba76db7 --- /dev/null +++ b/tools/utils/secrets/tokens.go @@ -0,0 +1,42 @@ +// License: GPLv3 Copyright: 2023, Kovid Goyal, + +package secrets + +import ( + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" +) + +var _ = fmt.Print + +const DEFAULT_NUM_OF_BYTES_FOR_TOKEN = 32 + +func TokenBytes(nbytes ...int) ([]byte, error) { + if len(nbytes) == 0 { + nbytes = []int{DEFAULT_NUM_OF_BYTES_FOR_TOKEN} + } + buf := make([]byte, nbytes[0]) + _, err := rand.Read(buf) + if err != nil { + return nil, err + } + return buf, nil +} + +func TokenHex(nbytes ...int) (string, error) { + b, err := TokenBytes(nbytes...) + if err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +func TokenBase64(nbytes ...int) (string, error) { + b, err := TokenBytes(nbytes...) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(b), nil +} diff --git a/tools/utils/set.go b/tools/utils/set.go index 56b8db0d3..cff09034c 100644 --- a/tools/utils/set.go +++ b/tools/utils/set.go @@ -4,6 +4,8 @@ package utils import ( "fmt" + + "golang.org/x/exp/maps" ) var _ = fmt.Print @@ -22,6 +24,10 @@ func (self *Set[T]) AddItems(val ...T) { } } +func (self *Set[T]) String() string { + return fmt.Sprintf("%#v", maps.Keys(self.items)) +} + func (self *Set[T]) Remove(val T) { delete(self.items, val) } @@ -68,6 +74,25 @@ func (self *Set[T]) Intersect(other *Set[T]) (ans *Set[T]) { return } +func (self *Set[T]) Subtract(other *Set[T]) (ans *Set[T]) { + ans = NewSet[T](self.Len()) + for x := range self.items { + if !other.Has(x) { + ans.items[x] = struct{}{} + } + } + return ans +} + +func (self *Set[T]) IsSubsetOf(other *Set[T]) bool { + for x := range self.items { + if !other.Has(x) { + return false + } + } + return true +} + func NewSet[T comparable](capacity ...int) (ans *Set[T]) { if len(capacity) == 0 { ans = &Set[T]{items: make(map[T]struct{}, 8)} @@ -76,3 +101,9 @@ func NewSet[T comparable](capacity ...int) (ans *Set[T]) { } return } + +func NewSetWithItems[T comparable](items ...T) (ans *Set[T]) { + ans = NewSet[T](len(items)) + ans.AddItems(items...) + return ans +} diff --git a/tools/utils/shm/shm.go b/tools/utils/shm/shm.go index 79a04a716..5450deb41 100644 --- a/tools/utils/shm/shm.go +++ b/tools/utils/shm/shm.go @@ -3,15 +3,16 @@ package shm import ( - "crypto/rand" - "encoding/base32" + "encoding/binary" "errors" "fmt" - not_rand "math/rand" + "io" + "io/fs" "os" - "strconv" "strings" + "kitty/tools/cli" + "golang.org/x/sys/unix" ) @@ -43,15 +44,6 @@ func prefix_and_suffix(pattern string) (prefix, suffix string, err error) { return prefix, suffix, nil } -func next_random() string { - b := make([]byte, 8) - _, err := rand.Read(b) - if err != nil { - return strconv.FormatUint(uint64(not_rand.Uint32()), 16) - } - return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b) -} - type MMap interface { Close() error Unlink() error @@ -59,6 +51,13 @@ type MMap interface { Name() string IsFileSystemBacked() bool FileSystemName() string + Stat() (fs.FileInfo, error) + Flush() error + Seek(offset int64, whence int) (int64, error) + Read(b []byte) (int, error) + ReadWithSize() ([]byte, error) + Write(p []byte) (n int, err error) + WriteWithSize([]byte) error } type AccessFlags int @@ -109,3 +108,78 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) { } return } + +func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) { + p := buf + for len(p) > 0 { + n, err := f.Read(p) + p = p[n:] + if err != nil { + if len(p) == 0 && errors.Is(err, io.EOF) { + err = nil + } + return buf[:len(buf)-len(p)], err + } + } + return buf, nil +} + +func read_with_size(f *os.File) ([]byte, error) { + szbuf := []byte{0, 0, 0, 0} + szbuf, err := read_till_buf_full(f, szbuf) + if err != nil { + return nil, err + } + size := int(binary.BigEndian.Uint32(szbuf)) + return read_till_buf_full(f, make([]byte, size)) +} + +func write_with_size(f *os.File, b []byte) error { + szbuf := []byte{0, 0, 0, 0} + binary.BigEndian.PutUint32(szbuf, uint32(len(b))) + _, err := f.Write(szbuf) + if err == nil { + _, err = f.Write(b) + } + return err +} + +func test_integration_with_python(args []string) (rc int, err error) { + switch args[0] { + default: + return 1, fmt.Errorf("Unknown test type: %s", args[0]) + case "read": + data, err := ReadWithSizeAndUnlink(args[1]) + if err != nil { + return 1, err + } + _, err = os.Stdout.Write(data) + if err != nil { + return 1, err + } + case "write": + data, err := io.ReadAll(os.Stdin) + if err != nil { + return 1, err + } + mmap, err := CreateTemp("shmtest-", uint64(len(data)+4)) + if err != nil { + return 1, err + } + mmap.WriteWithSize(data) + mmap.Close() + fmt.Println(mmap.Name()) + } + return 0, nil +} + +func TestEntryPoint(root *cli.Command) { + root.AddSubCommand(&cli.Command{ + Name: "shm", + OnlyArgsAllowed: true, + Run: func(cmd *cli.Command, args []string) (rc int, err error) { + return test_integration_with_python(args) + }, + }) + +} diff --git a/tools/utils/shm/shm_fs.go b/tools/utils/shm/shm_fs.go index d5866f3b9..15a7d8321 100644 --- a/tools/utils/shm/shm_fs.go +++ b/tools/utils/shm/shm_fs.go @@ -8,10 +8,13 @@ import ( "errors" "fmt" "io/fs" - "kitty/tools/utils" "os" "path/filepath" "runtime" + + "kitty/tools/utils" + + "golang.org/x/sys/unix" ) var _ = fmt.Print @@ -39,6 +42,10 @@ func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool, speci return &file_based_mmap{f: f, region: region, special_name: special_name}, nil } +func (self *file_based_mmap) Stat() (fs.FileInfo, error) { + return self.f.Stat() +} + func (self *file_based_mmap) Name() string { if self.special_name != "" { return self.special_name @@ -46,6 +53,30 @@ func (self *file_based_mmap) Name() string { return filepath.Base(self.f.Name()) } +func (self *file_based_mmap) Flush() error { + return unix.Msync(self.region, unix.MS_SYNC) +} + +func (self *file_based_mmap) Seek(offset int64, whence int) (int64, error) { + return self.f.Seek(offset, whence) +} + +func (self *file_based_mmap) Read(b []byte) (int, error) { + return self.f.Read(b) +} + +func (self *file_based_mmap) Write(b []byte) (int, error) { + return self.f.Write(b) +} + +func (self *file_based_mmap) WriteWithSize(b []byte) error { + return write_with_size(self.f, b) +} + +func (self *file_based_mmap) ReadWithSize() ([]byte, error) { + return read_with_size(self.f) +} + func (self *file_based_mmap) FileSystemName() string { return self.f.Name() } @@ -92,7 +123,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) { var f *os.File try := 0 for { - name := prefix + next_random() + suffix + name := prefix + utils.RandomFilename() + suffix path := file_path_from_name(name) f, err = os.OpenFile(path, os.O_EXCL|os.O_CREATE|os.O_RDWR, 0600) if err != nil { @@ -113,7 +144,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) { return file_mmap(f, size, WRITE, true, special_name) } -func Open(name string, size uint64) (MMap, error) { +func open(name string) (*os.File, error) { ans, err := os.OpenFile(file_path_from_name(name), os.O_RDONLY, 0) if err != nil { if errors.Is(err, fs.ErrNotExist) { @@ -123,5 +154,29 @@ func Open(name string, size uint64) (MMap, error) { } return nil, err } + return ans, nil +} + +func Open(name string, size uint64) (MMap, error) { + ans, err := open(name) + if err != nil { + return nil, err + } return file_mmap(ans, size, READ, false, name) } + +func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) { + f, err := open(name) + if err != nil { + return nil, err + } + defer f.Close() + defer os.Remove(f.Name()) + for _, cb := range file_callback { + err = cb(f) + if err != nil { + return nil, err + } + } + return read_with_size(f) +} diff --git a/tools/utils/shm/shm_syscall.go b/tools/utils/shm/shm_syscall.go index 8974af3f5..dcabe70aa 100644 --- a/tools/utils/shm/shm_syscall.go +++ b/tools/utils/shm/shm_syscall.go @@ -11,6 +11,8 @@ import ( "strings" "unsafe" + "kitty/tools/utils" + "golang.org/x/sys/unix" ) @@ -86,11 +88,38 @@ func syscall_mmap(f *os.File, size uint64, access AccessFlags, truncate bool) (M func (self *syscall_based_mmap) Name() string { return self.f.Name() } +func (self *syscall_based_mmap) Stat() (fs.FileInfo, error) { + return self.f.Stat() +} + +func (self *syscall_based_mmap) Flush() error { + return unix.Msync(self.region, unix.MS_SYNC) +} func (self *syscall_based_mmap) Slice() []byte { return self.region } +func (self *syscall_based_mmap) Seek(offset int64, whence int) (int64, error) { + return self.f.Seek(offset, whence) +} + +func (self *syscall_based_mmap) Read(b []byte) (int, error) { + return self.f.Read(b) +} + +func (self *syscall_based_mmap) Write(b []byte) (int, error) { + return self.f.Write(b) +} + +func (self *syscall_based_mmap) WriteWithSize(b []byte) error { + return write_with_size(self.f, b) +} + +func (self *syscall_based_mmap) ReadWithSize() ([]byte, error) { + return read_with_size(self.f) +} + func (self *syscall_based_mmap) Close() (err error) { if self.region != nil { self.f.Close() @@ -124,7 +153,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) { var f *os.File try := 0 for { - name := prefix + next_random() + suffix + name := prefix + utils.RandomFilename() + suffix if len(name) > SHM_NAME_MAX { return nil, ErrPatternTooLong } @@ -151,3 +180,19 @@ func Open(name string, size uint64) (MMap, error) { } return syscall_mmap(ans, size, READ, false) } + +func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) { + f, err := shm_open(name, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + defer f.Close() + defer shm_unlink(f.Name()) + for _, cb := range file_callback { + err = cb(f) + if err != nil { + return nil, err + } + } + return read_with_size(f) +} diff --git a/tools/utils/shm/shm_test.go b/tools/utils/shm/shm_test.go index 6d6c2817e..086731f8e 100644 --- a/tools/utils/shm/shm_test.go +++ b/tools/utils/shm/shm_test.go @@ -23,6 +23,10 @@ func TestSHM(t *testing.T) { } copy(mm.Slice(), data) + err = mm.Flush() + if err != nil { + t.Fatalf("Failed to msync() with error: %v", err) + } err = mm.Close() if err != nil { t.Fatalf("Failed to close with error: %v", err) diff --git a/tools/utils/style/wrapper.go b/tools/utils/style/wrapper.go index fd94c8a7a..e8cdf969f 100644 --- a/tools/utils/style/wrapper.go +++ b/tools/utils/style/wrapper.go @@ -57,6 +57,10 @@ type RGBA struct { Red, Green, Blue, Inverse_alpha uint8 } +func (self RGBA) AsRGBSharp() string { + return fmt.Sprintf("#%02x%02x%02x", self.Red, self.Green, self.Blue) +} + func (self *RGBA) parse_rgb_strings(r string, g string, b string) bool { var rv, gv, bv uint64 var err error @@ -77,6 +81,12 @@ func (self *RGBA) AsRGB() uint32 { return uint32(self.Blue) | (uint32(self.Green) << 8) | (uint32(self.Red) << 16) } +func (self *RGBA) FromRGB(col uint32) { + self.Red = uint8((col >> 16) & 0xff) + self.Green = uint8((col >> 8) & 0xff) + self.Blue = uint8((col) & 0xff) +} + type color_type struct { is_numbered bool val RGBA diff --git a/tools/utils/which.go b/tools/utils/which.go index 184221803..1a4c194d9 100644 --- a/tools/utils/which.go +++ b/tools/utils/which.go @@ -13,15 +13,18 @@ import ( var _ = fmt.Print -func Which(cmd string) string { +func Which(cmd string, paths ...string) string { if strings.Contains(cmd, string(os.PathSeparator)) { return "" } - path := os.Getenv("PATH") - if path == "" { - return "" + if len(paths) == 0 { + path := os.Getenv("PATH") + if path == "" { + return "" + } + paths = strings.Split(path, string(os.PathListSeparator)) } - for _, dir := range strings.Split(path, string(os.PathListSeparator)) { + for _, dir := range paths { q := filepath.Join(dir, cmd) if unix.Access(q, unix.X_OK) == nil { s, err := os.Stat(q)