Merge branch 'ssh'

This commit is contained in:
Kovid Goyal 2023-02-28 12:45:51 +05:30
commit 9135ba138e
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
69 changed files with 4521 additions and 1567 deletions

View File

@ -62,6 +62,9 @@ Detailed list of changes
- macOS: Fix the maximized window not taking up full space when the title bar is hidden or when :opt:`resize_in_steps` is configured (:iss:`6021`) - macOS: Fix the maximized window not taking up full space when the title bar is hidden or when :opt:`resize_in_steps` is configured (:iss:`6021`)
- ssh kitten: Change the syntax of glob patterns slightly to match common usage
elsewhere. Now the syntax is the same a "extendedglob" in most shells.
0.27.1 [2023-02-07] 0.27.1 [2023-02-07]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -226,15 +226,16 @@ def commit_role(
# CLI docs {{{ # CLI docs {{{
def write_cli_docs(all_kitten_names: Iterable[str]) -> None: def write_cli_docs(all_kitten_names: Iterable[str]) -> None:
from kittens.ssh.copy import option_text from kittens.ssh.main import copy_message, option_text
from kittens.ssh.options.definition import copy_message
from kitty.cli import option_spec_as_rst from kitty.cli import option_spec_as_rst
from kitty.launch import options_spec as launch_options_spec
with open('generated/ssh-copy.rst', 'w') as f: with open('generated/ssh-copy.rst', 'w') as f:
f.write(option_spec_as_rst( f.write(option_spec_as_rst(
appname='copy', ospec=option_text, heading_char='^', appname='copy', ospec=option_text, heading_char='^',
usage='file-or-dir-to-copy ...', message=copy_message usage='file-or-dir-to-copy ...', message=copy_message
)) ))
del sys.modules['kittens.ssh.main']
from kitty.launch import options_spec as launch_options_spec
with open('generated/launch.rst', 'w') as f: with open('generated/launch.rst', 'w') as f:
f.write(option_spec_as_rst( f.write(option_spec_as_rst(
appname='launch', ospec=launch_options_spec, heading_char='_', appname='launch', ospec=launch_options_spec, heading_char='_',
@ -525,9 +526,9 @@ def write_conf_docs(app: Any, all_kitten_names: Iterable[str]) -> None:
from kittens.runner import get_kitten_conf_docs from kittens.runner import get_kitten_conf_docs
for kitten in all_kitten_names: for kitten in all_kitten_names:
definition = get_kitten_conf_docs(kitten) defn = get_kitten_conf_docs(kitten)
if definition: if defn is not None:
generate_default_config(definition, f'kitten-{kitten}') generate_default_config(defn, f'kitten-{kitten}')
from kitty.actions import as_rst from kitty.actions import as_rst
with open('generated/actions.rst', 'w', encoding='utf-8') as f: with open('generated/actions.rst', 'w', encoding='utf-8') as f:

View File

@ -51,8 +51,6 @@ def main() -> None:
from kittens.diff.options.definition import definition as kd from kittens.diff.options.definition import definition as kd
write_output('kittens.diff', kd) write_output('kittens.diff', kd)
from kittens.ssh.options.definition import definition as sd
write_output('kittens.ssh', sd)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,15 +1,17 @@
#!./kitty/launcher/kitty +launch #!./kitty/launcher/kitty +launch
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net> # License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import bz2
import io import io
import json import json
import os import os
import struct import struct
import subprocess import subprocess
import sys import sys
import zlib import tarfile
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from functools import lru_cache from functools import lru_cache
from itertools import chain
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Set, TextIO, Tuple, Union from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Set, TextIO, Tuple, Union
import kitty.constants as kc import kitty.constants as kc
@ -22,6 +24,8 @@ from kitty.cli import (
parse_option_spec, parse_option_spec,
serialize_as_go_string, serialize_as_go_string,
) )
from kitty.conf.generate import gen_go_code
from kitty.conf.types import Definition
from kitty.guess_mime_type import text_mimes from kitty.guess_mime_type import text_mimes
from kitty.key_encoding import config_mod_map from kitty.key_encoding import config_mod_map
from kitty.key_names import character_key_name_aliases, functional_key_name_aliases from kitty.key_names import character_key_name_aliases, functional_key_name_aliases
@ -38,7 +42,7 @@ def newer(dest: str, *sources: str) -> bool:
dtime = os.path.getmtime(dest) dtime = os.path.getmtime(dest)
except OSError: except OSError:
return True return True
for s in sources: for s in chain(sources, (__file__,)):
with suppress(FileNotFoundError): with suppress(FileNotFoundError):
if os.path.getmtime(s) >= dtime: if os.path.getmtime(s) >= dtime:
return True return True
@ -318,8 +322,45 @@ def wrapped_kittens() -> Sequence[str]:
raise Exception('Failed to read wrapped kittens from kitty wrapper script') raise Exception('Failed to read wrapped kittens from kitty wrapper script')
def generate_conf_parser(kitten: str, defn: Definition) -> None:
with replace_if_needed(f'tools/cmd/{kitten}/conf_generated.go'):
print(f'package {kitten}')
print(gen_go_code(defn))
def generate_extra_cli_parser(name: str, spec: str) -> None:
print('import "kitty/tools/cli"')
go_opts = tuple(go_options_for_seq(parse_option_spec(spec)[0]))
print(f'type {name}_options struct ''{')
for opt in go_opts:
print(opt.struct_declaration())
print('}')
print(f'func parse_{name}_args(args []string) (*{name}_options, []string, error) ''{')
print(f'root := cli.Command{{Name: `{name}` }}')
for opt in go_opts:
print(opt.as_option('root'))
print('cmd, err := root.ParseArgs(args)')
print('if err != nil { return nil, nil, err }')
print(f'var opts {name}_options')
print('err = cmd.GetOptionValues(&opts)')
print('if err != nil { return nil, nil, err }')
print('return &opts, cmd.Args, nil')
print('}')
def kitten_clis() -> None: def kitten_clis() -> None:
from kittens.runner import get_kitten_conf_docs, get_kitten_extra_cli_parsers
for kitten in wrapped_kittens(): for kitten in wrapped_kittens():
defn = get_kitten_conf_docs(kitten)
if defn is not None:
generate_conf_parser(kitten, defn)
ecp = get_kitten_extra_cli_parsers(kitten)
if ecp:
for name, spec in ecp.items():
with replace_if_needed(f'tools/cmd/{kitten}/{name}_cli_generated.go'):
print(f'package {kitten}')
generate_extra_cli_parser(name, spec)
with replace_if_needed(f'tools/cmd/{kitten}/cli_generated.go'): with replace_if_needed(f'tools/cmd/{kitten}/cli_generated.go'):
od = [] od = []
kcd = kitten_cli_docs(kitten) kcd = kitten_cli_docs(kitten)
@ -329,10 +370,11 @@ def kitten_clis() -> None:
print('func create_cmd(root *cli.Command, run_func func(*cli.Command, *Options, []string)(int, error)) {') print('func create_cmd(root *cli.Command, run_func func(*cli.Command, *Options, []string)(int, error)) {')
print('ans := root.AddSubCommand(&cli.Command{') print('ans := root.AddSubCommand(&cli.Command{')
print(f'Name: "{kitten}",') print(f'Name: "{kitten}",')
print(f'ShortDescription: "{serialize_as_go_string(kcd["short_desc"])}",') if kcd:
if kcd['usage']: print(f'ShortDescription: "{serialize_as_go_string(kcd["short_desc"])}",')
print(f'Usage: "[options] {serialize_as_go_string(kcd["usage"])}",') if kcd['usage']:
print(f'HelpText: "{serialize_as_go_string(kcd["help_text"])}",') print(f'Usage: "[options] {serialize_as_go_string(kcd["usage"])}",')
print(f'HelpText: "{serialize_as_go_string(kcd["help_text"])}",')
print('Run: func(cmd *cli.Command, args []string) (int, error) {') print('Run: func(cmd *cli.Command, args []string) (int, error) {')
print('opts := Options{}') print('opts := Options{}')
print('err := cmd.GetOptionValues(&opts)') print('err := cmd.GetOptionValues(&opts)')
@ -351,6 +393,8 @@ def kitten_clis() -> None:
print("clone := root.AddClone(ans.Group, ans)") print("clone := root.AddClone(ans.Group, ans)")
print('clone.Hidden = false') print('clone.Hidden = false')
print(f'clone.Name = "{serialize_as_go_string(kitten.replace("_", "-"))}"') print(f'clone.Name = "{serialize_as_go_string(kitten.replace("_", "-"))}"')
if not kcd:
print('specialize_command(ans)')
print('}') print('}')
print('type Options struct {') print('type Options struct {')
print('\n'.join(od)) print('\n'.join(od))
@ -383,11 +427,24 @@ def generate_spinners() -> str:
def generate_color_names() -> str: def generate_color_names() -> str:
selfg = "" if Options.selection_foreground is None else Options.selection_foreground.as_sharp
selbg = "" if Options.selection_background is None else Options.selection_background.as_sharp
cursor = "" if Options.cursor is None else Options.cursor.as_sharp
return 'package style\n\nvar ColorNames = map[string]RGBA{' + '\n'.join( return 'package style\n\nvar ColorNames = map[string]RGBA{' + '\n'.join(
f'\t"{name}": RGBA{{ Red:{val.red}, Green:{val.green}, Blue:{val.blue} }},' f'\t"{name}": RGBA{{ Red:{val.red}, Green:{val.green}, Blue:{val.blue} }},'
for name, val in color_names.items() for name, val in color_names.items()
) + '\n}' + '\n\nvar ColorTable = [256]uint32{' + ', '.join( ) + '\n}' + '\n\nvar ColorTable = [256]uint32{' + ', '.join(
f'{x}' for x in Options.color_table) + '}\n' f'{x}' for x in Options.color_table) + '}\n' + f'''
var DefaultColors = struct {{
Foreground, Background, Cursor, SelectionFg, SelectionBg string
}}{{
Foreground: "{Options.foreground.as_sharp}",
Background: "{Options.background.as_sharp}",
Cursor: "{cursor}",
SelectionFg: "{selfg}",
SelectionBg: "{selbg}",
}}
'''
def load_ref_map() -> Dict[str, Dict[str, str]]: def load_ref_map() -> Dict[str, Dict[str, str]]:
@ -399,6 +456,8 @@ def load_ref_map() -> Dict[str, Dict[str, str]]:
def generate_constants() -> str: def generate_constants() -> str:
from kitty.options.types import Options
from kitty.options.utils import allowed_shell_integration_values
ref_map = load_ref_map() ref_map = load_ref_map()
dp = ", ".join(map(lambda x: f'"{serialize_as_go_string(x)}"', kc.default_pager_for_help)) dp = ", ".join(map(lambda x: f'"{serialize_as_go_string(x)}"', kc.default_pager_for_help))
return f'''\ return f'''\
@ -410,6 +469,7 @@ type VersionType struct {{
const VersionString string = "{kc.str_version}" const VersionString string = "{kc.str_version}"
const WebsiteBaseURL string = "{kc.website_base_url}" const WebsiteBaseURL string = "{kc.website_base_url}"
const VCSRevision string = "" const VCSRevision string = ""
const SSHControlMasterTemplate = "{kc.ssh_control_master_template}"
const RC_ENCRYPTION_PROTOCOL_VERSION string = "{kc.RC_ENCRYPTION_PROTOCOL_VERSION}" const RC_ENCRYPTION_PROTOCOL_VERSION string = "{kc.RC_ENCRYPTION_PROTOCOL_VERSION}"
const IsFrozenBuild bool = false const IsFrozenBuild bool = false
const IsStandaloneBuild bool = false const IsStandaloneBuild bool = false
@ -421,6 +481,12 @@ var CharacterKeyNameAliases = map[string]string{serialize_go_dict(character_key_
var ConfigModMap = map[string]uint16{serialize_go_dict(config_mod_map)} var ConfigModMap = map[string]uint16{serialize_go_dict(config_mod_map)}
var RefMap = map[string]string{serialize_go_dict(ref_map['ref'])} var RefMap = map[string]string{serialize_go_dict(ref_map['ref'])}
var DocTitleMap = map[string]string{serialize_go_dict(ref_map['doc'])} var DocTitleMap = map[string]string{serialize_go_dict(ref_map['doc'])}
var AllowedShellIntegrationValues = []string{{ {str(sorted(allowed_shell_integration_values))[1:-1].replace("'", '"')} }}
var KittyConfigDefaults = struct {{
Term, Shell_integration string
}}{{
Term: "{Options.term}", Shell_integration: "{' '.join(Options.shell_integration)}",
}}
''' # }}} ''' # }}}
@ -598,6 +664,11 @@ def generate_textual_mimetypes() -> str:
return '\n'.join(ans) return '\n'.join(ans)
def write_compressed_data(data: bytes, d: BinaryIO) -> None:
d.write(struct.pack('<I', len(data)))
d.write(bz2.compress(data))
def generate_unicode_names(src: TextIO, dest: BinaryIO) -> None: def generate_unicode_names(src: TextIO, dest: BinaryIO) -> None:
num_names, num_of_words = map(int, next(src).split()) num_names, num_of_words = map(int, next(src).split())
gob = io.BytesIO() gob = io.BytesIO()
@ -612,9 +683,31 @@ def generate_unicode_names(src: TextIO, dest: BinaryIO) -> None:
if aliases: if aliases:
record += aliases.encode() record += aliases.encode()
gob.write(struct.pack('<H', len(record)) + record) gob.write(struct.pack('<H', len(record)) + record)
data = gob.getvalue() write_compressed_data(gob.getvalue(), dest)
dest.write(struct.pack('<I', len(data)))
dest.write(zlib.compress(data, zlib.Z_BEST_COMPRESSION))
def generate_ssh_kitten_data() -> None:
files = {
'terminfo/kitty.terminfo', 'terminfo/x/xterm-kitty',
}
for dirpath, dirnames, filenames in os.walk('shell-integration'):
for f in filenames:
path = os.path.join(dirpath, f)
files.add(path.replace(os.sep, '/'))
dest = 'tools/cmd/ssh/data_generated.bin'
def normalize(t: tarfile.TarInfo) -> tarfile.TarInfo:
t.uid = t.gid = 0
t.uname = t.gname = ''
return t
if newer(dest, *files):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode='w') as tf:
for f in sorted(files):
tf.add(f, filter=normalize)
with open(dest, 'wb') as d:
write_compressed_data(buf.getvalue(), d)
def main() -> None: def main() -> None:
@ -633,6 +726,7 @@ def main() -> None:
if newer('tools/unicode_names/data_generated.bin', 'tools/unicode_names/names.txt'): if newer('tools/unicode_names/data_generated.bin', 'tools/unicode_names/names.txt'):
with open('tools/unicode_names/data_generated.bin', 'wb') as dest, open('tools/unicode_names/names.txt') as src: with open('tools/unicode_names/data_generated.bin', 'wb') as dest, open('tools/unicode_names/names.txt') as src:
generate_unicode_names(src, dest) generate_unicode_names(src, dest)
generate_ssh_kitten_data()
update_completion() update_completion()
update_at_commands() update_at_commands()

14
go.mod
View File

@ -4,14 +4,24 @@ go 1.20
require ( require (
github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924 github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924
github.com/bmatcuk/doublestar v1.3.4
github.com/disintegration/imaging v1.6.2 github.com/disintegration/imaging v1.6.2
github.com/google/go-cmp v0.5.8 github.com/google/go-cmp v0.5.9
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f
github.com/seancfoley/ipaddress-go v1.5.3 github.com/seancfoley/ipaddress-go v1.5.3
github.com/shirou/gopsutil/v3 v3.23.1
golang.org/x/exp v0.0.0-20230202163644-54bba9f4231b golang.org/x/exp v0.0.0-20230202163644-54bba9f4231b
golang.org/x/image v0.5.0 golang.org/x/image v0.5.0
golang.org/x/sys v0.4.0 golang.org/x/sys v0.4.0
) )
require github.com/seancfoley/bintree v1.2.1 // indirect require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/seancfoley/bintree v1.2.1 // indirect
github.com/tklauser/go-sysconf v0.3.11 // indirect
github.com/tklauser/numcpus v0.6.0 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect
)

41
go.sum
View File

@ -1,18 +1,47 @@
github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924 h1:DG4UyTVIujioxwJc8Zj8Nabz1L1wTgQ/xNBSQDfdP3I= github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924 h1:DG4UyTVIujioxwJc8Zj8Nabz1L1wTgQ/xNBSQDfdP3I=
github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924/go.mod h1:+NaH2gLeY6RPBPPQf4aRotPPStg+eXc8f9ZaE4vRfD4= github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924/go.mod h1:+NaH2gLeY6RPBPPQf4aRotPPStg+eXc8f9ZaE4vRfD4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f h1:Ko4+g6K16vSyUrtd/pPXuQnWsiHe5BYptEtTxfwYwCc= github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f h1:Ko4+g6K16vSyUrtd/pPXuQnWsiHe5BYptEtTxfwYwCc=
github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f/go.mod h1:eHzfhOKbTGJEGPSdMHzU6jft192tHHt2Bu2vIZArvC0= github.com/jamesruan/go-rfc1924 v0.0.0-20170108144916-2767ca7c638f/go.mod h1:eHzfhOKbTGJEGPSdMHzU6jft192tHHt2Bu2vIZArvC0=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/seancfoley/bintree v1.2.1 h1:Z/iNjRKkXnn0CTW7jDQYtjW5fz2GH1yWvOTJ4MrMvdo= github.com/seancfoley/bintree v1.2.1 h1:Z/iNjRKkXnn0CTW7jDQYtjW5fz2GH1yWvOTJ4MrMvdo=
github.com/seancfoley/bintree v1.2.1/go.mod h1:hIUabL8OFYyFVTQ6azeajbopogQc2l5C/hiXMcemWNU= github.com/seancfoley/bintree v1.2.1/go.mod h1:hIUabL8OFYyFVTQ6azeajbopogQc2l5C/hiXMcemWNU=
github.com/seancfoley/ipaddress-go v1.5.3 h1:fLnn4nsatd2rp3IJsVWriXv5gXn2Qiy8uxjxe4iZtTg= github.com/seancfoley/ipaddress-go v1.5.3 h1:fLnn4nsatd2rp3IJsVWriXv5gXn2Qiy8uxjxe4iZtTg=
github.com/seancfoley/ipaddress-go v1.5.3/go.mod h1:fpvVPC+Jso+YEhNcNiww8HQmBgKP8T4T6BTp1SLxxIo= github.com/seancfoley/ipaddress-go v1.5.3/go.mod h1:fpvVPC+Jso+YEhNcNiww8HQmBgKP8T4T6BTp1SLxxIo=
github.com/shirou/gopsutil/v3 v3.23.1 h1:a9KKO+kGLKEvcPIs4W62v0nu3sciVDOOOPUD0Hz7z/4=
github.com/shirou/gopsutil/v3 v3.23.1/go.mod h1:NN6mnm5/0k8jw4cBfCnJtr5L7ErOTg18tMNpgFkn0hA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
github.com/tklauser/go-sysconf v0.3.11/go.mod h1:GqXfhXY3kiPa0nAXPDIQIWzJbMCB7AmcWpGR8lSZfqI=
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
github.com/tklauser/numcpus v0.6.0/go.mod h1:FEZLMke0lhOUG6w2JadTzp0a+Nl8PF/GFkQ5UVIcaL4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg=
github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20230202163644-54bba9f4231b h1:EqBVA+nNsObCwQoBEHy4wLU0pi7i8a4AL3pbItPdPkE= golang.org/x/exp v0.0.0-20230202163644-54bba9f4231b h1:EqBVA+nNsObCwQoBEHy4wLU0pi7i8a4AL3pbItPdPkE=
@ -27,10 +56,13 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
@ -43,3 +75,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -7,7 +7,7 @@ import os
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Generator, List, cast from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Generator, List, Optional, cast
from kitty.constants import list_kitty_resources from kitty.constants import list_kitty_resources
from kitty.types import run_once from kitty.types import run_once
@ -171,7 +171,7 @@ def get_kitten_completer(kitten: str) -> Any:
return ans return ans
def get_kitten_conf_docs(kitten: str) -> Definition: def get_kitten_conf_docs(kitten: str) -> Optional[Definition]:
setattr(sys, 'options_definition', None) setattr(sys, 'options_definition', None)
run_kitten(kitten, run_name='__conf__') run_kitten(kitten, run_name='__conf__')
ans = getattr(sys, 'options_definition') ans = getattr(sys, 'options_definition')
@ -179,6 +179,14 @@ def get_kitten_conf_docs(kitten: str) -> Definition:
return cast(Definition, ans) return cast(Definition, ans)
def get_kitten_extra_cli_parsers(kitten: str) -> Dict[str,str]:
setattr(sys, 'extra_cli_parsers', {})
run_kitten(kitten, run_name='__extra_cli_parsers__')
ans = getattr(sys, 'extra_cli_parsers')
delattr(sys, 'extra_cli_parsers')
return cast(Dict[str, str], ans)
def main() -> None: def main() -> None:
try: try:
args = sys.argv[1:] args = sys.argv[1:]

View File

@ -1,70 +0,0 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import fnmatch
import os
from typing import Any, Dict, Iterable, Optional
from kitty.conf.utils import load_config as _load_config
from kitty.conf.utils import parse_config_base, resolve_config
from kitty.constants import config_dir
from .options.types import Options as SSHOptions
from .options.types import defaults
SYSTEM_CONF = '/etc/xdg/kitty/ssh.conf'
defconf = os.path.join(config_dir, 'ssh.conf')
def host_matches(mpat: str, hostname: str, username: str) -> bool:
for pat in mpat.split():
upat = '*'
if '@' in pat:
upat, pat = pat.split('@', 1)
if fnmatch.fnmatchcase(hostname, pat) and fnmatch.fnmatchcase(username, upat):
return True
return False
def load_config(*paths: str, overrides: Optional[Iterable[str]] = None, hostname: str = '!', username: str = '') -> SSHOptions:
from .options.parse import create_result_dict, merge_result_dicts, parse_conf_item
from .options.utils import first_seen_positions, get_per_hosts_dict, init_results_dict
def merge_dicts(base: Dict[str, Any], vals: Dict[str, Any]) -> Dict[str, Any]:
base_phd = get_per_hosts_dict(base)
vals_phd = get_per_hosts_dict(vals)
for hostname in base_phd:
vals_phd[hostname] = merge_result_dicts(base_phd[hostname], vals_phd.get(hostname, {}))
ans: Dict[str, Any] = vals_phd.pop(vals['hostname'])
ans['per_host_dicts'] = vals_phd
return ans
def parse_config(lines: Iterable[str]) -> Dict[str, Any]:
ans: Dict[str, Any] = init_results_dict(create_result_dict())
parse_config_base(lines, parse_conf_item, ans)
return ans
overrides = tuple(overrides) if overrides is not None else ()
first_seen_positions.clear()
first_seen_positions['*'] = 0
opts_dict, paths = _load_config(
defaults, parse_config, merge_dicts, *paths, overrides=overrides, initialize_defaults=init_results_dict)
phd = get_per_hosts_dict(opts_dict)
final_dict: Dict[str, Any] = {}
for hostname_pat in sorted(phd, key=first_seen_positions.__getitem__):
if host_matches(hostname_pat, hostname, username):
od = phd[hostname_pat]
for k, v in od.items():
if isinstance(v, dict):
bv = final_dict.setdefault(k, {})
bv.update(v)
else:
final_dict[k] = v
first_seen_positions.clear()
return SSHOptions(final_dict)
def init_config(hostname: str, username: str, overrides: Optional[Iterable[str]] = None) -> SSHOptions:
config = tuple(resolve_config(SYSTEM_CONF, defconf))
return load_config(*config, overrides=overrides, hostname=hostname, username=username)

View File

@ -1,117 +0,0 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import glob
import os
import shlex
import uuid
from typing import Dict, Iterable, Iterator, List, NamedTuple, Optional, Sequence, Tuple
from kitty.cli import parse_args
from kitty.cli_stub import CopyCLIOptions
from kitty.types import run_once
from ..transfer.utils import expand_home, home_path
@run_once
def option_text() -> str:
return '''
--glob
type=bool-set
Interpret file arguments as glob patterns.
--dest
The destination on the remote host to copy to. Relative paths are resolved
relative to HOME on the remote host. When this option is not specified, the
local file path is used as the remote destination (with the HOME directory
getting automatically replaced by the remote HOME). Note that environment
variables and ~ are not expanded.
--exclude
type=list
A glob pattern. Files with names matching this pattern are excluded from being
transferred. Useful when adding directories. Can
be specified multiple times, if any of the patterns match the file will be
excluded. To exclude a directory use a pattern like */directory_name/*.
--symlink-strategy
default=preserve
choices=preserve,resolve,keep-path
Control what happens if the specified path is a symlink. The default is to preserve
the symlink, re-creating it on the remote machine. Setting this to :code:`resolve`
will cause the symlink to be followed and its target used as the file/directory to copy.
The value of :code:`keep-path` is the same as :code:`resolve` except that the remote
file path is derived from the symlink's path instead of the path of the symlink's target.
Note that this option does not apply to symlinks encountered while recursively copying directories.
'''
def parse_copy_args(args: Optional[Sequence[str]] = None) -> Tuple[CopyCLIOptions, List[str]]:
args = list(args or ())
try:
opts, args = parse_args(result_class=CopyCLIOptions, args=args, ospec=option_text)
except SystemExit as e:
raise CopyCLIError from e
return opts, args
def resolve_file_spec(spec: str, is_glob: bool) -> Iterator[str]:
ans = os.path.expandvars(expand_home(spec))
if not os.path.isabs(ans):
ans = expand_home(f'~/{ans}')
if is_glob:
files = glob.glob(ans)
if not files:
raise CopyCLIError(f'{spec} does not exist')
else:
if not os.path.exists(ans):
raise CopyCLIError(f'{spec} does not exist')
files = [ans]
for x in files:
yield os.path.normpath(x).replace(os.sep, '/')
class CopyCLIError(ValueError):
pass
def get_arcname(loc: str, dest: Optional[str], home: str) -> str:
if dest:
arcname = dest
else:
arcname = os.path.normpath(loc)
if arcname.startswith(home):
arcname = os.path.relpath(arcname, home)
arcname = os.path.normpath(arcname).replace(os.sep, '/')
prefix = 'root' if arcname.startswith('/') else 'home/'
return prefix + arcname
class CopyInstruction(NamedTuple):
local_path: str
arcname: str
exclude_patterns: Tuple[str, ...]
def parse_copy_instructions(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]:
opts, args = parse_copy_args(shlex.split(val))
locations: List[str] = []
for a in args:
locations.extend(resolve_file_spec(a, opts.glob))
if not locations:
raise CopyCLIError('No files to copy specified')
if len(locations) > 1 and opts.dest:
raise CopyCLIError('Specifying a remote location with more than one file is not supported')
home = home_path()
for loc in locations:
if opts.symlink_strategy != 'preserve':
rp = os.path.realpath(loc)
else:
rp = loc
arcname = get_arcname(rp if opts.symlink_strategy == 'resolve' else loc, opts.dest, home)
yield str(uuid.uuid4()), CopyInstruction(rp, arcname, tuple(opts.exclude))

View File

@ -1,728 +1,214 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# License: GPL v3 Copyright: 2018, Kovid Goyal <kovid at kovidgoyal.net> # License: GPL v3 Copyright: 2018, Kovid Goyal <kovid at kovidgoyal.net>
import fnmatch
import glob
import io
import json
import os
import re
import secrets
import shlex
import shutil
import stat
import subprocess
import sys import sys
import tarfile from typing import List, Optional
import tempfile
import termios
import time
import traceback
from base64 import standard_b64decode, standard_b64encode
from contextlib import contextmanager, suppress
from getpass import getuser
from select import select
from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Tuple, Union, cast
from kitty.constants import cache_dir, runtime_dir, shell_integration_dir, ssh_control_master_template, str_version, terminfo_dir from kitty.conf.types import Definition
from kitty.shell_integration import as_str_literal
from kitty.shm import SharedMemory
from kitty.types import run_once from kitty.types import run_once
from kitty.utils import SSHConnectionData, expandvars, resolve_abs_or_config_path
from kitty.utils import set_echo as turn_off_echo
from ..tui.operations import RESTORE_PRIVATE_MODE_VALUES, SAVE_PRIVATE_MODE_VALUES, Mode, restore_colors, save_colors, set_mode copy_message = '''\
from ..tui.utils import kitty_opts, running_in_tmux Copy files and directories from local to remote hosts. The specified files are
from .config import init_config assumed to be relative to the HOME directory and copied to the HOME on the
from .copy import CopyInstruction remote. Directories are copied recursively. If absolute paths are used, they are
from .options.types import Options as SSHOptions copied as is.'''
from .options.utils import DELETE_ENV_VAR
from .utils import create_shared_memory, get_ssh_cli, is_extra_arg, passthrough_args
@run_once @run_once
def ssh_exe() -> str: def option_text() -> str:
return shutil.which('ssh') or 'ssh' return '''
--glob
type=bool-set
def read_data_from_shared_memory(shm_name: str) -> Any: Interpret file arguments as glob patterns. Globbing is based on
with SharedMemory(shm_name, readonly=True) as shm: Based on standard wildcards with the addition that ``/**/`` matches any number of directories.
shm.unlink() See the :link:`detailed syntax <https://github.com/bmatcuk/doublestar#patterns>`.
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile')
mode = stat.S_IMODE(shm.stats.st_mode) --dest
if mode != stat.S_IREAD | stat.S_IWRITE: The destination on the remote host to copy to. Relative paths are resolved
raise ValueError('Incorrect permissions on pwfile') relative to HOME on the remote host. When this option is not specified, the
return json.loads(shm.read_data_with_size()) local file path is used as the remote destination (with the HOME directory
getting automatically replaced by the remote HOME). Note that environment
variables and ~ are not expanded.
# See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
quote_pat = re.compile('([\\`"])')
--exclude
type=list
def quote_env_val(x: str, literal_quote: bool = False) -> str: A glob pattern. Files with names matching this pattern are excluded from being
if literal_quote: transferred. Useful when adding directories. Can
return as_str_literal(x) be specified multiple times, if any of the patterns match the file will be
x = quote_pat.sub(r'\\\1', x) excluded. If the pattern includes a :code:`/` then it will match against the full
x = x.replace('$(', r'\$(') # prevent execution with $() path, not just the filename. In such patterns you can use :code:`/**/` to match zero
return f'"{x}"' or more directories. For example, to exclude a directory and everything under it use
:code:`**/directory_name`.
See the :link:`detailed syntax <https://github.com/bmatcuk/doublestar#patterns>` for
def serialize_env(literal_env: Dict[str, str], env: Dict[str, str], base_env: Dict[str, str], for_python: bool = False) -> bytes: how wildcards match.
lines = []
literal_quote = True
--symlink-strategy
if for_python: default=preserve
def a(k: str, val: str = '', prefix: str = 'export') -> None: choices=preserve,resolve,keep-path
if val: Control what happens if the specified path is a symlink. The default is to preserve
lines.append(f'{prefix} {json.dumps((k, val, literal_quote))}') the symlink, re-creating it on the remote machine. Setting this to :code:`resolve`
else: will cause the symlink to be followed and its target used as the file/directory to copy.
lines.append(f'{prefix} {json.dumps((k,))}') The value of :code:`keep-path` is the same as :code:`resolve` except that the remote
else: file path is derived from the symlink's path instead of the path of the symlink's target.
def a(k: str, val: str = '', prefix: str = 'export') -> None: Note that this option does not apply to symlinks encountered while recursively copying directories,
if val: those are always preserved.
lines.append(f'{prefix} {shlex.quote(k)}={quote_env_val(val, literal_quote)}') '''
else:
lines.append(f'{prefix} {shlex.quote(k)}')
for k, v in literal_env.items(): definition = Definition(
a(k, v) '!kittens.ssh',
)
literal_quote = False
for k in sorted(env): agr = definition.add_group
v = env[k] egr = definition.end_group
if v == DELETE_ENV_VAR: opt = definition.add_option
a(k, prefix='unset')
elif v == '_kitty_copy_env_var_': agr('bootstrap', 'Host bootstrap configuration') # {{{
q = base_env.get(k)
if q is not None: opt('hostname', '*', long_text='''
a(k, q) The hostname that the following options apply to. A glob pattern to match
else: multiple hosts can be used. Multiple hostnames can also be specified, separated
a(k, v) by spaces. The hostname can include an optional username in the form
return '\n'.join(lines).encode('utf-8') :code:`user@host`. When not specified options apply to all hosts, until the
first hostname specification is found. Note that matching of hostname is done
against the name you specify on the command line to connect to the remote host.
def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str], compression: str = 'gz', literal_env: Dict[str, str] = {}) -> bytes: If you wish to include the same basic configuration for many different hosts,
you can do so with the :ref:`include <include>` directive.
def normalize_tarinfo(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo: ''')
tarinfo.uname = tarinfo.gname = ''
tarinfo.uid = tarinfo.gid = 0 opt('interpreter', 'sh', long_text='''
# some distro's like nix mess with installed file permissions so ensure The interpreter to use on the remote host. Must be either a POSIX complaint
# files are at least readable and writable by owning user shell or a :program:`python` executable. If the default :program:`sh` is not
tarinfo.mode |= stat.S_IWUSR | stat.S_IRUSR available or broken, using an alternate interpreter can be useful.
return tarinfo ''')
def add_data_as_file(tf: tarfile.TarFile, arcname: str, data: Union[str, bytes]) -> tarfile.TarInfo: opt('remote_dir', '.local/share/kitty-ssh-kitten', long_text='''
ans = tarfile.TarInfo(arcname) The location on the remote host where the files needed for this kitten are
ans.mtime = 0 installed. Relative paths are resolved with respect to :code:`$HOME`.
ans.type = tarfile.REGTYPE ''')
if isinstance(data, str):
data = data.encode('utf-8') opt('+copy', '', add_to_default=False, ctype='CopyInstruction', long_text=f'''
ans.size = len(data) {copy_message} For example::
normalize_tarinfo(ans)
tf.addfile(ans, io.BytesIO(data)) copy .vimrc .zshrc .config/some-dir
return ans
Use :code:`--dest` to copy a file to some other destination on the remote host::
def filter_from_globs(*pats: str) -> Callable[[tarfile.TarInfo], Optional[tarfile.TarInfo]]:
def filter(tarinfo: tarfile.TarInfo) -> Optional[tarfile.TarInfo]: copy --dest some-other-name some-file
for junk_dir in ('.DS_Store', '__pycache__'):
for pat in (f'*/{junk_dir}', f'*/{junk_dir}/*'): Glob patterns can be specified to copy multiple files, with :code:`--glob`::
if fnmatch.fnmatch(tarinfo.name, pat):
return None copy --glob images/*.png
for pat in pats:
if fnmatch.fnmatch(tarinfo.name, pat): Files can be excluded when copying with :code:`--exclude`::
return None
return normalize_tarinfo(tarinfo) copy --glob --exclude *.jpg --exclude *.bmp images/*
return filter
Files whose remote name matches the exclude pattern will not be copied.
from kitty.shell_integration import get_effective_ksi_env_var For more details, see :ref:`ssh_copy_command`.
if ssh_opts.shell_integration == 'inherited': ''')
ksi = get_effective_ksi_env_var(kitty_opts()) egr() # }}}
else:
from kitty.options.types import Options agr('shell', 'Login shell environment') # {{{
from kitty.options.utils import shell_integration
ksi = get_effective_ksi_env_var(Options({'shell_integration': shell_integration(ssh_opts.shell_integration)})) opt('shell_integration', 'inherited', long_text='''
Control the shell integration on the remote host. See :ref:`shell_integration`
env = { for details on how this setting works. The special value :code:`inherited` means
'TERM': os.environ.get('TERM') or kitty_opts().term, use the setting from :file:`kitty.conf`. This setting is useful for overriding
'COLORTERM': 'truecolor', integration on a per-host basis.
} ''')
env.update(ssh_opts.env)
for q in ('KITTY_WINDOW_ID', 'WINDOWID'): opt('login_shell', '', long_text='''
val = os.environ.get(q) The login shell to execute on the remote host. By default, the remote user
if val is not None: account's login shell is used.
env[q] = val ''')
env['KITTY_SHELL_INTEGRATION'] = ksi or DELETE_ENV_VAR
env['KITTY_SSH_KITTEN_DATA_DIR'] = ssh_opts.remote_dir opt('+env', '', add_to_default=False, ctype='EnvInstruction', long_text='''
if ssh_opts.login_shell: Specify the environment variables to be set on the remote host. Using the
env['KITTY_LOGIN_SHELL'] = ssh_opts.login_shell name with an equal sign (e.g. :code:`env VAR=`) will set it to the empty string.
if ssh_opts.cwd: Specifying only the name (e.g. :code:`env VAR`) will remove the variable from
env['KITTY_LOGIN_CWD'] = ssh_opts.cwd the remote shell environment. The special value :code:`_kitty_copy_env_var_`
if ssh_opts.remote_kitty != 'no': will cause the value of the variable to be copied from the local environment.
env['KITTY_REMOTE'] = ssh_opts.remote_kitty The definitions are processed alphabetically. Note that environment variables
if os.environ.get('KITTY_PUBLIC_KEY'): are expanded recursively, for example::
env.pop('KITTY_PUBLIC_KEY', None)
literal_env['KITTY_PUBLIC_KEY'] = os.environ['KITTY_PUBLIC_KEY'] env VAR1=a
env_script = serialize_env(literal_env, env, base_env, for_python=compression != 'gz') env VAR2=${HOME}/${VAR1}/b
buf = io.BytesIO()
with tarfile.open(mode=f'w:{compression}', fileobj=buf, encoding='utf-8') as tf: The value of :code:`VAR2` will be :code:`<path to home directory>/a/b`.
rd = ssh_opts.remote_dir.rstrip('/') ''')
for ci in ssh_opts.copy.values():
tf.add(ci.local_path, arcname=ci.arcname, filter=filter_from_globs(*ci.exclude_patterns)) opt('cwd', '', long_text='''
add_data_as_file(tf, 'data.sh', env_script) The working directory on the remote host to change to. Environment variables in
if compression == 'gz': this value are expanded. The default is empty so no changing is done, which
tf.add(f'{shell_integration_dir}/ssh/bootstrap-utils.sh', arcname='bootstrap-utils.sh', filter=normalize_tarinfo) usually means the HOME directory is used.
if ksi: ''')
arcname = 'home/' + rd + '/shell-integration'
tf.add(shell_integration_dir, arcname=arcname, filter=filter_from_globs( opt('color_scheme', '', long_text='''
f'{arcname}/ssh/*', # bootstrap files are sent as command line args Specify a color scheme to use when connecting to the remote host. If this option
f'{arcname}/zsh/kitty.zsh', # present for legacy compat not needed by ssh kitten ends with :code:`.conf`, it is assumed to be the name of a config file to load
)) from the kitty config directory, otherwise it is assumed to be the name of a
if ssh_opts.remote_kitty != 'no': color theme to load via the :doc:`themes kitten </kittens/themes>`. Note that
arcname = 'home/' + rd + '/kitty' only colors applying to the text/background are changed, other config settings
add_data_as_file(tf, arcname + '/version', str_version.encode('ascii')) in the .conf files/themes are ignored.
tf.add(shell_integration_dir + '/ssh/kitty', arcname=arcname + '/bin/kitty', filter=normalize_tarinfo) ''')
tf.add(shell_integration_dir + '/ssh/kitten', arcname=arcname + '/bin/kitten', filter=normalize_tarinfo)
tf.add(f'{terminfo_dir}/kitty.terminfo', arcname='home/.terminfo/kitty.terminfo', filter=normalize_tarinfo) opt('remote_kitty', 'if-needed', choices=('if-needed', 'no', 'yes'), long_text='''
tf.add(glob.glob(f'{terminfo_dir}/*/xterm-kitty')[0], arcname='home/.terminfo/x/xterm-kitty', filter=normalize_tarinfo) Make :program:`kitty` available on the remote host. Useful to run kittens such
return buf.getvalue() as the :doc:`icat kitten </kittens/icat>` to display images or the
:doc:`transfer file kitten </kittens/transfer>` to transfer files. Only works if
the remote host has an architecture for which :link:`pre-compiled kitty binaries
def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]: <https://github.com/kovidgoyal/kitty/releases>` are available. Note that kitty
yield b'\nKITTY_DATA_START\n' # to discard leading data is not actually copied to the remote host, instead a small bootstrap script is
try: copied which will download and run kitty when kitty is first executed on the
msg = standard_b64decode(msg).decode('utf-8') remote host. A value of :code:`if-needed` means kitty is installed only if not
md = dict(x.split('=', 1) for x in msg.split(':')) already present in the system-wide PATH. A value of :code:`yes` means that kitty
pw = md['pw'] is installed even if already present, and the installed kitty takes precedence.
pwfilename = md['pwfile'] Finally, :code:`no` means no kitty is installed on the remote host. The
rq_id = md['id'] installed kitty can be updated by running: :code:`kitty +update-kitty` on the
except Exception: remote host.
traceback.print_exc() ''')
yield b'invalid ssh data request message\n' egr() # }}}
else:
try: agr('ssh', 'SSH configuration') # {{{
env_data = read_data_from_shared_memory(pwfilename)
if pw != env_data['pw']: opt('share_connections', 'yes', option_type='to_bool', long_text='''
raise ValueError('Incorrect password') Within a single kitty instance, all connections to a particular server can be
if rq_id != request_id: shared. This reduces startup latency for subsequent connections and means that
raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window') you have to enter the password only once. Under the hood, it uses SSH
except Exception as e: ControlMasters and these are automatically cleaned up by kitty when it quits.
traceback.print_exc() You can map a shortcut to :ac:`close_shared_ssh_connections` to disconnect all
yield f'{e}\n'.encode('utf-8') active shared connections.
else: ''')
yield b'OK\n'
ssh_opts = SSHOptions(env_data['opts']) opt('askpass', 'unless-set', choices=('unless-set', 'ssh', 'native'), long_text='''
ssh_opts.copy = {k: CopyInstruction(*v) for k, v in ssh_opts.copy.items()} Control the program SSH uses to ask for passwords or confirmation of host keys
encoded_data = memoryview(env_data['tarfile'].encode('ascii')) etc. The default is to use kitty's native :program:`askpass`, unless the
# macOS has a 255 byte limit on its input queue as per man stty. :envvar:`SSH_ASKPASS` environment variable is set. Set this option to
# Not clear if that applies to canonical mode input as well, but :code:`ssh` to not interfere with the normal ssh askpass mechanism at all, which
# better to be safe. typically means that ssh will prompt at the terminal. Set it to :code:`native`
line_sz = 254 to always use kitty's native, built-in askpass implementation. Note that not
while encoded_data: using the kitty askpass implementation means that SSH might need to use the
yield encoded_data[:line_sz] terminal before the connection is established, so the kitten cannot use the
yield b'\n' terminal to send data without an extra roundtrip, adding to initial connection
encoded_data = encoded_data[line_sz:] latency.
yield b'KITTY_DATA_END\n' ''')
egr() # }}}
def safe_remove(x: str) -> None:
with suppress(OSError): def main(args: List[str]) -> Optional[str]:
os.remove(x) raise SystemExit('This should be run as kitten unicode_input')
def prepare_script(ans: str, replacements: Dict[str, str], script_type: str) -> str:
for k in ('EXEC_CMD', 'EXPORT_HOME_CMD'):
replacements[k] = replacements.get(k, '')
def sub(m: 're.Match[str]') -> str:
return replacements[m.group()]
return re.sub('|'.join(fr'\b{k}\b' for k in replacements), sub, ans)
def prepare_exec_cmd(remote_args: Sequence[str], is_python: bool) -> str:
# ssh simply concatenates multiple commands using a space see
# line 1129 of ssh.c and on the remote side sshd.c runs the
# concatenated command as shell -c cmd
if is_python:
return standard_b64encode(' '.join(remote_args).encode('utf-8')).decode('ascii')
args = ' '.join(c.replace("'", """'"'"'""") for c in remote_args)
return f"""unset KITTY_SHELL_INTEGRATION; exec "$login_shell" -c '{args}'"""
def prepare_export_home_cmd(ssh_opts: SSHOptions, is_python: bool) -> str:
home = ssh_opts.env.get('HOME')
if home == '_kitty_copy_env_var_':
home = os.environ.get('HOME')
if home:
if is_python:
return standard_b64encode(home.encode('utf-8')).decode('ascii')
else:
return f'export HOME={quote_env_val(home)}; cd "$HOME"'
return ''
def bootstrap_script(
ssh_opts: SSHOptions, script_type: str = 'sh', remote_args: Sequence[str] = (),
test_script: str = '', request_id: Optional[str] = None, cli_hostname: str = '', cli_uname: str = '',
request_data: bool = False, echo_on: bool = True, literal_env: Dict[str, str] = {}
) -> Tuple[str, Dict[str, str], str]:
if request_id is None:
request_id = os.environ['KITTY_PID'] + '-' + os.environ['KITTY_WINDOW_ID']
is_python = script_type == 'py'
export_home_cmd = prepare_export_home_cmd(ssh_opts, is_python) if 'HOME' in ssh_opts.env else ''
exec_cmd = prepare_exec_cmd(remote_args, is_python) if remote_args else ''
with open(os.path.join(shell_integration_dir, 'ssh', f'bootstrap.{script_type}')) as f:
ans = f.read()
pw = secrets.token_hex()
tfd = standard_b64encode(make_tarfile(ssh_opts, dict(os.environ), 'gz' if script_type == 'sh' else 'bz2', literal_env=literal_env)).decode('ascii')
data = {'pw': pw, 'opts': ssh_opts._asdict(), 'hostname': cli_hostname, 'uname': cli_uname, 'tarfile': tfd}
shm_name = create_shared_memory(data, prefix=f'kssh-{os.getpid()}-')
sensitive_data = {'REQUEST_ID': request_id, 'DATA_PASSWORD': pw, 'PASSWORD_FILENAME': shm_name}
replacements = {
'EXPORT_HOME_CMD': export_home_cmd,
'EXEC_CMD': exec_cmd, 'TEST_SCRIPT': test_script,
'REQUEST_DATA': '1' if request_data else '0', 'ECHO_ON': '1' if echo_on else '0',
}
sd = replacements.copy()
if request_data:
sd.update(sensitive_data)
replacements.update(sensitive_data)
return prepare_script(ans, sd, script_type), replacements, shm_name
def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
port: Optional[int] = None
expecting_port = expecting_identity = False
expecting_option_val = False
expecting_hostname = False
expecting_extra_val = ''
host_name = identity_file = found_ssh = ''
found_extra_args: List[Tuple[str, str]] = []
for i, arg in enumerate(args):
if not found_ssh:
if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
found_ssh = arg
continue
if expecting_hostname:
host_name = arg
continue
if arg.startswith('-') and not expecting_option_val:
if arg in boolean_ssh_args:
continue
if arg == '--':
expecting_hostname = True
if arg.startswith('-p'):
if arg[2:].isdigit():
with suppress(Exception):
port = int(arg[2:])
continue
elif arg == '-p':
expecting_port = True
elif arg.startswith('-i'):
if arg == '-i':
expecting_identity = True
else:
identity_file = arg[2:]
continue
if arg.startswith('--') and extra_args:
matching_ex = is_extra_arg(arg, extra_args)
if matching_ex:
if '=' in arg:
exval = arg.partition('=')[-1]
found_extra_args.append((matching_ex, exval))
continue
expecting_extra_val = matching_ex
expecting_option_val = True
continue
if expecting_option_val:
if expecting_port:
with suppress(Exception):
port = int(arg)
expecting_port = False
elif expecting_identity:
identity_file = arg
elif expecting_extra_val:
found_extra_args.append((expecting_extra_val, arg))
expecting_extra_val = ''
expecting_option_val = False
continue
if not host_name:
host_name = arg
if not host_name:
return None
if host_name.startswith('ssh://'):
from urllib.parse import urlparse
purl = urlparse(host_name)
if purl.hostname:
host_name = purl.hostname
if purl.username:
host_name = f'{purl.username}@{host_name}'
if port is None and purl.port:
port = purl.port
if identity_file:
if not os.path.isabs(identity_file):
identity_file = os.path.expanduser(identity_file)
if not os.path.isabs(identity_file):
identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))
class InvalidSSHArgs(ValueError):
def __init__(self, msg: str = ''):
super().__init__(msg)
self.err_msg = msg
def system_exit(self) -> None:
if self.err_msg:
print(self.err_msg, file=sys.stderr)
os.execlp(ssh_exe(), 'ssh')
def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[List[str], List[str], bool, Tuple[str, ...]]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
ssh_args = []
server_args: List[str] = []
expecting_option_val = False
passthrough = False
stop_option_processing = False
found_extra_args: List[str] = []
expecting_extra_val = ''
for argument in args:
if len(server_args) > 1 or stop_option_processing:
server_args.append(argument)
continue
if argument.startswith('-') and not expecting_option_val:
if argument == '--':
stop_option_processing = True
continue
if extra_args:
matching_ex = is_extra_arg(argument, extra_args)
if matching_ex:
if '=' in argument:
exval = argument.partition('=')[-1]
found_extra_args.extend((matching_ex, exval))
else:
expecting_extra_val = matching_ex
expecting_option_val = True
continue
# could be a multi-character option
all_args = argument[1:]
for i, arg in enumerate(all_args):
arg = f'-{arg}'
if arg in passthrough_args:
passthrough = True
if arg in boolean_ssh_args:
ssh_args.append(arg)
continue
if arg in other_ssh_args:
ssh_args.append(arg)
rest = all_args[i+1:]
if rest:
ssh_args.append(rest)
else:
expecting_option_val = True
break
raise InvalidSSHArgs(f'unknown option -- {arg[1:]}')
continue
if expecting_option_val:
if expecting_extra_val:
found_extra_args.extend((expecting_extra_val, argument))
expecting_extra_val = ''
else:
ssh_args.append(argument)
expecting_option_val = False
continue
server_args.append(argument)
if not server_args:
raise InvalidSSHArgs()
return ssh_args, server_args, passthrough, tuple(found_extra_args)
def wrap_bootstrap_script(sh_script: str, interpreter: str) -> List[str]:
# sshd will execute the command we pass it by join all command line
# arguments with a space and passing it as a single argument to the users
# login shell with -c. If the user has a non POSIX login shell it might
# have different escaping semantics and syntax, so the command it should
# execute has to be as simple as possible, basically of the form
# interpreter -c unwrap_script escaped_bootstrap_script
# The unwrap_script is responsible for unescaping the bootstrap script and
# executing it.
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
if is_python:
es = standard_b64encode(sh_script.encode('utf-8')).decode('ascii')
unwrap_script = '''"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"'''
else:
# We cant rely on base64 being available on the remote system, so instead
# we quote the bootstrap script by replacing ' and \ with \v and \f
# also replacing \n and ! with \r and \b for tcsh
# finally surrounding with '
es = "'" + sh_script.replace("'", '\v').replace('\\', '\f').replace('\n', '\r').replace('!', '\b') + "'"
unwrap_script = r"""'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' """
# exec is supported by all sh like shells, and fish and csh
return ['exec', interpreter, '-c', unwrap_script, es]
def get_remote_command(
remote_args: List[str], ssh_opts: SSHOptions, cli_hostname: str = '', cli_uname: str = '',
echo_on: bool = True, request_data: bool = False, literal_env: Dict[str, str] = {}
) -> Tuple[List[str], Dict[str, str], str]:
interpreter = ssh_opts.interpreter
q = os.path.basename(interpreter).lower()
is_python = 'python' in q
sh_script, replacements, shm_name = bootstrap_script(
ssh_opts, script_type='py' if is_python else 'sh', remote_args=remote_args, literal_env=literal_env,
cli_hostname=cli_hostname, cli_uname=cli_uname, echo_on=echo_on, request_data=request_data)
return wrap_bootstrap_script(sh_script, interpreter), replacements, shm_name
def connection_sharing_args(kitty_pid: int) -> List[str]:
rd = runtime_dir()
# Bloody OpenSSH generates a 40 char hash and in creating the socket
# appends a 27 char temp suffix to it. Socket max path length is approx
# ~104 chars. macOS has no system runtime dir so we use a cache dir in
# /Users/WHY_DOES_ANYONE_USE_MACOS/Library/Caches/APPLE_ARE_IDIOTIC
if len(rd) > 35 and os.path.isdir('/tmp'):
idiotic_design = f'/tmp/kssh-rdir-{os.getuid()}'
try:
os.symlink(rd, idiotic_design)
except FileExistsError:
try:
dest = os.readlink(idiotic_design)
except OSError as e:
raise ValueError(f'The {idiotic_design} symlink could not be created as something with that name exists already') from e
else:
if dest != rd:
with tempfile.TemporaryDirectory(dir='/tmp') as tdir:
tlink = os.path.join(tdir, 'sigh')
os.symlink(rd, tlink)
os.rename(tlink, idiotic_design)
rd = idiotic_design
cp = os.path.join(rd, ssh_control_master_template.format(kitty_pid=kitty_pid, ssh_placeholder='%C'))
ans: List[str] = [
'-o', 'ControlMaster=auto',
'-o', f'ControlPath={cp}',
'-o', 'ControlPersist=yes',
'-o', 'ServerAliveInterval=60',
'-o', 'ServerAliveCountMax=5',
'-o', 'TCPKeepAlive=no',
]
return ans
@contextmanager
def restore_terminal_state() -> Iterator[bool]:
with open(os.ctermid()) as f:
val = termios.tcgetattr(f.fileno())
print(end=SAVE_PRIVATE_MODE_VALUES)
print(end=set_mode(Mode.HANDLE_TERMIOS_SIGNALS), flush=True)
try:
yield bool(val[3] & termios.ECHO)
finally:
termios.tcsetattr(f.fileno(), termios.TCSAFLUSH, val)
print(end=RESTORE_PRIVATE_MODE_VALUES, flush=True)
def dcs_to_kitty(payload: Union[bytes, str], type: str = 'ssh') -> bytes:
if isinstance(payload, str):
payload = payload.encode('utf-8')
payload = standard_b64encode(payload)
ans = b'\033P@kitty-' + type.encode('ascii') + b'|' + payload
tmux = running_in_tmux()
if tmux:
cp = subprocess.run([tmux, 'set', '-p', 'allow-passthrough', 'on'])
if cp.returncode != 0:
raise SystemExit(cp.returncode)
ans = b'\033Ptmux;\033' + ans + b'\033\033\\\033\\'
else:
ans += b'\033\\'
return ans
@run_once
def ssh_version() -> Tuple[int, int]:
o = subprocess.check_output([ssh_exe(), '-V'], stderr=subprocess.STDOUT).decode()
m = re.match(r'OpenSSH_(\d+).(\d+)', o)
if m is None:
raise ValueError(f'Invalid version string for OpenSSH: {o}')
return int(m.group(1)), int(m.group(2))
@contextmanager
def drain_potential_tty_garbage(p: 'subprocess.Popen[bytes]', data_request: str) -> Iterator[None]:
with open(os.open(os.ctermid(), os.O_CLOEXEC | os.O_RDWR | os.O_NOCTTY), 'wb') as tty:
if data_request:
turn_off_echo(tty.fileno())
tty.write(dcs_to_kitty(data_request))
tty.flush()
try:
yield
finally:
# discard queued input data on tty in case data transmission was
# interrupted due to SSH failure, avoids spewing garbage to screen
from uuid import uuid4
canary = uuid4().hex.encode('ascii')
turn_off_echo(tty.fileno())
tty.write(dcs_to_kitty(canary + b'\n\r', type='echo'))
tty.flush()
data = b''
give_up_at = time.monotonic() + 2
tty_fd = tty.fileno()
while time.monotonic() < give_up_at and canary not in data:
with suppress(KeyboardInterrupt):
rd, wr, err = select([tty_fd], [], [tty_fd], max(0, give_up_at - time.monotonic()))
if err or not rd:
break
q = os.read(tty_fd, io.DEFAULT_BUFFER_SIZE)
if not q:
break
data += q
def change_colors(color_scheme: str) -> bool:
if not color_scheme:
return False
from kittens.themes.collection import NoCacheFound, load_themes, text_as_opts
from kittens.themes.main import colors_as_escape_codes
if color_scheme.endswith('.conf'):
conf_file = resolve_abs_or_config_path(color_scheme)
try:
with open(conf_file) as f:
opts = text_as_opts(f.read())
except FileNotFoundError:
raise SystemExit(f'Failed to find the color conf file: {expandvars(conf_file)}')
else:
try:
themes = load_themes(-1)
except NoCacheFound:
themes = load_themes()
cs = expandvars(color_scheme)
try:
theme = themes[cs]
except KeyError:
raise SystemExit(f'Failed to find the color theme: {cs}')
opts = theme.kitty_opts
raw = colors_as_escape_codes(opts)
print(save_colors(), sep='', end=raw, flush=True)
return True
def add_cloned_env(shm_name: str) -> Dict[str, str]:
try:
return cast(Dict[str, str], read_data_from_shared_memory(shm_name))
except FileNotFoundError:
pass
return {}
def run_ssh(ssh_args: List[str], server_args: List[str], found_extra_args: Tuple[str, ...]) -> NoReturn:
cmd = [ssh_exe()] + ssh_args
hostname, remote_args = server_args[0], server_args[1:]
if not remote_args:
cmd.append('-t')
insertion_point = len(cmd)
cmd.append('--')
cmd.append(hostname)
uname = getuser()
if hostname.startswith('ssh://'):
from urllib.parse import urlparse
purl = urlparse(hostname)
hostname_for_match = purl.hostname or hostname[6:].split('/', 1)[0]
uname = purl.username or uname
elif '@' in hostname and hostname[0] != '@':
uname, hostname_for_match = hostname.split('@', 1)
else:
hostname_for_match = hostname
hostname_for_match = hostname_for_match.split('@', 1)[-1].split(':', 1)[0]
overrides: List[str] = []
literal_env: Dict[str, str] = {}
pat = re.compile(r'^([a-zA-Z0-9_]+)[ \t]*=')
for i, a in enumerate(found_extra_args):
if i % 2 == 1:
aq = pat.sub(r'\1 ', a.lstrip())
key = aq.split(maxsplit=1)[0]
if key == 'clone_env':
literal_env = add_cloned_env(aq.split(maxsplit=1)[1])
elif key != 'hostname':
overrides.append(aq)
if overrides:
overrides.insert(0, f'hostname {uname}@{hostname_for_match}')
host_opts = init_config(hostname_for_match, uname, overrides)
if host_opts.share_connections:
cmd[insertion_point:insertion_point] = connection_sharing_args(int(os.environ['KITTY_PID']))
use_kitty_askpass = host_opts.askpass == 'native' or (host_opts.askpass == 'unless-set' and 'SSH_ASKPASS' not in os.environ)
need_to_request_data = True
if use_kitty_askpass:
sentinel = os.path.join(cache_dir(), 'openssh-is-new-enough-for-askpass')
sentinel_exists = os.path.exists(sentinel)
if sentinel_exists or ssh_version() >= (8, 4):
if not sentinel_exists:
open(sentinel, 'w').close()
# SSH_ASKPASS_REQUIRE was introduced in 8.4 release on 2020-09-27
need_to_request_data = False
os.environ['SSH_ASKPASS_REQUIRE'] = 'force'
os.environ['SSH_ASKPASS'] = os.path.join(shell_integration_dir, 'ssh', 'askpass.py')
if need_to_request_data and host_opts.share_connections:
cp = subprocess.run(cmd[:1] + ['-O', 'check'] + cmd[1:], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
if cp.returncode == 0:
# we will use the master connection so SSH does not need to use the tty
need_to_request_data = False
with restore_terminal_state() as echo_on:
rcmd, replacements, shm_name = get_remote_command(
remote_args, host_opts, hostname_for_match, uname, echo_on, request_data=need_to_request_data, literal_env=literal_env)
cmd += rcmd
colors_changed = change_colors(host_opts.color_scheme)
try:
p = subprocess.Popen(cmd)
except FileNotFoundError:
raise SystemExit('Could not find the ssh executable, is it in your PATH?')
else:
rq = '' if need_to_request_data else 'id={REQUEST_ID}:pwfile={PASSWORD_FILENAME}:pw={DATA_PASSWORD}'.format(**replacements)
with drain_potential_tty_garbage(p, rq):
raise SystemExit(p.wait())
finally:
if colors_changed:
print(end=restore_colors(), flush=True)
def main(args: List[str]) -> None:
args = args[1:]
if args and args[0] == 'use-python':
args = args[1:] # backwards compat from when we had a python implementation
try:
ssh_args, server_args, passthrough, found_extra_args = parse_ssh_args(args, extra_args=('--kitten',))
except InvalidSSHArgs as e:
e.system_exit()
if passthrough:
if found_extra_args:
raise SystemExit(f'The SSH kitten cannot work with the options: {", ".join(passthrough_args)}')
os.execlp(ssh_exe(), 'ssh', *args)
if not os.environ.get('KITTY_WINDOW_ID') or not os.environ.get('KITTY_PID'):
raise SystemExit('The SSH kitten is meant to run inside a kitty window')
if not sys.stdin.isatty():
raise SystemExit('The SSH kitten is meant for interactive use only, STDIN must be a terminal')
try:
run_ssh(ssh_args, server_args, found_extra_args)
except KeyboardInterrupt:
sys.excepthook = lambda *a: None
raise
if __name__ == '__main__': if __name__ == '__main__':
main(sys.argv) main([])
elif __name__ == '__wrapper_of__': elif __name__ == '__wrapper_of__':
cd = sys.cli_docs # type: ignore cd = getattr(sys, 'cli_docs')
cd['wrapper_of'] = 'ssh' cd['wrapper_of'] = 'ssh'
elif __name__ == '__conf__': elif __name__ == '__conf__':
from .options.definition import definition setattr(sys, 'options_definition', definition)
sys.options_definition = definition # type: ignore elif __name__ == '__extra_cli_parsers__':
setattr(sys, 'extra_cli_parsers', {'copy': option_text()})

View File

@ -1,153 +0,0 @@
#!/usr/bin/env python
# vim:fileencoding=utf-8
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
# After editing this file run ./gen-config.py to apply the changes
from kitty.conf.types import Definition
copy_message = '''\
Copy files and directories from local to remote hosts. The specified files are
assumed to be relative to the HOME directory and copied to the HOME on the
remote. Directories are copied recursively. If absolute paths are used, they are
copied as is.'''
definition = Definition(
'kittens.ssh',
)
agr = definition.add_group
egr = definition.end_group
opt = definition.add_option
agr('bootstrap', 'Host bootstrap configuration') # {{{
opt('hostname', '*', option_type='hostname', long_text='''
The hostname that the following options apply to. A glob pattern to match
multiple hosts can be used. Multiple hostnames can also be specified, separated
by spaces. The hostname can include an optional username in the form
:code:`user@host`. When not specified options apply to all hosts, until the
first hostname specification is found. Note that matching of hostname is done
against the name you specify on the command line to connect to the remote host.
If you wish to include the same basic configuration for many different hosts,
you can do so with the :ref:`include <include>` directive.
''')
opt('interpreter', 'sh', long_text='''
The interpreter to use on the remote host. Must be either a POSIX complaint
shell or a :program:`python` executable. If the default :program:`sh` is not
available or broken, using an alternate interpreter can be useful.
''')
opt('remote_dir', '.local/share/kitty-ssh-kitten', long_text='''
The location on the remote host where the files needed for this kitten are
installed. Relative paths are resolved with respect to :code:`$HOME`.
''')
opt('+copy', '', option_type='copy', add_to_default=False, long_text=f'''
{copy_message} For example::
copy .vimrc .zshrc .config/some-dir
Use :code:`--dest` to copy a file to some other destination on the remote host::
copy --dest some-other-name some-file
Glob patterns can be specified to copy multiple files, with :code:`--glob`::
copy --glob images/*.png
Files can be excluded when copying with :code:`--exclude`::
copy --glob --exclude *.jpg --exclude *.bmp images/*
Files whose remote name matches the exclude pattern will not be copied.
For more details, see :ref:`ssh_copy_command`.
''')
egr() # }}}
agr('shell', 'Login shell environment') # {{{
opt('shell_integration', 'inherited', long_text='''
Control the shell integration on the remote host. See :ref:`shell_integration`
for details on how this setting works. The special value :code:`inherited` means
use the setting from :file:`kitty.conf`. This setting is useful for overriding
integration on a per-host basis.
''')
opt('login_shell', '', long_text='''
The login shell to execute on the remote host. By default, the remote user
account's login shell is used.
''')
opt('+env', '', option_type='env', add_to_default=False, long_text='''
Specify the environment variables to be set on the remote host. Using the
name with an equal sign (e.g. :code:`env VAR=`) will set it to the empty string.
Specifying only the name (e.g. :code:`env VAR`) will remove the variable from
the remote shell environment. The special value :code:`_kitty_copy_env_var_`
will cause the value of the variable to be copied from the local environment.
The definitions are processed alphabetically. Note that environment variables
are expanded recursively, for example::
env VAR1=a
env VAR2=${HOME}/${VAR1}/b
The value of :code:`VAR2` will be :code:`<path to home directory>/a/b`.
''')
opt('cwd', '', long_text='''
The working directory on the remote host to change to. Environment variables in
this value are expanded. The default is empty so no changing is done, which
usually means the HOME directory is used.
''')
opt('color_scheme', '', long_text='''
Specify a color scheme to use when connecting to the remote host. If this option
ends with :code:`.conf`, it is assumed to be the name of a config file to load
from the kitty config directory, otherwise it is assumed to be the name of a
color theme to load via the :doc:`themes kitten </kittens/themes>`. Note that
only colors applying to the text/background are changed, other config settings
in the .conf files/themes are ignored.
''')
opt('remote_kitty', 'if-needed', choices=('if-needed', 'no', 'yes'), long_text='''
Make :program:`kitty` available on the remote host. Useful to run kittens such
as the :doc:`icat kitten </kittens/icat>` to display images or the
:doc:`transfer file kitten </kittens/transfer>` to transfer files. Only works if
the remote host has an architecture for which :link:`pre-compiled kitty binaries
<https://github.com/kovidgoyal/kitty/releases>` are available. Note that kitty
is not actually copied to the remote host, instead a small bootstrap script is
copied which will download and run kitty when kitty is first executed on the
remote host. A value of :code:`if-needed` means kitty is installed only if not
already present in the system-wide PATH. A value of :code:`yes` means that kitty
is installed even if already present, and the installed kitty takes precedence.
Finally, :code:`no` means no kitty is installed on the remote host. The
installed kitty can be updated by running: :code:`kitty +update-kitty` on the
remote host.
''')
egr() # }}}
agr('ssh', 'SSH configuration') # {{{
opt('share_connections', 'yes', option_type='to_bool', long_text='''
Within a single kitty instance, all connections to a particular server can be
shared. This reduces startup latency for subsequent connections and means that
you have to enter the password only once. Under the hood, it uses SSH
ControlMasters and these are automatically cleaned up by kitty when it quits.
You can map a shortcut to :ac:`close_shared_ssh_connections` to disconnect all
active shared connections.
''')
opt('askpass', 'unless-set', choices=('unless-set', 'ssh', 'native'), long_text='''
Control the program SSH uses to ask for passwords or confirmation of host keys
etc. The default is to use kitty's native :program:`askpass`, unless the
:envvar:`SSH_ASKPASS` environment variable is set. Set this option to
:code:`ssh` to not interfere with the normal ssh askpass mechanism at all, which
typically means that ssh will prompt at the terminal. Set it to :code:`native`
to always use kitty's native, built-in askpass implementation. Note that not
using the kitty askpass implementation means that SSH might need to use the
terminal before the connection is established, so the kitten cannot use the
terminal to send data without an extra roundtrip, adding to initial connection
latency.
''')
egr() # }}}

View File

@ -1,90 +0,0 @@
# generated by gen-config.py DO NOT edit
# isort: skip_file
import typing
from kittens.ssh.options.utils import copy, env, hostname
from kitty.conf.utils import merge_dicts, to_bool
class Parser:
def askpass(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
val = val.lower()
if val not in self.choices_for_askpass:
raise ValueError(f"The value {val} is not a valid choice for askpass")
ans["askpass"] = val
choices_for_askpass = frozenset(('unless-set', 'ssh', 'native'))
def color_scheme(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['color_scheme'] = str(val)
def copy(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
for k, v in copy(val, ans["copy"]):
ans["copy"][k] = v
def cwd(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['cwd'] = str(val)
def env(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
for k, v in env(val, ans["env"]):
ans["env"][k] = v
def hostname(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
hostname(val, ans)
def interpreter(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['interpreter'] = str(val)
def login_shell(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['login_shell'] = str(val)
def remote_dir(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['remote_dir'] = str(val)
def remote_kitty(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
val = val.lower()
if val not in self.choices_for_remote_kitty:
raise ValueError(f"The value {val} is not a valid choice for remote_kitty")
ans["remote_kitty"] = val
choices_for_remote_kitty = frozenset(('if-needed', 'no', 'yes'))
def share_connections(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['share_connections'] = to_bool(val)
def shell_integration(self, val: str, ans: typing.Dict[str, typing.Any]) -> None:
ans['shell_integration'] = str(val)
def create_result_dict() -> typing.Dict[str, typing.Any]:
return {
'copy': {},
'env': {},
}
actions: typing.FrozenSet[str] = frozenset(())
def merge_result_dicts(defaults: typing.Dict[str, typing.Any], vals: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
ans = {}
for k, v in defaults.items():
if isinstance(v, dict):
ans[k] = merge_dicts(v, vals.get(k, {}))
elif k in actions:
ans[k] = v + vals.get(k, [])
else:
ans[k] = vals.get(k, v)
return ans
parser = Parser()
def parse_conf_item(key: str, val: str, ans: typing.Dict[str, typing.Any]) -> bool:
func = getattr(parser, key, None)
if func is not None:
func(val, ans)
return True
return False

View File

@ -1,93 +0,0 @@
# generated by gen-config.py DO NOT edit
# isort: skip_file
import typing
import kittens.ssh.copy
if typing.TYPE_CHECKING:
choices_for_askpass = typing.Literal['unless-set', 'ssh', 'native']
choices_for_remote_kitty = typing.Literal['if-needed', 'no', 'yes']
else:
choices_for_askpass = str
choices_for_remote_kitty = str
option_names = ( # {{{
'askpass',
'color_scheme',
'copy',
'cwd',
'env',
'hostname',
'interpreter',
'login_shell',
'remote_dir',
'remote_kitty',
'share_connections',
'shell_integration') # }}}
class Options:
askpass: choices_for_askpass = 'unless-set'
color_scheme: str = ''
cwd: str = ''
hostname: str = '*'
interpreter: str = 'sh'
login_shell: str = ''
remote_dir: str = '.local/share/kitty-ssh-kitten'
remote_kitty: choices_for_remote_kitty = 'if-needed'
share_connections: bool = True
shell_integration: str = 'inherited'
copy: typing.Dict[str, kittens.ssh.copy.CopyInstruction] = {}
env: typing.Dict[str, str] = {}
config_paths: typing.Tuple[str, ...] = ()
config_overrides: typing.Tuple[str, ...] = ()
def __init__(self, options_dict: typing.Optional[typing.Dict[str, typing.Any]] = None) -> None:
if options_dict is not None:
null = object()
for key in option_names:
val = options_dict.get(key, null)
if val is not null:
setattr(self, key, val)
@property
def _fields(self) -> typing.Tuple[str, ...]:
return option_names
def __iter__(self) -> typing.Iterator[str]:
return iter(self._fields)
def __len__(self) -> int:
return len(self._fields)
def _copy_of_val(self, name: str) -> typing.Any:
ans = getattr(self, name)
if isinstance(ans, dict):
ans = ans.copy()
elif isinstance(ans, list):
ans = ans[:]
return ans
def _asdict(self) -> typing.Dict[str, typing.Any]:
return {k: self._copy_of_val(k) for k in self}
def _replace(self, **kw: typing.Any) -> "Options":
ans = Options()
for name in self:
setattr(ans, name, self._copy_of_val(name))
for name, val in kw.items():
setattr(ans, name, val)
return ans
def __getitem__(self, key: typing.Union[int, str]) -> typing.Any:
k = option_names[key] if isinstance(key, int) else key
try:
return getattr(self, k)
except AttributeError:
pass
raise KeyError(f"No option named: {k}")
defaults = Options()
defaults.copy = {}
defaults.env = {}

View File

@ -1,56 +0,0 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
from typing import Any, Dict, Iterable, Optional, Tuple
from ..copy import CopyInstruction, parse_copy_instructions
DELETE_ENV_VAR = '_delete_this_env_var_'
def env(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, str]]:
val = val.strip()
if val:
if '=' in val:
key, v = val.split('=', 1)
key, v = key.strip(), v.strip()
if key:
yield key, v
else:
yield val, DELETE_ENV_VAR
def copy(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]:
yield from parse_copy_instructions(val, current_val)
def init_results_dict(ans: Dict[str, Any]) -> Dict[str, Any]:
ans['hostname'] = '*'
ans['per_host_dicts'] = {}
return ans
def get_per_hosts_dict(results_dict: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
ans: Dict[str, Dict[str, Any]] = results_dict.get('per_host_dicts', {}).copy()
h = results_dict['hostname']
hd = {k: v for k, v in results_dict.items() if k != 'per_host_dicts'}
ans[h] = hd
return ans
first_seen_positions: Dict[str, int] = {}
def hostname(val: str, dict_with_parse_results: Optional[Dict[str, Any]] = None) -> str:
if dict_with_parse_results is not None:
ch = dict_with_parse_results['hostname']
if val != ch:
from .parse import create_result_dict
phd = get_per_hosts_dict(dict_with_parse_results)
dict_with_parse_results.clear()
dict_with_parse_results.update(phd.pop(val, create_result_dict()))
dict_with_parse_results['per_host_dicts'] = phd
dict_with_parse_results['hostname'] = val
if val not in first_seen_positions:
first_seen_positions[val] = len(first_seen_positions)
return val

View File

@ -4,9 +4,12 @@
import os import os
import subprocess import subprocess
from typing import Any, Dict, List, Sequence, Set, Tuple import traceback
from contextlib import suppress
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple
from kitty.types import run_once from kitty.types import run_once
from kitty.utils import SSHConnectionData
@run_once @run_once
@ -94,6 +97,57 @@ def create_shared_memory(data: Any, prefix: str) -> str:
return shm.name return shm.name
def read_data_from_shared_memory(shm_name: str) -> Any:
import json
import stat
from kitty.shm import SharedMemory
with SharedMemory(shm_name, readonly=True) as shm:
shm.unlink()
if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
raise ValueError('Incorrect owner on pwfile')
mode = stat.S_IMODE(shm.stats.st_mode)
if mode != stat.S_IREAD | stat.S_IWRITE:
raise ValueError('Incorrect permissions on pwfile')
return json.loads(shm.read_data_with_size())
def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
from base64 import standard_b64decode
yield b'\nKITTY_DATA_START\n' # to discard leading data
try:
msg = standard_b64decode(msg).decode('utf-8')
md = dict(x.split('=', 1) for x in msg.split(':'))
pw = md['pw']
pwfilename = md['pwfile']
rq_id = md['id']
except Exception:
traceback.print_exc()
yield b'invalid ssh data request message\n'
else:
try:
env_data = read_data_from_shared_memory(pwfilename)
if pw != env_data['pw']:
raise ValueError('Incorrect password')
if rq_id != request_id:
raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window')
except Exception as e:
traceback.print_exc()
yield f'{e}\n'.encode('utf-8')
else:
yield b'OK\n'
encoded_data = memoryview(env_data['tarfile'].encode('ascii'))
# macOS has a 255 byte limit on its input queue as per man stty.
# Not clear if that applies to canonical mode input as well, but
# better to be safe.
line_sz = 254
while encoded_data:
yield encoded_data[:line_sz]
yield b'\n'
encoded_data = encoded_data[line_sz:]
yield b'KITTY_DATA_END\n'
def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None: def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None:
patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv) patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)
@ -183,3 +237,86 @@ def set_server_args_in_cmdline(
ans.insert(i, '-t') ans.insert(i, '-t')
break break
argv[:] = ans + server_args argv[:] = ans + server_args
def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
port: Optional[int] = None
expecting_port = expecting_identity = False
expecting_option_val = False
expecting_hostname = False
expecting_extra_val = ''
host_name = identity_file = found_ssh = ''
found_extra_args: List[Tuple[str, str]] = []
for i, arg in enumerate(args):
if not found_ssh:
if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
found_ssh = arg
continue
if expecting_hostname:
host_name = arg
continue
if arg.startswith('-') and not expecting_option_val:
if arg in boolean_ssh_args:
continue
if arg == '--':
expecting_hostname = True
if arg.startswith('-p'):
if arg[2:].isdigit():
with suppress(Exception):
port = int(arg[2:])
continue
elif arg == '-p':
expecting_port = True
elif arg.startswith('-i'):
if arg == '-i':
expecting_identity = True
else:
identity_file = arg[2:]
continue
if arg.startswith('--') and extra_args:
matching_ex = is_extra_arg(arg, extra_args)
if matching_ex:
if '=' in arg:
exval = arg.partition('=')[-1]
found_extra_args.append((matching_ex, exval))
continue
expecting_extra_val = matching_ex
expecting_option_val = True
continue
if expecting_option_val:
if expecting_port:
with suppress(Exception):
port = int(arg)
expecting_port = False
elif expecting_identity:
identity_file = arg
elif expecting_extra_val:
found_extra_args.append((expecting_extra_val, arg))
expecting_extra_val = ''
expecting_option_val = False
continue
if not host_name:
host_name = arg
if not host_name:
return None
if host_name.startswith('ssh://'):
from urllib.parse import urlparse
purl = urlparse(host_name)
if purl.hostname:
host_name = purl.hostname
if purl.username:
host_name = f'{purl.username}@{host_name}'
if port is None and purl.port:
port = purl.port
if identity_file:
if not os.path.isabs(identity_file):
identity_file = os.path.expanduser(identity_file)
if not os.path.isabs(identity_file):
identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))

View File

@ -13,7 +13,7 @@ LaunchCLIOptions = AskCLIOptions = ClipboardCLIOptions = DiffCLIOptions = CLIOpt
HintsCLIOptions = IcatCLIOptions = PanelCLIOptions = ResizeCLIOptions = CLIOptions HintsCLIOptions = IcatCLIOptions = PanelCLIOptions = ResizeCLIOptions = CLIOptions
ErrorCLIOptions = UnicodeCLIOptions = RCOptions = RemoteFileCLIOptions = CLIOptions ErrorCLIOptions = UnicodeCLIOptions = RCOptions = RemoteFileCLIOptions = CLIOptions
QueryTerminalCLIOptions = BroadcastCLIOptions = ShowKeyCLIOptions = CLIOptions QueryTerminalCLIOptions = BroadcastCLIOptions = ShowKeyCLIOptions = CLIOptions
ThemesCLIOptions = TransferCLIOptions = CopyCLIOptions = CLIOptions ThemesCLIOptions = TransferCLIOptions = CLIOptions
def generate_stub() -> None: def generate_stub() -> None:
@ -78,9 +78,6 @@ def generate_stub() -> None:
from kittens.transfer.main import option_text as OPTIONS from kittens.transfer.main import option_text as OPTIONS
do(OPTIONS(), 'TransferCLIOptions') do(OPTIONS(), 'TransferCLIOptions')
from kittens.ssh.copy import option_text as OPTIONS
do(OPTIONS(), 'CopyCLIOptions')
from kitty.rc.base import all_command_names, command_for_name from kitty.rc.base import all_command_names, command_for_name
for cmd_name in all_command_names(): for cmd_name in all_command_names():
cmd = command_for_name(cmd_name) cmd = command_for_name(cmd_name)

View File

@ -9,7 +9,7 @@ import re
import textwrap import textwrap
from typing import Any, Callable, Dict, Iterator, List, Set, Tuple, Union, get_type_hints from typing import Any, Callable, Dict, Iterator, List, Set, Tuple, Union, get_type_hints
from kitty.conf.types import Definition, MultiOption, Option, unset from kitty.conf.types import Definition, MultiOption, Option, ParserFuncType, unset
from kitty.types import _T from kitty.types import _T
@ -442,6 +442,121 @@ def write_output(loc: str, defn: Definition) -> None:
f.write(f'{c}\n') f.write(f'{c}\n')
def go_type_data(parser_func: ParserFuncType, ctype: str) -> Tuple[str, str]:
if ctype:
return f'*{ctype}', f'Parse{ctype}(val)'
p = parser_func.__name__
if p == 'int':
return 'int64', 'strconv.ParseInt(val, 10, 64)'
if p == 'str':
return 'string', 'val, nil'
if p == 'float':
return 'float64', 'strconv.ParseFloat(val, 10, 64)'
if p == 'to_bool':
return 'bool', 'config.StringToBool(val), nil'
th = get_type_hints(parser_func)
rettype = th['return']
return {int: 'int64', str: 'string', float: 'float64'}[rettype], f'{p}(val)'
def gen_go_code(defn: Definition) -> str:
lines = ['import "fmt"', 'import "strconv"', 'import "kitty/tools/config"',
'var _ = fmt.Println', 'var _ = config.StringToBool', 'var _ = strconv.Atoi']
a = lines.append
choices = {}
go_types = {}
go_parsers = {}
defaults = {}
multiopts = {''}
for option in sorted(defn.iter_all_options(), key=lambda a: natural_keys(a.name)):
name = option.name.capitalize()
if isinstance(option, MultiOption):
go_types[name], go_parsers[name] = go_type_data(option.parser_func, option.ctype)
multiopts.add(name)
else:
defaults[name] = option.parser_func(option.defval_as_string)
if option.choices:
choices[name] = option.choices
go_types[name] = f'{name}_Choice_Type'
go_parsers[name] = f'Parse_{name}(val)'
continue
go_types[name], go_parsers[name] = go_type_data(option.parser_func, option.ctype)
for oname in choices:
a(f'type {go_types[oname]} int')
a('type Config struct {')
for name, gotype in go_types.items():
if name in multiopts:
a(f'{name} []{gotype}')
else:
a(f'{name} {gotype}')
a('}')
def cval(x: str) -> str:
return x.replace('-', '_')
a('func NewConfig() *Config {')
a('return &Config{')
for name, pname in go_parsers.items():
if name in multiopts:
continue
d = defaults[name]
if not d:
continue
if isinstance(d, str):
dval = f'{name}_{cval(d)}' if name in choices else f'`{d}`'
elif isinstance(d, bool):
dval = repr(d).lower()
else:
dval = repr(d)
a(f'{name}: {dval},')
a('}''}')
for oname, choice_vals in choices.items():
a('const (')
for i, c in enumerate(choice_vals):
c = cval(c)
if i == 0:
a(f'{oname}_{c} {oname}_Choice_Type = iota')
else:
a(f'{oname}_{c}')
a(')')
a(f'func (x {oname}_Choice_Type) String() string'' {')
a('switch x {')
a('default: return ""')
for c in choice_vals:
a(f'case {oname}_{cval(c)}: return "{c}"')
a('}''}')
a(f'func {go_parsers[oname].split("(")[0]}(val string) (ans {go_types[oname]}, err error) ''{')
a('switch val {')
for c in choice_vals:
a(f'case "{c}": return {oname}_{cval(c)}, nil')
vals = ', '.join(choice_vals)
a(f'default: return ans, fmt.Errorf("%#v is not a valid value for %s. Valid values are: %s", val, "{c}", "{vals}")')
a('}''}')
a('func (c *Config) Parse(key, val string) (err error) {')
a('switch key {')
a('default: return fmt.Errorf("Unknown configuration key: %#v", key)')
for oname, pname in go_parsers.items():
ol = oname.lower()
is_multiple = oname in multiopts
a(f'case "{ol}":')
if is_multiple:
a(f'var temp_val []{go_types[oname]}')
else:
a(f'var temp_val {go_types[oname]}')
a(f'temp_val, err = {pname}')
a(f'if err != nil {{ return fmt.Errorf("Failed to parse {ol} = %#v with error: %w", val, err) }}')
if is_multiple:
a(f'c.{oname} = append(c.{oname}, temp_val...)')
else:
a(f'c.{oname} = temp_val')
a('}')
a('return}')
return '\n'.join(lines)
def main() -> None: def main() -> None:
# To use run it as: # To use run it as:
# kitty +runpy 'from kitty.conf.generate import main; main()' /path/to/kitten/file.py # kitty +runpy 'from kitty.conf.generate import main; main()' /path/to/kitten/file.py

View File

@ -854,12 +854,14 @@ def store_multiple(val: str, current_val: Container[str]) -> Iterable[Tuple[str,
yield val, val yield val, val
allowed_shell_integration_values = frozenset({'enabled', 'disabled', 'no-rc', 'no-cursor', 'no-title', 'no-prompt-mark', 'no-complete', 'no-cwd'})
def shell_integration(x: str) -> FrozenSet[str]: def shell_integration(x: str) -> FrozenSet[str]:
s = frozenset({'enabled', 'disabled', 'no-rc', 'no-cursor', 'no-title', 'no-prompt-mark', 'no-complete', 'no-cwd'})
q = frozenset(x.lower().split()) q = frozenset(x.lower().split())
if not q.issubset(s): if not q.issubset(allowed_shell_integration_values):
log_error(f'Invalid shell integration options: {q - s}, ignoring') log_error(f'Invalid shell integration options: {q - allowed_shell_integration_values}, ignoring')
return q & s or frozenset({'invalid'}) return q & allowed_shell_integration_values or frozenset({'invalid'})
return q return q

View File

@ -959,7 +959,7 @@ class Window:
def handle_remote_file(self, netloc: str, remote_path: str) -> None: def handle_remote_file(self, netloc: str, remote_path: str) -> None:
from kittens.remote_file.main import is_ssh_kitten_sentinel from kittens.remote_file.main import is_ssh_kitten_sentinel
from kittens.ssh.main import get_connection_data from kittens.ssh.utils import get_connection_data
from .utils import SSHConnectionData from .utils import SSHConnectionData
args = self.ssh_kitten_cmdline() args = self.ssh_kitten_cmdline()
@ -1156,7 +1156,7 @@ class Window:
self.write_to_child(data) self.write_to_child(data)
def handle_remote_ssh(self, msg: str) -> None: def handle_remote_ssh(self, msg: str) -> None:
from kittens.ssh.main import get_ssh_data from kittens.ssh.utils import get_ssh_data
for line in get_ssh_data(msg, f'{os.getpid()}-{self.id}'): for line in get_ssh_data(msg, f'{os.getpid()}-{self.id}'):
self.write_to_child(line) self.write_to_child(line)

View File

@ -112,7 +112,7 @@ class Callbacks:
self.current_clone_data += rest self.current_clone_data += rest
def handle_remote_ssh(self, msg): def handle_remote_ssh(self, msg):
from kittens.ssh.main import get_ssh_data from kittens.ssh.utils import get_ssh_data
if self.pty: if self.pty:
for line in get_ssh_data(msg, "testing"): for line in get_ssh_data(msg, "testing"):
self.pty.write_to_child(line) self.pty.write_to_child(line)

View File

@ -67,7 +67,7 @@ class TestBuild(BaseTest):
q = stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH q = stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
return mode & q == q return mode & q == q
for x in ('kitty', 'kitten', 'askpass.py'): for x in ('kitty', 'kitten'):
x = os.path.join(shell_integration_dir, 'ssh', x) x = os.path.join(shell_integration_dir, 'ssh', x)
self.assertTrue(is_executable(x), f'{x} is not executable') self.assertTrue(is_executable(x), f'{x} is not executable')
if getattr(sys, 'frozen', False): if getattr(sys, 'frozen', False):

30
kitty_tests/shm.py Normal file
View File

@ -0,0 +1,30 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2023, Kovid Goyal <kovid at kovidgoyal.net>
import os
import subprocess
from kitty.constants import kitten_exe
from kitty.fast_data_types import shm_unlink
from kitty.shm import SharedMemory
from . import BaseTest
class SHMTest(BaseTest):
def test_shm_with_kitten(self):
data = os.urandom(333)
with SharedMemory(size=363) as shm:
shm.write_data_with_size(data)
cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'read', shm.name], stdout=subprocess.PIPE)
self.assertEqual(cp.returncode, 0)
self.assertEqual(cp.stdout, data)
self.assertRaises(FileNotFoundError, shm_unlink, shm.name)
cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'write'], input=data, stdout=subprocess.PIPE)
self.assertEqual(cp.returncode, 0)
name = cp.stdout.decode().strip()
with SharedMemory(name=name, unlink_on_exit=True) as shm:
q = shm.read_data_with_size()
self.assertEqual(data, q)

View File

@ -3,18 +3,16 @@
import glob import glob
import json
import os import os
import shutil import shutil
import subprocess
import tempfile import tempfile
from contextlib import suppress from contextlib import suppress
from functools import lru_cache from functools import lru_cache
from kittens.ssh.config import load_config from kittens.ssh.utils import get_connection_data
from kittens.ssh.main import bootstrap_script, get_connection_data, wrap_bootstrap_script from kitty.constants import is_macos, kitten_exe, runtime_dir
from kittens.ssh.options.types import Options as SSHOptions
from kittens.ssh.options.utils import DELETE_ENV_VAR
from kittens.transfer.utils import set_paths
from kitty.constants import is_macos, runtime_dir
from kitty.fast_data_types import CURSOR_BEAM, shm_unlink from kitty.fast_data_types import CURSOR_BEAM, shm_unlink
from kitty.utils import SSHConnectionData from kitty.utils import SSHConnectionData
@ -58,32 +56,6 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
t('ssh --kitten=one -p 12 --kitten two -ix main', identity_file='x', port=12, extra_args=(('--kitten', 'one'), ('--kitten', 'two'))) t('ssh --kitten=one -p 12 --kitten two -ix main', identity_file='x', port=12, extra_args=(('--kitten', 'one'), ('--kitten', 'two')))
self.assertTrue(runtime_dir()) self.assertTrue(runtime_dir())
def test_ssh_config_parsing(self):
def parse(conf, hostname='unmatched_host', username=''):
return load_config(overrides=conf.splitlines(), hostname=hostname, username=username)
self.ae(parse('').env, {})
self.ae(parse('env a=b').env, {'a': 'b'})
conf = 'env a=b\nhostname 2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'})
self.ae(parse('env a=').env, {'a': ''})
self.ae(parse('env a').env, {'a': '_delete_this_env_var_'})
conf = 'env a=b\nhostname test@2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '2').env, {'a': 'b'})
self.ae(parse(conf, '2', 'test').env, {'a': 'c', 'b': 'b'})
conf = 'env a=b\nhostname 1 2\nenv a=c\nenv b=b'
self.ae(parse(conf).env, {'a': 'b'})
self.ae(parse(conf, '1').env, {'a': 'c', 'b': 'b'})
self.ae(parse(conf, '2').env, {'a': 'c', 'b': 'b'})
def test_ssh_bootstrap_sh_cmd_limit(self):
# dropbear has a 9000 bytes maximum command length limit
sh_script, _, _ = bootstrap_script(SSHOptions({'interpreter': 'sh'}), script_type='sh', remote_args=[], request_id='123-123')
rcmd = wrap_bootstrap_script(sh_script, 'sh')
self.assertLessEqual(sum(len(x) for x in rcmd), 9000)
@property @property
@lru_cache() @lru_cache()
def all_possible_sh(self): def all_possible_sh(self):
@ -98,11 +70,13 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
f.write(simple_data) f.write(simple_data)
for sh in self.all_possible_sh: for sh in self.all_possible_sh:
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as remote_home, tempfile.TemporaryDirectory() as local_home, set_paths(home=local_home): with self.subTest(sh=sh), tempfile.TemporaryDirectory() as remote_home, tempfile.TemporaryDirectory() as local_home:
tuple(map(touch, 'simple-file g.1 g.2'.split())) tuple(map(touch, 'simple-file g.1 g.2'.split()))
os.makedirs(f'{local_home}/d1/d2/d3') os.makedirs(f'{local_home}/d1/d2/d3')
touch('d1/d2/x') touch('d1/d2/x')
touch('d1/d2/w.exclude') touch('d1/d2/w.exclude')
os.mkdir(f'{local_home}/d1/r')
touch('d1/r/noooo')
os.symlink('d2/x', f'{local_home}/d1/y') os.symlink('d2/x', f'{local_home}/d1/y')
os.symlink('simple-file', f'{local_home}/s1') os.symlink('simple-file', f'{local_home}/s1')
os.symlink('simple-file', f'{local_home}/s2') os.symlink('simple-file', f'{local_home}/s2')
@ -110,15 +84,13 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
conf = '''\ conf = '''\
copy simple-file copy simple-file
copy s1 copy s1
copy --symlink-strategy=keep-name s2 copy --symlink-strategy=keep-path s2
copy --dest=a/sfa simple-file copy --dest=a/sfa simple-file
copy --glob g.* copy --glob g.*
copy --exclude */w.* d1 copy --exclude **/w.* --exclude **/r d1
''' '''
copy = load_config(overrides=filter(None, conf.splitlines())).copy
self.check_bootstrap( self.check_bootstrap(
sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', sh, remote_home, test_script='env; exit 0', SHELL_INTEGRATION_VALUE='', conf=conf, home=local_home,
ssh_opts={'copy': copy}
) )
tname = '.terminfo' tname = '.terminfo'
if os.path.exists('/usr/share/misc/terminfo.cdb'): if os.path.exists('/usr/share/misc/terminfo.cdb'):
@ -148,17 +120,18 @@ copy --exclude */w.* d1
self.ae(len(glob.glob(f'{remote_home}/{tname}/*/xterm-kitty')), 2) self.ae(len(glob.glob(f'{remote_home}/{tname}/*/xterm-kitty')), 2)
def test_ssh_env_vars(self): def test_ssh_env_vars(self):
tset = '$A-$(echo no)-`echo no2` !Q5 "something\nelse"' tset = '$A-$(echo no)-`echo no2` !Q5 "something else"'
for sh in self.all_possible_sh: for sh in self.all_possible_sh:
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir: with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
os.mkdir(os.path.join(tdir, 'cwd')) os.mkdir(os.path.join(tdir, 'cwd'))
conf = f'''
cwd $HOME/cwd
env A=AAA
env TSET={tset}
env COLORTERM
'''
pty = self.check_bootstrap( pty = self.check_bootstrap(
sh, tdir, test_script='env; pwd; exit 0', SHELL_INTEGRATION_VALUE='', sh, tdir, test_script='env; pwd; exit 0', SHELL_INTEGRATION_VALUE='', conf=conf
ssh_opts={'cwd': '$HOME/cwd', 'env': {
'A': 'AAA',
'TSET': tset,
'COLORTERM': DELETE_ENV_VAR,
}}
) )
pty.wait_till(lambda: 'TSET={}'.format(tset.replace('$A', 'AAA')) in pty.screen_contents()) pty.wait_till(lambda: 'TSET={}'.format(tset.replace('$A', 'AAA')) in pty.screen_contents())
self.assertNotIn('COLORTERM', pty.screen_contents()) self.assertNotIn('COLORTERM', pty.screen_contents())
@ -240,34 +213,34 @@ copy --exclude */w.* d1
self.assertEqual(pty.screen.cursor.shape, 0) self.assertEqual(pty.screen.cursor.shape, 0)
self.assertNotIn(b'\x1b]133;', pty.received_bytes) self.assertNotIn(b'\x1b]133;', pty.received_bytes)
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', ssh_opts=None, launcher='sh'): def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', test_script='', pre_data='', conf='', launcher='sh', home=''):
ssh_opts = ssh_opts or {}
if login_shell: if login_shell:
ssh_opts['login_shell'] = login_shell conf += f'\nlogin_shell {login_shell}'
if 'python' in sh: if 'python' in sh:
if test_script.startswith('env;'): if test_script.startswith('env;'):
test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})' test_script = f'os.execlp("sh", "sh", "-c", {test_script!r})'
test_script = f'print("UNTAR_DONE", flush=True); {test_script}' test_script = f'print("UNTAR_DONE", flush=True); {test_script}'
else: else:
test_script = f'echo "UNTAR_DONE"; {test_script}' test_script = f'echo "UNTAR_DONE"; {test_script}'
ssh_opts['shell_integration'] = SHELL_INTEGRATION_VALUE or 'disabled' conf += '\nshell_integration ' + (SHELL_INTEGRATION_VALUE or 'disabled')
script, replacements, shm_name = bootstrap_script( conf += '\ninterpreter ' + sh
SSHOptions(ssh_opts), script_type='py' if 'python' in sh else 'sh', request_id="testing", test_script=test_script, env = os.environ.copy()
request_data=True if home:
) env['HOME'] = home
cp = subprocess.run([kitten_exe(), '__pytest__', 'ssh', test_script], env=env, stdout=subprocess.PIPE, input=conf.encode('utf-8'))
self.assertEqual(cp.returncode, 0)
self.rdata = json.loads(cp.stdout)
del cp
try: try:
env = basic_shell_env(home_dir) env = basic_shell_env(home_dir)
# Avoid generating unneeded completion scripts # Avoid generating unneeded completion scripts
os.makedirs(os.path.join(home_dir, '.local', 'share', 'fish', 'generated_completions'), exist_ok=True) os.makedirs(os.path.join(home_dir, '.local', 'share', 'fish', 'generated_completions'), exist_ok=True)
# prevent newuser-install from running # prevent newuser-install from running
open(os.path.join(home_dir, '.zshrc'), 'w').close() open(os.path.join(home_dir, '.zshrc'), 'w').close()
cmd = wrap_bootstrap_script(script, sh) pty = self.create_pty([launcher, '-c', ' '.join(self.rdata['cmd'])], cwd=home_dir, env=env)
pty = self.create_pty([launcher, '-c', ' '.join(cmd)], cwd=home_dir, env=env)
pty.turn_off_echo() pty.turn_off_echo()
del cmd
if pre_data: if pre_data:
pty.write_buf = pre_data.encode('utf-8') pty.write_buf = pre_data.encode('utf-8')
del script
def check_untar_or_fail(): def check_untar_or_fail():
q = pty.screen_contents() q = pty.screen_contents()
@ -284,4 +257,4 @@ copy --exclude */w.* d1
return pty return pty
finally: finally:
with suppress(FileNotFoundError): with suppress(FileNotFoundError):
shm_unlink(shm_name) shm_unlink(self.rdata['shm_name'])

View File

@ -1,5 +1,5 @@
[tool.mypy] [tool.mypy]
files = 'kitty,kittens,glfw,*.py,docs/conf.py,shell-integration/ssh/askpass.py' files = 'kitty,kittens,glfw,*.py,docs/conf.py'
no_implicit_optional = true no_implicit_optional = true
sqlite_cache = true sqlite_cache = true
cache_fine_grained = true cache_fine_grained = true

View File

@ -1459,7 +1459,7 @@ def package(args: Options, bundle_type: str) -> None:
if path.endswith('.so'): if path.endswith('.so'):
return True return True
q = path.split(os.sep)[-2:] q = path.split(os.sep)[-2:]
if len(q) == 2 and q[0] == 'ssh' and q[1] in ('askpass.py', 'kitty', 'kitten'): if len(q) == 2 and q[0] == 'ssh' and q[1] in ('kitty', 'kitten'):
return True return True
return False return False

View File

@ -1,46 +0,0 @@
#!/usr/bin/env -S kitty +launch
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import json
import os
import sys
import time
from kitty.shm import SharedMemory
msg = sys.argv[-1]
prompt = os.environ.get('SSH_ASKPASS_PROMPT', '')
is_confirm = prompt == 'confirm'
is_fingerprint_check = '(yes/no/[fingerprint])' in msg
q = {
'message': msg,
'type': 'confirm' if is_confirm else 'get_line',
'is_password': not is_fingerprint_check,
}
data = json.dumps(q)
with SharedMemory(
size=len(data) + 1 + SharedMemory.num_bytes_for_size, unlink_on_exit=True, prefix=f'askpass-{os.getpid()}-') as shm, \
open(os.ctermid(), 'wb') as tty:
shm.write(b'\0')
shm.write_data_with_size(data)
shm.flush()
with open(os.ctermid(), 'wb') as f:
f.write(f'\x1bP@kitty-ask|{shm.name}\x1b\\'.encode('ascii'))
f.flush()
while True:
# TODO: Replace sleep() with a mutex and condition variable created in the shared memory
time.sleep(0.05)
shm.seek(0)
if shm.read(1) == b'\x01':
break
response = json.loads(shm.read_data_with_size())
if is_confirm:
response = 'yes' if response else 'no'
elif is_fingerprint_check:
if response.lower() in ('y', 'yes'):
response = 'yes'
if response.lower() in ('n', 'no'):
response = 'no'
if response:
print(response, flush=True)

View File

@ -24,7 +24,7 @@ exec_kitty() {
is_wrapped_kitten() { is_wrapped_kitten() {
wrapped_kittens="clipboard icat unicode_input" wrapped_kittens="clipboard icat unicode_input ssh"
[ -n "$1" ] && { [ -n "$1" ] && {
case " $wrapped_kittens " in case " $wrapped_kittens " in
*" $1 "*) printf "%s" "$1" ;; *" $1 "*) printf "%s" "$1" ;;

View File

@ -33,9 +33,11 @@ type Command struct {
ArgCompleter CompletionFunc ArgCompleter CompletionFunc
// Stop completion processing at this arg num // Stop completion processing at this arg num
StopCompletingAtArg int StopCompletingAtArg int
// Consider all args as non-options args // Consider all args as non-options args when parsing for completion
OnlyArgsAllowed bool OnlyArgsAllowed bool
// Specialised arg aprsing // Pass through all args, useful for wrapper commands
IgnoreAllArgs bool
// Specialised arg parsing
ParseArgsForCompletion func(cmd *Command, args []string, completions *Completions) ParseArgsForCompletion func(cmd *Command, args []string, completions *Completions)
SubCommandGroups []*CommandGroup SubCommandGroups []*CommandGroup

View File

@ -13,6 +13,10 @@ func (self *Command) parse_args(ctx *Context, args []string) error {
args_to_parse := make([]string, len(args)) args_to_parse := make([]string, len(args))
copy(args_to_parse, args) copy(args_to_parse, args)
ctx.SeenCommands = append(ctx.SeenCommands, self) ctx.SeenCommands = append(ctx.SeenCommands, self)
if self.IgnoreAllArgs {
self.Args = args
return nil
}
var expecting_arg_for *Option var expecting_arg_for *Option
options_allowed := true options_allowed := true

View File

@ -66,8 +66,8 @@ func complete_plus_open(completions *cli.Completions, word string, arg_num int)
} }
func complete_themes(completions *cli.Completions, word string, arg_num int) { func complete_themes(completions *cli.Completions, word string, arg_num int) {
kitty, err := utils.KittyExe() kitty := utils.KittyExe()
if err == nil { if kitty != "" {
out, err := exec.Command(kitty, "+runpy", "from kittens.themes.collection import *; print_theme_names()").Output() out, err := exec.Command(kitty, "+runpy", "from kittens.themes.collection import *; print_theme_names()").Output()
if err == nil { if err == nil {
mg := completions.AddMatchGroup("Themes") mg := completions.AddMatchGroup("Themes")

View File

@ -14,7 +14,6 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
"kitty/tools/tui/graphics" "kitty/tools/tui/graphics"
"kitty/tools/utils" "kitty/tools/utils"
@ -24,12 +23,13 @@ import (
var _ = fmt.Print var _ = fmt.Print
var find_exe_lock sync.Once var MagickExe = (&utils.Once[string]{Run: func() string {
var magick_exe string = "" ans := utils.Which("magick")
if ans == "" {
func find_magick_exe() { ans = utils.Which("magick", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin", "/usr/sbin", "/sbin")
magick_exe = utils.Which("magick") }
} return ans
}}).Get
func run_magick(path string, cmd []string) ([]byte, error) { func run_magick(path string, cmd []string) ([]byte, error) {
c := exec.Command(cmd[0], cmd[1:]...) c := exec.Command(cmd[0], cmd[1:]...)
@ -154,10 +154,9 @@ func parse_identify_record(ans *IdentifyRecord, raw *IdentifyOutput) (err error)
} }
func Identify(path string) (ans []IdentifyRecord, err error) { func Identify(path string) (ans []IdentifyRecord, err error) {
find_exe_lock.Do(find_magick_exe)
cmd := []string{"identify"} cmd := []string{"identify"}
if magick_exe != "" { if MagickExe() != "" {
cmd = []string{magick_exe, cmd[0]} cmd = []string{MagickExe(), cmd[0]}
} }
q := `{"fmt":"%m","canvas":"%g","transparency":"%A","gap":"%T","index":"%p","size":"%wx%h",` + q := `{"fmt":"%m","canvas":"%g","transparency":"%A","gap":"%T","index":"%p","size":"%wx%h",` +
`"dpi":"%xx%y","dispose":"%D","orientation":"%[EXIF:Orientation]"},` `"dpi":"%xx%y","dispose":"%D","orientation":"%[EXIF:Orientation]"},`
@ -227,10 +226,9 @@ func check_resize(frame *image_frame) error {
} }
func Render(path string, ro *RenderOptions, frames []IdentifyRecord) (ans []*image_frame, err error) { func Render(path string, ro *RenderOptions, frames []IdentifyRecord) (ans []*image_frame, err error) {
find_exe_lock.Do(find_magick_exe)
cmd := []string{"convert"} cmd := []string{"convert"}
if magick_exe != "" { if MagickExe() != "" {
cmd = []string{magick_exe, cmd[0]} cmd = []string{MagickExe(), cmd[0]}
} }
ans = make([]*image_frame, 0, len(frames)) ans = make([]*image_frame, 0, len(frames))
defer func() { defer func() {

View File

@ -3,12 +3,22 @@
package main package main
import ( import (
"os"
"kitty/tools/cli" "kitty/tools/cli"
"kitty/tools/cmd/completion" "kitty/tools/cmd/completion"
"kitty/tools/cmd/ssh"
"kitty/tools/cmd/tool" "kitty/tools/cmd/tool"
) )
func main() { func main() {
krm := os.Getenv("KITTY_KITTEN_RUN_MODULE")
os.Unsetenv("KITTY_KITTEN_RUN_MODULE")
switch krm {
case "ssh_askpass":
ssh.RunSSHAskpass()
return
}
root := cli.NewRootCommand() root := cli.NewRootCommand()
root.ShortDescription = "Fast, statically compiled implementations for various kittens (command line tools for use with kitty)" root.ShortDescription = "Fast, statically compiled implementations for various kittens (command line tools for use with kitty)"
root.Usage = "command [command options] [command args]" root.Usage = "command [command options] [command args]"

22
tools/cmd/pytest/main.go Normal file
View File

@ -0,0 +1,22 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package pytest
import (
"fmt"
"kitty/tools/cli"
"kitty/tools/cmd/ssh"
"kitty/tools/utils/shm"
)
var _ = fmt.Print
func EntryPoint(root *cli.Command) {
root = root.AddSubCommand(&cli.Command{
Name: "__pytest__",
Hidden: true,
})
shm.TestEntryPoint(root)
ssh.TestEntryPoint(root)
}

118
tools/cmd/ssh/askpass.go Normal file
View File

@ -0,0 +1,118 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"encoding/binary"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"kitty/tools/cli"
"kitty/tools/tty"
"kitty/tools/utils/shm"
)
var _ = fmt.Print
func fatal(err error) {
cli.ShowError(err)
os.Exit(1)
}
func trigger_ask(name string) {
term, err := tty.OpenControllingTerm()
if err != nil {
fatal(err)
}
defer term.Close()
_, err = term.WriteString("\x1bP@kitty-ask|" + name + "\x1b\\")
if err != nil {
fatal(err)
}
}
func RunSSHAskpass() {
msg := os.Args[len(os.Args)-1]
prompt := os.Getenv("SSH_ASKPASS_PROMPT")
is_confirm := prompt == "confirm"
q_type := "get_line"
if is_confirm {
q_type = "confirm"
}
is_fingerprint_check := strings.Contains(msg, "(yes/no/[fingerprint])")
q := map[string]any{
"message": msg,
"type": q_type,
"is_password": !is_fingerprint_check,
}
data, err := json.Marshal(q)
if err != nil {
fatal(err)
}
shm, err := shm.CreateTemp("askpass-*", uint64(len(data)+32))
if err != nil {
fatal(fmt.Errorf("Failed to create SHM file with error: %w", err))
}
defer shm.Close()
defer shm.Unlink()
shm.Slice()[0] = 0
binary.BigEndian.PutUint32(shm.Slice()[1:], uint32(len(data)))
copy(shm.Slice()[5:], data)
err = shm.Flush()
if err != nil {
fatal(fmt.Errorf("Failed to flush SHM file with error: %w", err))
}
trigger_ask(shm.Name())
buf := []byte{0}
for {
time.Sleep(50 * time.Millisecond)
_, err = shm.Seek(0, os.SEEK_SET)
if err != nil {
fatal(fmt.Errorf("Failed to seek into SHM file while waiting for response with error: %w", err))
}
_, err = shm.Read(buf)
if err != nil {
fatal(fmt.Errorf("Failed to read from SHM file while waiting for response with error: %w", err))
}
if buf[0] == 1 {
break
}
}
data, err = shm.ReadWithSize()
if err != nil {
fatal(fmt.Errorf("Failed to read response data from SHM file with error: %w", err))
}
response := ""
if is_confirm {
var ok bool
err = json.Unmarshal(data, &ok)
if err != nil {
fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err))
}
response = "no"
if ok {
response = "yes"
}
} else {
err = json.Unmarshal(data, &response)
if err != nil {
fatal(fmt.Errorf("Failed to parse response data: %#v with error: %w", string(data), err))
}
if is_fingerprint_check {
response = strings.ToLower(response)
if response == "y" {
response = "yes"
} else if response == "n" {
response = "no"
}
}
}
if response != "" {
fmt.Println(response)
}
}

409
tools/cmd/ssh/config.go Normal file
View File

@ -0,0 +1,409 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"archive/tar"
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path"
"path/filepath"
"strings"
"time"
"kitty/tools/config"
"kitty/tools/utils"
"kitty/tools/utils/paths"
"kitty/tools/utils/shlex"
"github.com/bmatcuk/doublestar"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
type EnvInstruction struct {
key, val string
delete_on_remote, copy_from_local, literal_quote bool
}
func quote_for_sh(val string, literal_quote bool) string {
if literal_quote {
return utils.QuoteStringForSH(val)
}
// See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
b := strings.Builder{}
b.Grow(len(val) + 16)
b.WriteRune('"')
runes := []rune(val)
for i, ch := range runes {
if ch == '\\' || ch == '`' || ch == '"' || (ch == '$' && i+1 < len(runes) && runes[i+1] == '(') {
// special chars are escaped
// $( is escaped to prevent execution
b.WriteRune('\\')
}
b.WriteRune(ch)
}
b.WriteRune('"')
return b.String()
}
func (self *EnvInstruction) Serialize(for_python bool, get_local_env func(string) (string, bool)) string {
var unset func() string
var export func(string) string
if for_python {
dumps := func(x ...any) string {
ans, _ := json.Marshal(x)
return utils.UnsafeBytesToString(ans)
}
export = func(val string) string {
if val == "" {
return fmt.Sprintf("export %s", dumps(self.key))
}
return fmt.Sprintf("export %s", dumps(self.key, val, self.literal_quote))
}
unset = func() string {
return fmt.Sprintf("unset %s", dumps(self.key))
}
} else {
kq := utils.QuoteStringForSH(self.key)
unset = func() string {
return fmt.Sprintf("unset %s", kq)
}
export = func(val string) string {
return fmt.Sprintf("export %s=%s", kq, quote_for_sh(val, self.literal_quote))
}
}
if self.delete_on_remote {
return unset()
}
if self.copy_from_local {
val, found := get_local_env(self.key)
if !found {
return ""
}
return export(val)
}
return export(self.val)
}
func final_env_instructions(for_python bool, get_local_env func(string) (string, bool), env ...*EnvInstruction) string {
seen := make(map[string]int, len(env))
ans := make([]string, 0, len(env))
for _, ei := range env {
q := ei.Serialize(for_python, get_local_env)
if q != "" {
if pos, found := seen[ei.key]; found {
ans[pos] = q
} else {
seen[ei.key] = len(ans)
ans = append(ans, q)
}
}
}
return strings.Join(ans, "\n")
}
type CopyInstruction struct {
local_path, arcname string
exclude_patterns []string
}
func ParseEnvInstruction(spec string) (ans []*EnvInstruction, err error) {
const COPY_FROM_LOCAL string = "_kitty_copy_env_var_"
ei := &EnvInstruction{}
found := false
ei.key, ei.val, found = strings.Cut(spec, "=")
ei.key = strings.TrimSpace(ei.key)
if found {
ei.val = strings.TrimSpace(ei.val)
if ei.val == COPY_FROM_LOCAL {
ei.val = ""
ei.copy_from_local = true
}
} else {
ei.delete_on_remote = true
}
if ei.key == "" {
err = fmt.Errorf("The env directive must not be empty")
}
ans = []*EnvInstruction{ei}
return
}
var paths_ctx *paths.Ctx
func resolve_file_spec(spec string, is_glob bool) ([]string, error) {
if paths_ctx == nil {
paths_ctx = &paths.Ctx{}
}
ans := os.ExpandEnv(paths_ctx.ExpandHome(spec))
if !filepath.IsAbs(ans) {
ans = paths_ctx.AbspathFromHome(ans)
}
if is_glob {
files, err := doublestar.Glob(ans)
if err != nil {
return nil, fmt.Errorf("%s is not a valid glob pattern with error: %w", spec, err)
}
if len(files) == 0 {
return nil, fmt.Errorf("%s matches no files", spec)
}
return files, nil
}
err := unix.Access(ans, unix.R_OK)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("%s does not exist", spec)
}
return nil, fmt.Errorf("Cannot read from: %s with error: %w", spec, err)
}
return []string{ans}, nil
}
func get_arcname(loc, dest, home string) (arcname string) {
if dest != "" {
arcname = dest
} else {
arcname = filepath.Clean(loc)
if filepath.HasPrefix(arcname, home) {
ra, err := filepath.Rel(home, arcname)
if err == nil {
arcname = ra
}
}
}
prefix := "home/"
if strings.HasPrefix(arcname, "/") {
prefix = "root"
}
return prefix + arcname
}
func ParseCopyInstruction(spec string) (ans []*CopyInstruction, err error) {
args, err := shlex.Split("copy " + spec)
if err != nil {
return nil, err
}
opts, args, err := parse_copy_args(args)
if err != nil {
return nil, err
}
locations := make([]string, 0, len(args))
for _, arg := range args {
locs, err := resolve_file_spec(arg, opts.Glob)
if err != nil {
return nil, err
}
locations = append(locations, locs...)
}
if len(locations) == 0 {
return nil, fmt.Errorf("No files to copy specified")
}
if len(locations) > 1 && opts.Dest != "" {
return nil, fmt.Errorf("Specifying a remote location with more than one file is not supported")
}
home := paths_ctx.HomePath()
ans = make([]*CopyInstruction, 0, len(locations))
for _, loc := range locations {
ci := CopyInstruction{local_path: loc, exclude_patterns: opts.Exclude}
if opts.SymlinkStrategy != "preserve" {
ci.local_path, err = filepath.EvalSymlinks(loc)
if err != nil {
return nil, fmt.Errorf("Failed to resolve symlinks in %#v with error: %w", loc, err)
}
}
if opts.SymlinkStrategy == "resolve" {
ci.arcname = get_arcname(ci.local_path, opts.Dest, home)
} else {
ci.arcname = get_arcname(loc, opts.Dest, home)
}
ans = append(ans, &ci)
}
return
}
type file_unique_id struct {
dev, inode uint64
}
func excluded(pattern, path string) bool {
if !strings.ContainsRune(pattern, '/') {
path = filepath.Base(path)
}
if matched, err := doublestar.PathMatch(pattern, path); matched && err == nil {
return true
}
return false
}
func get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string, local_path, arcname string, exclude_patterns []string) error {
s, err := os.Lstat(local_path)
if err != nil {
return err
}
u, ok := s.Sys().(unix.Stat_t)
cb := func(h *tar.Header, data []byte, arcname string) error {
h.Name = arcname
if h.Typeflag == tar.TypeDir {
h.Name = strings.TrimRight(h.Name, "/") + "/"
}
h.Size = int64(len(data))
h.Mode = int64(s.Mode().Perm())
h.ModTime = s.ModTime()
h.Format = tar.FormatPAX
if ok {
h.AccessTime = time.Unix(0, u.Atim.Nano())
h.ChangeTime = time.Unix(0, u.Ctim.Nano())
}
return callback(h, data)
}
// we only copy regular files, directories and symlinks
switch s.Mode().Type() {
case fs.ModeSymlink:
target, err := os.Readlink(local_path)
if err != nil {
return err
}
err = cb(&tar.Header{
Typeflag: tar.TypeSymlink,
Linkname: target,
}, nil, arcname)
if err != nil {
return err
}
case fs.ModeDir:
local_path = filepath.Clean(local_path)
type entry struct {
path, arcname string
}
stack := []entry{{local_path, arcname}}
for len(stack) > 0 {
x := stack[0]
stack = stack[1:]
entries, err := os.ReadDir(x.path)
if err != nil {
if x.path == local_path {
return err
}
continue
}
err = cb(&tar.Header{Typeflag: tar.TypeDir}, nil, x.arcname)
if err != nil {
return err
}
for _, e := range entries {
entry_path := filepath.Join(x.path, e.Name())
aname := path.Join(x.arcname, e.Name())
ok := true
for _, pat := range exclude_patterns {
if excluded(pat, entry_path) {
ok = false
break
}
}
if !ok {
continue
}
if e.IsDir() {
stack = append(stack, entry{entry_path, aname})
} else {
err = get_file_data(callback, seen, entry_path, aname, exclude_patterns)
if err != nil {
return err
}
}
}
}
case 0: // Regular file
fid := file_unique_id{dev: u.Dev, inode: u.Ino}
if prev, ok := seen[fid]; ok { // Hard link
err = cb(&tar.Header{Typeflag: tar.TypeLink, Linkname: prev}, nil, arcname)
if err != nil {
return err
}
}
seen[fid] = arcname
data, err := os.ReadFile(local_path)
if err != nil {
return err
}
err = cb(&tar.Header{Typeflag: tar.TypeReg}, data, arcname)
if err != nil {
return err
}
}
return nil
}
func (ci *CopyInstruction) get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string) (err error) {
ep := ci.exclude_patterns
for _, folder_name := range []string{"__pycache__", ".DS_Store"} {
ep = append(ep, "**/"+folder_name, "**/"+folder_name+"/**")
}
return get_file_data(callback, seen, ci.local_path, ci.arcname, ep)
}
type ConfigSet struct {
all_configs []*Config
}
func config_for_hostname(hostname_to_match, username_to_match string, cs *ConfigSet) *Config {
matcher := func(q *Config) bool {
for _, pat := range strings.Split(q.Hostname, " ") {
upat := "*"
if strings.Contains(pat, "@") {
upat, pat, _ = strings.Cut(pat, "@")
}
var host_matched, user_matched bool
if matched, err := filepath.Match(pat, hostname_to_match); matched && err == nil {
host_matched = true
}
if matched, err := filepath.Match(upat, username_to_match); matched && err == nil {
user_matched = true
}
if host_matched && user_matched {
return true
}
}
return false
}
for _, c := range utils.Reversed(cs.all_configs) {
if matcher(c) {
return c
}
}
return cs.all_configs[0]
}
func (self *ConfigSet) line_handler(key, val string) error {
c := self.all_configs[len(self.all_configs)-1]
if key == "hostname" {
c = NewConfig()
self.all_configs = append(self.all_configs, c)
}
return c.Parse(key, val)
}
func load_config(hostname_to_match string, username_to_match string, overrides []string, paths ...string) (*Config, []config.ConfigLine, error) {
ans := &ConfigSet{all_configs: []*Config{NewConfig()}}
p := config.ConfigParser{LineHandler: ans.line_handler}
if len(paths) == 0 {
paths = []string{filepath.Join(utils.ConfigDir(), "ssh.conf")}
}
paths = utils.Filter(paths, func(x string) bool { return x != "" })
err := p.ParseFiles(paths...)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, nil, err
}
if len(overrides) > 0 {
err = p.ParseOverrides(overrides...)
if err != nil {
return nil, nil, err
}
}
return config_for_hostname(hostname_to_match, username_to_match, ans), p.BadLines(), nil
}

View File

@ -0,0 +1,110 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"fmt"
"kitty/tools/utils"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestSSHConfigParsing(t *testing.T) {
tdir := t.TempDir()
hostname := "unmatched"
username := ""
conf := ""
for_python := false
cf := filepath.Join(tdir, "ssh.conf")
rt := func(expected_env ...string) {
os.WriteFile(cf, []byte(conf), 0o600)
c, bad_lines, err := load_config(hostname, username, nil, cf)
if err != nil {
t.Fatal(err)
}
if len(bad_lines) != 0 {
t.Fatalf("Bad config line: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err)
}
actual := final_env_instructions(for_python, func(key string) (string, bool) {
if key == "LOCAL_ENV" {
return "LOCAL_VAL", true
}
return "", false
}, c.Env...)
if expected_env == nil {
expected_env = []string{}
}
diff := cmp.Diff(expected_env, utils.Splitlines(actual))
if diff != "" {
t.Fatalf("Unexpected env for\nhostname: %#v\nusername: %#v\nconf: %s\n%s", hostname, username, conf, diff)
}
}
rt()
conf = "env a=b"
rt(`export 'a'="b"`)
conf = "env a=b\nhostname 2\nenv a=c\nenv b=b"
rt(`export 'a'="b"`)
hostname = "2"
rt(`export 'a'="c"`, `export 'b'="b"`)
conf = "env a="
rt(`export 'a'=""`)
conf = "env a"
rt(`unset 'a'`)
conf = "env a=b\nhostname test@2\nenv a=c\nenv b=b"
hostname = "unmatched"
rt(`export 'a'="b"`)
hostname = "2"
rt(`export 'a'="b"`)
username = "test"
rt(`export 'a'="c"`, `export 'b'="b"`)
conf = "env a=b\nhostname 1 2\nenv a=c\nenv b=b"
username = ""
hostname = "unmatched"
rt(`export 'a'="b"`)
hostname = "1"
rt(`export 'a'="c"`, `export 'b'="b"`)
hostname = "2"
rt(`export 'a'="c"`, `export 'b'="b"`)
for_python = true
rt(`export ["a","c",false]`, `export ["b","b",false]`)
conf = "env a="
rt(`export ["a"]`)
conf = "env a"
rt(`unset ["a"]`)
conf = "env LOCAL_ENV=_kitty_copy_env_var_"
rt(`export ["LOCAL_ENV","LOCAL_VAL",false]`)
ci, err := ParseCopyInstruction("--exclude moose --dest=target " + cf)
if err != nil {
t.Fatal(err)
}
diff := cmp.Diff("home/target", ci[0].arcname)
if diff != "" {
t.Fatalf("Incorrect arcname:\n%s", diff)
}
diff = cmp.Diff(cf, ci[0].local_path)
if diff != "" {
t.Fatalf("Incorrect local_path:\n%s", diff)
}
diff = cmp.Diff([]string{"moose"}, ci[0].exclude_patterns)
if diff != "" {
t.Fatalf("Incorrect excludes:\n%s", diff)
}
ci, err = ParseCopyInstruction("--glob " + filepath.Join(filepath.Dir(cf), "*.conf"))
if err != nil {
t.Fatal(err)
}
diff = cmp.Diff(cf, ci[0].local_path)
if diff != "" {
t.Fatalf("Incorrect local_path:\n%s", diff)
}
if len(ci) != 1 {
t.Fatal(ci)
}
}

69
tools/cmd/ssh/data.go Normal file
View File

@ -0,0 +1,69 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"archive/tar"
_ "embed"
"errors"
"fmt"
"io"
"kitty/tools/utils"
"regexp"
"strings"
)
var _ = fmt.Print
//go:embed data_generated.bin
var embedded_data string
type Entry struct {
metadata *tar.Header
data []byte
}
type Container map[string]Entry
var Data = (&utils.Once[Container]{Run: func() Container {
tr := tar.NewReader(utils.ReaderForCompressedEmbeddedData(embedded_data))
ans := make(Container, 64)
for {
hdr, err := tr.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
panic(err)
}
data, err := utils.ReadAll(tr, int(hdr.Size))
if err != nil {
panic(err)
}
ans[hdr.Name] = Entry{hdr, data}
}
return ans
}}).Get
func (self Container) files_matching(prefix string, exclude_patterns ...string) []string {
ans := make([]string, 0, len(self))
patterns := make([]*regexp.Regexp, len(exclude_patterns))
for i, exp := range exclude_patterns {
patterns[i] = regexp.MustCompile(exp)
}
for name := range self {
if strings.HasPrefix(name, prefix) {
excluded := false
for _, pat := range patterns {
if matched := pat.FindString(name); matched != "" {
excluded = true
break
}
}
if !excluded {
ans = append(ans, name)
}
}
}
return ans
}

801
tools/cmd/ssh/main.go Normal file
View File

@ -0,0 +1,801 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"archive/tar"
"bytes"
"compress/gzip"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"kitty"
"net/url"
"os"
"os/exec"
"os/user"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"kitty/tools/cli"
"kitty/tools/themes"
"kitty/tools/tty"
"kitty/tools/tui"
"kitty/tools/tui/loop"
"kitty/tools/utils"
"kitty/tools/utils/secrets"
"kitty/tools/utils/shm"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
func get_destination(hostname string) (username, hostname_for_match string) {
u, err := user.Current()
if err == nil {
username = u.Username
}
hostname_for_match = hostname
if strings.HasPrefix(hostname, "ssh://") {
p, err := url.Parse(hostname)
if err == nil {
hostname_for_match = p.Hostname()
if p.User.Username() != "" {
username = p.User.Username()
}
}
} else if strings.Contains(hostname, "@") && hostname[0] != '@' {
username, hostname_for_match, _ = strings.Cut(hostname, "@")
}
if strings.Contains(hostname, "@") && hostname[0] != '@' {
_, hostname_for_match, _ = strings.Cut(hostname_for_match, "@")
}
hostname_for_match, _, _ = strings.Cut(hostname_for_match, ":")
return
}
func read_data_from_shared_memory(shm_name string) ([]byte, error) {
data, err := shm.ReadWithSizeAndUnlink(shm_name, func(f *os.File) error {
s, err := f.Stat()
if err != nil {
return fmt.Errorf("Failed to stat SHM file with error: %w", err)
}
if stat, ok := s.Sys().(unix.Stat_t); ok {
if os.Getuid() != int(stat.Uid) || os.Getgid() != int(stat.Gid) {
return fmt.Errorf("Incorrect owner on SHM file")
}
}
if s.Mode().Perm() != 0o600 {
return fmt.Errorf("Incorrect permissions on SHM file")
}
return nil
})
return data, err
}
func add_cloned_env(val string) (ans map[string]string, err error) {
data, err := read_data_from_shared_memory(val)
if err != nil {
return nil, err
}
err = json.Unmarshal(data, &ans)
return ans, err
}
func parse_kitten_args(found_extra_args []string, username, hostname_for_match string) (overrides []string, literal_env map[string]string, ferr error) {
literal_env = make(map[string]string)
overrides = make([]string, 0, 4)
for i, a := range found_extra_args {
if i%2 == 0 {
continue
}
if key, val, found := strings.Cut(a, "="); found {
if key == "clone_env" {
le, err := add_cloned_env(val)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return nil, nil, ferr
}
} else if le != nil {
literal_env = le
}
} else if key != "hostname" {
overrides = append(overrides, key+" "+val)
}
}
}
if len(overrides) > 0 {
overrides = append([]string{"hostname " + username + "@" + hostname_for_match}, overrides...)
}
return
}
func connection_sharing_args(kitty_pid int) ([]string, error) {
rd := utils.RuntimeDir()
// Bloody OpenSSH generates a 40 char hash and in creating the socket
// appends a 27 char temp suffix to it. Socket max path length is approx
// ~104 chars. And on idiotic Apple the path length to the runtime dir
// (technically the cache dir since Apple has no runtime dir and thinks it's
// a great idea to delete files in /tmp) is ~48 chars.
if len(rd) > 35 {
idiotic_design := fmt.Sprintf("/tmp/kssh-rdir-%d", os.Geteuid())
if err := utils.AtomicCreateSymlink(rd, idiotic_design); err != nil {
return nil, err
}
rd = idiotic_design
}
cp := strings.Replace(kitty.SSHControlMasterTemplate, "{kitty_pid}", strconv.Itoa(kitty_pid), 1)
cp = strings.Replace(cp, "{ssh_placeholder}", "%C", 1)
return []string{
"-o", "ControlMaster=auto",
"-o", "ControlPath=" + filepath.Join(rd, cp),
"-o", "ControlPersist=yes",
"-o", "ServerAliveInterval=60",
"-o", "ServerAliveCountMax=5",
"-o", "TCPKeepAlive=no",
}, nil
}
func set_askpass() (need_to_request_data bool) {
need_to_request_data = true
sentinel := filepath.Join(utils.CacheDir(), "openssh-is-new-enough-for-askpass")
_, err := os.Stat(sentinel)
sentinel_exists := err == nil
if sentinel_exists || GetSSHVersion().SupportsAskpassRequire() {
if !sentinel_exists {
os.WriteFile(sentinel, []byte{0}, 0o644)
}
need_to_request_data = false
}
exe, err := os.Executable()
if err == nil {
os.Setenv("SSH_ASKPASS", exe)
os.Setenv("KITTY_KITTEN_RUN_MODULE", "ssh_askpass")
if !need_to_request_data {
os.Setenv("SSH_ASKPASS_REQUIRE", "force")
}
} else {
need_to_request_data = true
}
return
}
type connection_data struct {
remote_args []string
host_opts *Config
hostname_for_match string
username string
echo_on bool
request_data bool
literal_env map[string]string
test_script string
shm_name string
script_type string
rcmd []string
replacements map[string]string
request_id string
bootstrap_script string
}
func get_effective_ksi_env_var(x string) string {
parts := strings.Split(strings.TrimSpace(strings.ToLower(x)), " ")
current := utils.NewSetWithItems(parts...)
if current.Has("disabled") {
return ""
}
allowed := utils.NewSetWithItems(kitty.AllowedShellIntegrationValues...)
if !current.IsSubsetOf(allowed) {
return RelevantKittyOpts().Shell_integration
}
return x
}
func serialize_env(cd *connection_data, get_local_env func(string) (string, bool)) (string, string) {
ksi := ""
if cd.host_opts.Shell_integration == "inherited" {
ksi = get_effective_ksi_env_var(RelevantKittyOpts().Shell_integration)
} else {
ksi = get_effective_ksi_env_var(cd.host_opts.Shell_integration)
}
env := make([]*EnvInstruction, 0, 8)
add_env := func(key, val string, fallback ...string) *EnvInstruction {
if val == "" && len(fallback) > 0 {
val = fallback[0]
}
if val != "" {
env = append(env, &EnvInstruction{key: key, val: val, literal_quote: true})
return env[len(env)-1]
}
return nil
}
add_non_literal_env := func(key, val string, fallback ...string) *EnvInstruction {
ans := add_env(key, val, fallback...)
if ans != nil {
ans.literal_quote = false
}
return ans
}
for k, v := range cd.literal_env {
add_env(k, v)
}
add_env("TERM", os.Getenv("TERM"), RelevantKittyOpts().Term)
add_env("COLORTERM", "truecolor")
env = append(env, cd.host_opts.Env...)
add_env("KITTY_WINDOW_ID", os.Getenv("KITTY_WINDOW_ID"))
add_env("WINDOWID", os.Getenv("WINDOWID"))
if ksi != "" {
add_env("KITTY_SHELL_INTEGRATION", ksi)
} else {
env = append(env, &EnvInstruction{key: "KITTY_SHELL_INTEGRATION", delete_on_remote: true})
}
add_non_literal_env("KITTY_SSH_KITTEN_DATA_DIR", cd.host_opts.Remote_dir)
add_non_literal_env("KITTY_LOGIN_SHELL", cd.host_opts.Login_shell)
add_non_literal_env("KITTY_LOGIN_CWD", cd.host_opts.Cwd)
if cd.host_opts.Remote_kitty != Remote_kitty_no {
add_env("KITTY_REMOTE", cd.host_opts.Remote_kitty.String())
}
add_env("KITTY_PUBLIC_KEY", os.Getenv("KITTY_PUBLIC_KEY"))
return final_env_instructions(cd.script_type == "py", get_local_env, env...), ksi
}
func make_tarfile(cd *connection_data, get_local_env func(string) (string, bool)) ([]byte, error) {
env_script, ksi := serialize_env(cd, get_local_env)
w := bytes.Buffer{}
w.Grow(64 * 1024)
gw, err := gzip.NewWriterLevel(&w, gzip.BestCompression)
if err != nil {
return nil, err
}
tw := tar.NewWriter(gw)
rd := strings.TrimRight(cd.host_opts.Remote_dir, "/")
seen := make(map[file_unique_id]string, 32)
add := func(h *tar.Header, data []byte) (err error) {
// some distro's like nix mess with installed file permissions so ensure
// files are at least readable and writable by owning user
h.Mode |= 0o600
err = tw.WriteHeader(h)
if err != nil {
return
}
if data != nil {
_, err := tw.Write(data)
if err != nil {
return err
}
}
return
}
for _, ci := range cd.host_opts.Copy {
err = ci.get_file_data(add, seen)
if err != nil {
return nil, err
}
}
type fe struct {
arcname string
data []byte
}
now := time.Now()
add_data := func(items ...fe) error {
for _, item := range items {
err := add(
&tar.Header{
Typeflag: tar.TypeReg, Name: item.arcname, Format: tar.FormatPAX, Size: int64(len(item.data)),
Mode: 0o644, ModTime: now, ChangeTime: now, AccessTime: now,
}, item.data)
if err != nil {
return err
}
}
return nil
}
add_entries := func(prefix string, items ...Entry) error {
for _, item := range items {
err := add(
&tar.Header{
Typeflag: item.metadata.Typeflag, Name: path.Join(prefix, path.Base(item.metadata.Name)), Format: tar.FormatPAX,
Size: int64(len(item.data)), Mode: item.metadata.Mode, ModTime: item.metadata.ModTime,
AccessTime: item.metadata.AccessTime, ChangeTime: item.metadata.ChangeTime,
}, item.data)
if err != nil {
return err
}
}
return nil
}
add_data(fe{"data.sh", utils.UnsafeStringToBytes(env_script)})
if cd.script_type == "sh" {
add_data(fe{"bootstrap-utils.sh", Data()[path.Join("shell-integration/ssh/bootstrap-utils.sh")].data})
}
if ksi != "" {
for _, fname := range Data().files_matching(
"shell-integration/",
"shell-integration/ssh/.+", // bootstrap files are sent as command line args
"shell-integration/zsh/kitty.zsh", // backward compat file not needed by ssh kitten
) {
arcname := path.Join("home/", rd, "/", path.Dir(fname))
err = add_entries(arcname, Data()[fname])
if err != nil {
return nil, err
}
}
}
if cd.host_opts.Remote_kitty != Remote_kitty_no {
arcname := path.Join("home/", rd, "/kitty")
err = add_data(fe{arcname + "/version", utils.UnsafeStringToBytes(kitty.VersionString)})
if err != nil {
return nil, err
}
for _, x := range []string{"kitty", "kitten"} {
err = add_entries(path.Join(arcname, "bin"), Data()[path.Join("shell-integration", "ssh", x)])
if err != nil {
return nil, err
}
}
}
err = add_entries(path.Join("home", ".terminfo"), Data()["terminfo/kitty.terminfo"])
if err == nil {
err = add_entries(path.Join("home", ".terminfo", "x"), Data()["terminfo/x/xterm-kitty"])
}
if err == nil {
err = tw.Close()
if err == nil {
err = gw.Close()
}
}
return w.Bytes(), err
}
func prepare_home_command(cd *connection_data) string {
is_python := cd.script_type == "py"
homevar := ""
for _, ei := range cd.host_opts.Env {
if ei.key == "HOME" && !ei.delete_on_remote {
if ei.copy_from_local {
homevar = os.Getenv("HOME")
} else {
homevar = ei.val
}
}
}
export_home_cmd := ""
if homevar != "" {
if is_python {
export_home_cmd = base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(homevar))
} else {
export_home_cmd = fmt.Sprintf("export HOME=%s; cd \"$HOME\"", utils.QuoteStringForSH(homevar))
}
}
return export_home_cmd
}
func prepare_exec_cmd(cd *connection_data) string {
// ssh simply concatenates multiple commands using a space see
// line 1129 of ssh.c and on the remote side sshd.c runs the
// concatenated command as shell -c cmd
if cd.script_type == "py" {
return base64.RawStdEncoding.EncodeToString(utils.UnsafeStringToBytes(strings.Join(cd.remote_args, " ")))
}
args := make([]string, len(cd.remote_args))
for i, arg := range cd.remote_args {
args[i] = strings.ReplaceAll(arg, "'", "'\"'\"'")
}
return "unset KITTY_SHELL_INTEGRATION; exec \"$login_shell\" -c '" + strings.Join(args, " ") + "'"
}
var data_shm shm.MMap
func prepare_script(script string, replacements map[string]string) string {
if _, found := replacements["EXEC_CMD"]; !found {
replacements["EXEC_CMD"] = ""
}
if _, found := replacements["EXPORT_HOME_CMD"]; !found {
replacements["EXPORT_HOME_CMD"] = ""
}
keys := maps.Keys(replacements)
for i, key := range keys {
keys[i] = "\\b" + key + "\\b"
}
pat := regexp.MustCompile(strings.Join(keys, "|"))
return pat.ReplaceAllStringFunc(script, func(key string) string { return replacements[key] })
}
func bootstrap_script(cd *connection_data) (err error) {
if cd.request_id == "" {
cd.request_id = os.Getenv("KITTY_PID") + "-" + os.Getenv("KITTY_WINDOW_ID")
}
export_home_cmd := prepare_home_command(cd)
exec_cmd := ""
if len(cd.remote_args) > 0 {
exec_cmd = prepare_exec_cmd(cd)
}
pw, err := secrets.TokenHex()
if err != nil {
return err
}
tfd, err := make_tarfile(cd, os.LookupEnv)
if err != nil {
return err
}
data := map[string]string{
"tarfile": base64.StdEncoding.EncodeToString(tfd),
"pw": pw,
"hostname": cd.hostname_for_match, "username": cd.username,
}
encoded_data, err := json.Marshal(data)
if err == nil {
data_shm, err = shm.CreateTemp(fmt.Sprintf("kssh-%d-", os.Getpid()), uint64(len(encoded_data)+8))
if err == nil {
err = data_shm.WriteWithSize(encoded_data)
if err == nil {
err = data_shm.Flush()
}
}
}
if err != nil {
return err
}
cd.shm_name = data_shm.Name()
sensitive_data := map[string]string{"REQUEST_ID": cd.request_id, "DATA_PASSWORD": pw, "PASSWORD_FILENAME": cd.shm_name}
replacements := map[string]string{
"EXPORT_HOME_CMD": export_home_cmd,
"EXEC_CMD": exec_cmd,
"TEST_SCRIPT": cd.test_script,
}
add_bool := func(ok bool, key string) {
if ok {
replacements[key] = "1"
} else {
replacements[key] = "0"
}
}
add_bool(cd.request_data, "REQUEST_DATA")
add_bool(cd.echo_on, "ECHO_ON")
sd := maps.Clone(replacements)
if cd.request_data {
maps.Copy(sd, sensitive_data)
}
maps.Copy(replacements, sensitive_data)
cd.replacements = replacements
cd.bootstrap_script = utils.UnsafeBytesToString(Data()["shell-integration/ssh/bootstrap."+cd.script_type].data)
cd.bootstrap_script = prepare_script(cd.bootstrap_script, sd)
return err
}
func wrap_bootstrap_script(cd *connection_data) {
// sshd will execute the command we pass it by join all command line
// arguments with a space and passing it as a single argument to the users
// login shell with -c. If the user has a non POSIX login shell it might
// have different escaping semantics and syntax, so the command it should
// execute has to be as simple as possible, basically of the form
// interpreter -c unwrap_script escaped_bootstrap_script
// The unwrap_script is responsible for unescaping the bootstrap script and
// executing it.
encoded_script := ""
unwrap_script := ""
if cd.script_type == "py" {
encoded_script = base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(cd.bootstrap_script))
unwrap_script = `"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"`
} else {
// We cant rely on base64 being available on the remote system, so instead
// we quote the bootstrap script by replacing ' and \ with \v and \f
// also replacing \n and ! with \r and \b for tcsh
// finally surrounding with '
encoded_script = "'" + strings.NewReplacer("'", "\v", "\\", "\f", "\n", "\r", "!", "\b").Replace(cd.bootstrap_script) + "'"
unwrap_script = `'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' `
}
cd.rcmd = []string{"exec", cd.host_opts.Interpreter, "-c", unwrap_script, encoded_script}
}
func get_remote_command(cd *connection_data) error {
interpreter := cd.host_opts.Interpreter
q := strings.ToLower(path.Base(interpreter))
is_python := strings.Contains(q, "python")
cd.script_type = "sh"
if is_python {
cd.script_type = "py"
}
err := bootstrap_script(cd)
if err != nil {
return err
}
wrap_bootstrap_script(cd)
return nil
}
func drain_potential_tty_garbage(term *tty.Term) {
err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho)
if err != nil {
return
}
canary, err := secrets.TokenBase64()
if err != nil {
return
}
dcs, err := tui.DCSToKitty("echo", canary+"\n\r")
if err != nil {
return
}
err = term.WriteAllString(dcs)
if err != nil {
return
}
q := utils.UnsafeStringToBytes(canary)
data := make([]byte, 0)
give_up_at := time.Now().Add(2 * time.Second)
buf := make([]byte, 0, 8192)
for !bytes.Contains(data, q) {
buf = buf[:cap(buf)]
timeout := give_up_at.Sub(time.Now())
if timeout < 0 {
break
}
n, err := term.ReadWithTimeout(buf, timeout)
if err != nil {
return
}
data = append(data, buf[:n]...)
}
}
func change_colors(color_scheme string) (ans string, err error) {
if color_scheme == "" {
return
}
var theme *themes.Theme
if !strings.HasSuffix(color_scheme, ".conf") {
cs := os.ExpandEnv(color_scheme)
tc, closer, err := themes.LoadThemes(-1)
if err != nil && errors.Is(err, themes.ErrNoCacheFound) {
tc, closer, err = themes.LoadThemes(time.Hour * 24)
}
if err != nil {
return "", err
}
defer closer.Close()
theme = tc.ThemeByName(cs)
if theme == nil {
return "", fmt.Errorf("No theme named %#v found", cs)
}
} else {
theme, err = themes.ThemeFromFile(utils.ResolveConfPath(color_scheme))
if err != nil {
return "", err
}
}
ans, err = theme.AsEscapeCodes()
if err == nil {
ans = "\033[#P" + ans
}
return
}
func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) {
go Data()
go RelevantKittyOpts()
defer func() {
if data_shm != nil {
data_shm.Close()
data_shm.Unlink()
}
}()
cmd := append([]string{SSHExe()}, ssh_args...)
cd := connection_data{remote_args: server_args[1:]}
hostname := server_args[0]
if len(cd.remote_args) == 0 {
cmd = append(cmd, "-t")
}
insertion_point := len(cmd)
cmd = append(cmd, "--", hostname)
uname, hostname_for_match := get_destination(hostname)
overrides, literal_env, err := parse_kitten_args(found_extra_args, uname, hostname_for_match)
if err != nil {
return 1, err
}
host_opts, bad_lines, err := load_config(hostname_for_match, uname, overrides)
if err != nil {
return 1, err
}
if len(bad_lines) > 0 {
for _, x := range bad_lines {
fmt.Fprintf(os.Stderr, "Ignoring bad config line: %s:%d with error: %s", filepath.Base(x.Src_file), x.Line_number, x.Err)
}
}
if host_opts.Share_connections {
kpid, err := strconv.Atoi(os.Getenv("KITTY_PID"))
if err != nil {
return 1, fmt.Errorf("Invalid KITTY_PID env var not an integer: %#v", os.Getenv("KITTY_PID"))
}
cpargs, err := connection_sharing_args(kpid)
if err != nil {
return 1, err
}
cmd = slices.Insert(cmd, insertion_point, cpargs...)
}
use_kitty_askpass := host_opts.Askpass == Askpass_native || (host_opts.Askpass == Askpass_unless_set && os.Getenv("SSH_ASKPASS") == "")
need_to_request_data := true
if use_kitty_askpass {
need_to_request_data = set_askpass()
}
if need_to_request_data && host_opts.Share_connections {
check_cmd := slices.Insert(cmd, 1, "-O", "check")
err = exec.Command(check_cmd[0], check_cmd[1:]...).Run()
if err == nil {
need_to_request_data = false
}
}
term, err := tty.OpenControllingTerm(tty.SetNoEcho)
if err != nil {
return 1, fmt.Errorf("Failed to open controlling terminal with error: %w", err)
}
cd.echo_on = term.WasEchoOnOriginally()
cd.host_opts, cd.literal_env = host_opts, literal_env
cd.request_data = need_to_request_data
cd.hostname_for_match, cd.username = hostname_for_match, uname
escape_codes_to_set_colors, err := change_colors(cd.host_opts.Color_scheme)
if err == nil {
err = term.WriteAllString(escape_codes_to_set_colors + loop.SAVE_PRIVATE_MODE_VALUES + loop.HANDLE_TERMIOS_SIGNALS.EscapeCodeToSet())
}
if err != nil {
return 1, err
}
restore_escape_codes := loop.RESTORE_PRIVATE_MODE_VALUES
if escape_codes_to_set_colors != "" {
restore_escape_codes += "\x1b[#Q"
}
defer func() {
term.WriteAllString(restore_escape_codes)
term.RestoreAndClose()
}()
err = get_remote_command(&cd)
if err != nil {
return 1, err
}
cmd = append(cmd, cd.rcmd...)
c := exec.Command(cmd[0], cmd[1:]...)
c.Stdin, c.Stdout, c.Stderr = os.Stdin, os.Stdout, os.Stderr
err = c.Start()
if err != nil {
return 1, err
}
if !cd.request_data {
rq := fmt.Sprintf("id=%s:pwfile=%s:pw=%s", cd.replacements["REQUEST_ID"], cd.replacements["PASSWORD_FILENAME"], cd.replacements["DATA_PASSWORD"])
err := term.ApplyOperations(tty.TCSANOW, tty.SetNoEcho)
if err == nil {
var dcs string
dcs, err = tui.DCSToKitty("ssh", rq)
if err == nil {
err = term.WriteAllString(dcs)
}
}
if err != nil {
c.Process.Kill()
c.Wait()
return 1, err
}
}
err = c.Wait()
drain_potential_tty_garbage(term)
if err != nil {
var exit_err *exec.ExitError
if errors.As(err, &exit_err) {
return exit_err.ExitCode(), nil
}
return 1, err
}
return 0, nil
}
func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
if len(args) > 0 {
switch args[0] {
case "use-python":
args = args[1:] // backwards compat from when we had a python implementation
case "-h", "--help":
cmd.ShowHelp()
return
}
}
ssh_args, server_args, passthrough, found_extra_args, err := ParseSSHArgs(args, "--kitten")
if err != nil {
var invargs *ErrInvalidSSHArgs
switch {
case errors.As(err, &invargs):
if invargs.Msg != "" {
fmt.Fprintln(os.Stderr, invargs.Msg)
}
return 1, unix.Exec(SSHExe(), []string{"ssh"}, os.Environ())
}
return 1, err
}
if passthrough {
if len(found_extra_args) > 0 {
return 1, fmt.Errorf("The SSH kitten cannot work with the options: %s", strings.Join(maps.Keys(PassthroughArgs()), " "))
}
return 1, unix.Exec(SSHExe(), append([]string{"ssh"}, args...), os.Environ())
}
if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" {
return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window")
}
if !tty.IsTerminal(os.Stdin.Fd()) {
return 1, fmt.Errorf("The SSH kitten is meant for interactive use only, STDIN must be a terminal")
}
return run_ssh(ssh_args, server_args, found_extra_args)
}
func EntryPoint(parent *cli.Command) {
create_cmd(parent, main)
}
func specialize_command(ssh *cli.Command) {
ssh.Usage = "arguments for the ssh command"
ssh.ShortDescription = "Truly convenient SSH"
ssh.HelpText = "The ssh kitten is a thin wrapper around the ssh command. It automatically enables shell integration on the remote host, re-uses existing connections to reduce latency, makes the kitty terminfo database available, etc. It's invocation is identical to the ssh command. For details on its usage, see :doc:`/kittens/ssh`."
ssh.IgnoreAllArgs = true
ssh.OnlyArgsAllowed = true
ssh.ArgCompleter = cli.CompletionForWrapper("ssh")
}
func test_integration_with_python(args []string) (rc int, err error) {
f, err := os.CreateTemp("", "*.conf")
if err != nil {
return 1, err
}
defer func() {
f.Close()
os.Remove(f.Name())
}()
_, err = io.Copy(f, os.Stdin)
if err != nil {
return 1, err
}
cd := &connection_data{
request_id: "testing", remote_args: []string{},
username: "testuser", hostname_for_match: "host.test", request_data: true,
test_script: args[0], echo_on: true,
}
opts, bad_lines, err := load_config(cd.hostname_for_match, cd.username, nil, f.Name())
if err == nil {
if len(bad_lines) > 0 {
return 1, fmt.Errorf("Bad config lines: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err)
}
cd.host_opts = opts
err = get_remote_command(cd)
}
if err != nil {
return 1, err
}
data, err := json.Marshal(map[string]any{"cmd": cd.rcmd, "shm_name": cd.shm_name})
if err == nil {
_, err = os.Stdout.Write(data)
os.Stdout.Close()
}
if err != nil {
return 1, err
}
return
}
func TestEntryPoint(root *cli.Command) {
root.AddSubCommand(&cli.Command{
Name: "ssh",
OnlyArgsAllowed: true,
Run: func(cmd *cli.Command, args []string) (rc int, err error) {
return test_integration_with_python(args)
},
})
}

154
tools/cmd/ssh/main_test.go Normal file
View File

@ -0,0 +1,154 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"encoding/binary"
"encoding/json"
"fmt"
"io/fs"
"kitty/tools/utils/shm"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
func TestCloneEnv(t *testing.T) {
env := map[string]string{"a": "1", "b": "2"}
data, err := json.Marshal(env)
if err != nil {
t.Fatal(err)
}
mmap, err := shm.CreateTemp("", 128)
if err != nil {
t.Fatal(err)
}
defer mmap.Unlink()
copy(mmap.Slice()[4:], data)
binary.BigEndian.PutUint32(mmap.Slice(), uint32(len(data)))
mmap.Close()
x, err := add_cloned_env(mmap.Name())
if err != nil {
t.Fatal(err)
}
diff := cmp.Diff(env, x)
if diff != "" {
t.Fatalf("Failed to deserialize env\n%s", diff)
}
}
func basic_connection_data(overrides ...string) *connection_data {
ans := &connection_data{
script_type: "sh", request_id: "123-123", remote_args: []string{},
username: "testuser", hostname_for_match: "host.test",
}
opts, bad_lines, err := load_config(ans.hostname_for_match, ans.username, overrides, "")
if err != nil {
panic(err)
}
if len(bad_lines) != 0 {
panic(fmt.Sprintf("Bad config lines: %s with error: %s", bad_lines[0].Line, bad_lines[0].Err))
}
ans.host_opts = opts
return ans
}
func TestSSHBootstrapScriptLimit(t *testing.T) {
cd := basic_connection_data()
err := get_remote_command(cd)
if err != nil {
t.Fatal(err)
}
total := 0
for _, x := range cd.rcmd {
total += len(x)
}
if total > 9000 {
t.Fatalf("Bootstrap script too large: %d bytes", total)
}
}
func TestSSHTarfile(t *testing.T) {
tdir := t.TempDir()
cd := basic_connection_data()
data, err := make_tarfile(cd, func(key string) (val string, found bool) { return })
if err != nil {
t.Fatal(err)
}
cmd := exec.Command("tar", "xpzf", "-", "-C", tdir)
cmd.Stderr = os.Stderr
inp, err := cmd.StdinPipe()
if err != nil {
t.Fatal(err)
}
err = cmd.Start()
if err != nil {
t.Fatal(err)
}
_, err = inp.Write(data)
if err != nil {
t.Fatal(err)
}
inp.Close()
err = cmd.Wait()
if err != nil {
t.Fatal(err)
}
seen := map[string]bool{}
err = filepath.WalkDir(tdir, func(name string, d fs.DirEntry, werr error) error {
if werr != nil {
return werr
}
rname, werr := filepath.Rel(tdir, name)
if werr != nil {
return werr
}
rname = strings.ReplaceAll(rname, "\\", "/")
if rname == "." {
return nil
}
fi, werr := d.Info()
if werr != nil {
return werr
}
if fi.Mode().Perm()&0o600 == 0 {
return fmt.Errorf("%s is not rw for its owner. Actual permissions: %s", rname, fi.Mode().String())
}
seen[rname] = true
return nil
})
if err != nil {
t.Fatal(err)
}
if !seen["data.sh"] {
t.Fatalf("data.sh missing")
}
for _, x := range []string{".terminfo/kitty.terminfo", ".terminfo/x/xterm-kitty"} {
if !seen["home/"+x] {
t.Fatalf("%s missing", x)
}
}
for _, x := range []string{"shell-integration/bash/kitty.bash", "shell-integration/fish/vendor_completions.d/kitty.fish"} {
if !seen[path.Join("home", cd.host_opts.Remote_dir, x)] {
t.Fatalf("%s missing", x)
}
}
for _, x := range []string{"kitty", "kitten"} {
p := filepath.Join(tdir, "home", cd.host_opts.Remote_dir, "kitty", "bin", x)
if err = unix.Access(p, unix.X_OK); err != nil {
t.Fatalf("Cannot execute %s with error: %s", x, err)
}
}
if seen[path.Join("home", cd.host_opts.Remote_dir, "shell-integration", "ssh", "kitten")] {
t.Fatalf("Contents of shell-integration/ssh not excluded")
}
}

248
tools/cmd/ssh/utils.go Normal file
View File

@ -0,0 +1,248 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"fmt"
"io"
"kitty"
"kitty/tools/config"
"kitty/tools/utils"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
)
var _ = fmt.Print
var SSHExe = (&utils.Once[string]{Run: func() string {
ans := utils.Which("ssh")
if ans != "" {
return ans
}
ans = utils.Which("ssh", "/usr/local/bin", "/opt/bin", "/opt/homebrew/bin", "/usr/bin", "/bin", "/usr/sbin", "/sbin")
if ans == "" {
ans = "ssh"
}
return ans
}}).Get
var SSHOptions = (&utils.Once[map[string]string]{Run: func() (ssh_options map[string]string) {
defer func() {
if ssh_options == nil {
ssh_options = map[string]string{
"4": "", "6": "", "A": "", "a": "", "C": "", "f": "", "G": "", "g": "", "K": "", "k": "",
"M": "", "N": "", "n": "", "q": "", "s": "", "T": "", "t": "", "V": "", "v": "", "X": "",
"x": "", "Y": "", "y": "", "B": "bind_interface", "b": "bind_address", "c": "cipher_spec",
"D": "[bind_address:]port", "E": "log_file", "e": "escape_char", "F": "configfile", "I": "pkcs11",
"i": "identity_file", "J": "[user@]host[:port]", "L": "address", "l": "login_name", "m": "mac_spec",
"O": "ctl_cmd", "o": "option", "p": "port", "Q": "query_option", "R": "address",
"S": "ctl_path", "W": "host:port", "w": "local_tun[:remote_tun]",
}
}
}()
cmd := exec.Command(SSHExe())
stderr, err := cmd.StderrPipe()
if err != nil {
return
}
if err := cmd.Start(); err != nil {
return
}
raw, err := io.ReadAll(stderr)
if err != nil {
return
}
text := utils.UnsafeBytesToString(raw)
ssh_options = make(map[string]string, 32)
for {
pos := strings.IndexByte(text, '[')
if pos < 0 {
break
}
num := 1
epos := pos
for num > 0 {
epos++
switch text[epos] {
case '[':
num += 1
case ']':
num -= 1
}
}
q := text[pos+1 : epos]
text = text[epos:]
if len(q) < 2 || !strings.HasPrefix(q, "-") {
continue
}
opt, desc, found := strings.Cut(q, " ")
if found {
ssh_options[opt[1:]] = desc
} else {
for _, ch := range opt[1:] {
ssh_options[string(ch)] = ""
}
}
}
return
}}).Get
func GetSSHCLI() (boolean_ssh_args *utils.Set[string], other_ssh_args *utils.Set[string]) {
other_ssh_args, boolean_ssh_args = utils.NewSet[string](32), utils.NewSet[string](32)
for k, v := range SSHOptions() {
k = "-" + k
if v == "" {
boolean_ssh_args.Add(k)
} else {
other_ssh_args.Add(k)
}
}
return
}
func is_extra_arg(arg string, extra_args []string) string {
for _, x := range extra_args {
if arg == x || strings.HasPrefix(arg, x+"=") {
return x
}
}
return ""
}
type ErrInvalidSSHArgs struct {
Msg string
}
func (self *ErrInvalidSSHArgs) Error() string {
return self.Msg
}
func PassthroughArgs() map[string]bool {
return map[string]bool{"-N": true, "-n": true, "-f": true, "-G": true, "-T": true}
}
func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, server_args []string, passthrough bool, found_extra_args []string, err error) {
if extra_args == nil {
extra_args = []string{}
}
if len(args) == 0 {
passthrough = true
return
}
passthrough_args := PassthroughArgs()
boolean_ssh_args, other_ssh_args := GetSSHCLI()
ssh_args, server_args, found_extra_args = make([]string, 0, 16), make([]string, 0, 16), make([]string, 0, 16)
expecting_option_val := false
stop_option_processing := false
expecting_extra_val := ""
for _, argument := range args {
if len(server_args) > 1 || stop_option_processing {
server_args = append(server_args, argument)
continue
}
if strings.HasPrefix(argument, "-") && !expecting_option_val {
if argument == "--" {
stop_option_processing = true
continue
}
if len(extra_args) > 0 {
matching_ex := is_extra_arg(argument, extra_args)
if matching_ex != "" {
_, exval, found := strings.Cut(argument, "=")
if found {
found_extra_args = append(found_extra_args, matching_ex, exval)
} else {
expecting_extra_val = matching_ex
expecting_option_val = true
}
continue
}
}
// could be a multi-character option
all_args := []rune(argument[1:])
for i, ch := range all_args {
arg := "-" + string(ch)
if passthrough_args[arg] {
passthrough = true
}
if boolean_ssh_args.Has(arg) {
ssh_args = append(ssh_args, arg)
continue
}
if other_ssh_args.Has(arg) {
ssh_args = append(ssh_args, arg)
if i+1 < len(all_args) {
ssh_args = append(ssh_args, string(all_args[i+1:]))
} else {
expecting_option_val = true
}
break
}
err = &ErrInvalidSSHArgs{Msg: "unknown option -- " + arg[1:]}
return
}
continue
}
if expecting_option_val {
if expecting_extra_val != "" {
found_extra_args = append(found_extra_args, expecting_extra_val, argument)
} else {
ssh_args = append(ssh_args, argument)
}
expecting_option_val = false
continue
}
server_args = append(server_args, argument)
}
if len(server_args) == 0 && !passthrough {
err = &ErrInvalidSSHArgs{Msg: ""}
}
return
}
type SSHVersion struct{ Major, Minor int }
func (self SSHVersion) SupportsAskpassRequire() bool {
return self.Major > 8 || (self.Major == 8 && self.Minor >= 4)
}
var GetSSHVersion = (&utils.Once[SSHVersion]{Run: func() SSHVersion {
b, err := exec.Command(SSHExe(), "-V").CombinedOutput()
if err != nil {
return SSHVersion{}
}
m := regexp.MustCompile(`OpenSSH_(\d+).(\d+)`).FindSubmatch(b)
if len(m) == 3 {
maj, _ := strconv.Atoi(utils.UnsafeBytesToString(m[1]))
min, _ := strconv.Atoi(utils.UnsafeBytesToString(m[2]))
return SSHVersion{Major: maj, Minor: min}
}
return SSHVersion{}
}}).Get
type KittyOpts struct {
Term, Shell_integration string
}
func read_relevant_kitty_opts(path string) KittyOpts {
ans := KittyOpts{Term: kitty.KittyConfigDefaults.Term, Shell_integration: kitty.KittyConfigDefaults.Shell_integration}
handle_line := func(key, val string) error {
switch key {
case "term":
ans.Term = strings.TrimSpace(val)
case "shell_integration":
ans.Shell_integration = strings.TrimSpace(val)
}
return nil
}
cp := config.ConfigParser{LineHandler: handle_line}
cp.ParseFiles(path)
return ans
}
var RelevantKittyOpts = (&utils.Once[KittyOpts]{Run: func() KittyOpts {
return read_relevant_kitty_opts(filepath.Join(utils.ConfigDir(), "kitty.conf"))
}}).Get

View File

@ -0,0 +1,68 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package ssh
import (
"fmt"
"os"
"path/filepath"
"testing"
"kitty/tools/utils/shlex"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestGetSSHOptions(t *testing.T) {
m := SSHOptions()
if m["w"] != "local_tun[:remote_tun]" {
t.Fatalf("Unexpected set of SSH options: %#v", m)
}
}
func TestParseSSHArgs(t *testing.T) {
split := func(x string) []string {
ans, err := shlex.Split(x)
if err != nil {
t.Fatal(err)
}
return ans
}
p := func(args, expected_ssh_args, expected_server_args, expected_extra_args string, expected_passthrough bool) {
ssh_args, server_args, passthrough, extra_args, err := ParseSSHArgs(split(args), "--kitten")
if err != nil {
t.Fatal(err)
}
check := func(a, b any) {
diff := cmp.Diff(a, b)
if diff != "" {
t.Fatalf("Unexpected value for args: %s\n%s", args, diff)
}
}
check(split(expected_ssh_args), ssh_args)
check(split(expected_server_args), server_args)
check(split(expected_extra_args), extra_args)
check(expected_passthrough, passthrough)
}
p(`localhost`, ``, `localhost`, ``, false)
p(`-- localhost`, ``, `localhost`, ``, false)
p(`-46p23 localhost sh -c "a b"`, `-4 -6 -p 23`, `localhost sh -c "a b"`, ``, false)
p(`-46p23 -S/moose -W x:6 -- localhost sh -c "a b"`, `-4 -6 -p 23 -S /moose -W x:6`, `localhost sh -c "a b"`, ``, false)
p(`--kitten=abc -np23 --kitten xyz host`, `-n -p 23`, `host`, `--kitten abc --kitten xyz`, true)
}
func TestRelevantKittyOpts(t *testing.T) {
tdir := t.TempDir()
path := filepath.Join(tdir, "kitty.conf")
os.WriteFile(path, []byte("term XXX\nshell_integration changed\nterm abcd"), 0o600)
rko := read_relevant_kitty_opts(path)
if rko.Term != "abcd" {
t.Fatalf("Unexpected TERM: %s", RelevantKittyOpts().Term)
}
if rko.Shell_integration != "changed" {
t.Fatalf("Unexpected shell_integration: %s", RelevantKittyOpts().Shell_integration)
}
}

View File

@ -10,6 +10,8 @@ import (
"kitty/tools/cmd/clipboard" "kitty/tools/cmd/clipboard"
"kitty/tools/cmd/edit_in_kitty" "kitty/tools/cmd/edit_in_kitty"
"kitty/tools/cmd/icat" "kitty/tools/cmd/icat"
"kitty/tools/cmd/pytest"
"kitty/tools/cmd/ssh"
"kitty/tools/cmd/unicode_input" "kitty/tools/cmd/unicode_input"
"kitty/tools/cmd/update_self" "kitty/tools/cmd/update_self"
"kitty/tools/tui" "kitty/tools/tui"
@ -30,8 +32,12 @@ func KittyToolEntryPoints(root *cli.Command) {
clipboard.EntryPoint(root) clipboard.EntryPoint(root)
// icat // icat
icat.EntryPoint(root) icat.EntryPoint(root)
// ssh
ssh.EntryPoint(root)
// unicode_input // unicode_input
unicode_input.EntryPoint(root) unicode_input.EntryPoint(root)
// __pytest__
pytest.EntryPoint(root)
// __hold_till_enter__ // __hold_till_enter__
root.AddSubCommand(&cli.Command{ root.AddSubCommand(&cli.Command{
Name: "__hold_till_enter__", Name: "__hold_till_enter__",

192
tools/config/api.go Normal file
View File

@ -0,0 +1,192 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package config
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"kitty/tools/utils"
"os"
"path/filepath"
"strings"
)
var _ = fmt.Print
func StringToBool(x string) bool {
x = strings.ToLower(x)
return x == "y" || x == "yes" || x == "true"
}
type ConfigLine struct {
Src_file, Line string
Line_number int
Err error
}
type ConfigParser struct {
LineHandler func(key, val string) error
CommentsHandler func(line string) error
SourceHandler func(text, path string)
bad_lines []ConfigLine
seen_includes map[string]bool
override_env []string
}
type Scanner interface {
Scan() bool
Text() string
Err() error
}
func (self *ConfigParser) BadLines() []ConfigLine {
return self.bad_lines
}
func (self *ConfigParser) parse(scanner Scanner, name, base_path_for_includes string, depth int) error {
if self.seen_includes[name] { // avoid include loops
return nil
}
self.seen_includes[name] = true
recurse := func(r io.Reader, nname, base_path_for_includes string) error {
if depth > 32 {
return fmt.Errorf("Too many nested include directives while processing config file: %s", name)
}
escanner := bufio.NewScanner(r)
return self.parse(escanner, nname, base_path_for_includes, depth+1)
}
lnum := 0
make_absolute := func(path string) (string, error) {
if path == "" {
return "", fmt.Errorf("Empty include paths not allowed")
}
if !filepath.IsAbs(path) {
path = filepath.Join(base_path_for_includes, path)
}
return path, nil
}
for scanner.Scan() {
line := strings.TrimLeft(scanner.Text(), " ")
lnum++
if line == "" {
continue
}
if line[0] == '#' {
if self.CommentsHandler != nil {
err := self.CommentsHandler(line)
if err != nil {
self.bad_lines = append(self.bad_lines, ConfigLine{Src_file: name, Line: line, Line_number: lnum, Err: err})
}
}
continue
}
key, val, _ := strings.Cut(line, " ")
switch key {
default:
err := self.LineHandler(key, val)
if err != nil {
self.bad_lines = append(self.bad_lines, ConfigLine{Src_file: name, Line: line, Line_number: lnum, Err: err})
}
case "include", "globinclude", "envinclude":
var includes []string
switch key {
case "include":
aval, err := make_absolute(val)
if err == nil {
includes = []string{aval}
}
case "globinclude":
aval, err := make_absolute(val)
if err == nil {
matches, err := filepath.Glob(aval)
if err == nil {
includes = matches
}
}
case "envinclude":
env := self.override_env
if env == nil {
env = os.Environ()
}
for _, x := range env {
key, eval, _ := strings.Cut(x, "=")
is_match, err := filepath.Match(val, key)
if is_match && err == nil {
err := recurse(strings.NewReader(eval), "<env var: "+key+">", base_path_for_includes)
if err != nil {
return err
}
}
}
}
if len(includes) > 0 {
for _, incpath := range includes {
raw, err := os.ReadFile(incpath)
if err == nil {
err := recurse(bytes.NewReader(raw), incpath, filepath.Dir(incpath))
if err != nil {
return err
}
} else if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("Failed to process include %#v with error: %w", incpath, err)
}
}
}
}
}
return nil
}
func (self *ConfigParser) ParseFiles(paths ...string) error {
for _, path := range paths {
apath, err := filepath.Abs(path)
if err == nil {
path = apath
}
raw, err := os.ReadFile(path)
if err != nil {
return err
}
scanner := bufio.NewScanner(bytes.NewReader(raw))
self.seen_includes = make(map[string]bool)
err = self.parse(scanner, path, filepath.Dir(path), 0)
if err != nil {
return err
}
if self.SourceHandler != nil {
self.SourceHandler(utils.UnsafeBytesToString(raw), path)
}
}
return nil
}
type LinesScanner struct {
lines []string
}
func (self *LinesScanner) Scan() bool {
return len(self.lines) > 0
}
func (self *LinesScanner) Text() string {
ans := self.lines[0]
self.lines = self.lines[1:]
return ans
}
func (self *LinesScanner) Err() error {
return nil
}
func (self *ConfigParser) ParseOverrides(overrides ...string) error {
s := LinesScanner{lines: overrides}
self.seen_includes = make(map[string]bool)
return self.parse(&s, "<overrides>", utils.ConfigDir(), 0)
}

62
tools/config/api_test.go Normal file
View File

@ -0,0 +1,62 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package config
import (
"fmt"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestConfigParsing(t *testing.T) {
tdir := t.TempDir()
conf_file := filepath.Join(tdir, "a.conf")
os.Mkdir(filepath.Join(tdir, "sub"), 0o700)
os.WriteFile(conf_file, []byte(
`error main
# ignore me
a one
#: other
include sub/b.conf
b
include non-existent
globinclude sub/c?.conf
`), 0o600)
os.WriteFile(filepath.Join(tdir, "sub/b.conf"), []byte("incb cool\ninclude a.conf"), 0o600)
os.WriteFile(filepath.Join(tdir, "sub/c1.conf"), []byte("inc1 cool"), 0o600)
os.WriteFile(filepath.Join(tdir, "sub/c2.conf"), []byte("inc2 cool\nenvinclude ENVINCLUDE"), 0o600)
os.WriteFile(filepath.Join(tdir, "sub/c.conf"), []byte("inc notcool\nerror sub"), 0o600)
var parsed_lines []string
pl := func(key, val string) error {
if key == "error" {
return fmt.Errorf("%s", val)
}
parsed_lines = append(parsed_lines, key+" "+val)
return nil
}
p := ConfigParser{LineHandler: pl, override_env: []string{"ENVINCLUDE=env cool\ninclude c.conf"}}
err := p.ParseFiles(conf_file)
if err != nil {
t.Fatal(err)
}
err = p.ParseOverrides("over one", "over two")
diff := cmp.Diff([]string{"a one", "incb cool", "b ", "inc1 cool", "inc2 cool", "env cool", "inc notcool", "over one", "over two"}, parsed_lines)
if diff != "" {
t.Fatalf("Unexpected parsed config values:\n%s", diff)
}
bad_lines := []string{}
for _, bl := range p.BadLines() {
bad_lines = append(bad_lines, fmt.Sprintf("%s: %d", filepath.Base(bl.Src_file), bl.Line_number))
}
diff = cmp.Diff([]string{"a.conf: 1", "c.conf: 2"}, bad_lines)
if diff != "" {
t.Fatalf("Unexpected bad lines:\n%s", diff)
}
}

430
tools/themes/collection.go Normal file
View File

@ -0,0 +1,430 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package themes
import (
"archive/zip"
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"kitty/tools/config"
"kitty/tools/utils"
"kitty/tools/utils/style"
"net/http"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/exp/maps"
)
var _ = fmt.Print
type JSONMetadata struct {
Etag string `json:"etag"`
Timestamp string `json:"timestamp"`
}
var ErrNoCacheFound = errors.New("No cache found and max cache age is negative")
func fetch_cached(name, url, cache_path string, max_cache_age time.Duration) (string, error) {
cache_path = filepath.Join(cache_path, name+".zip")
zf, err := zip.OpenReader(cache_path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return "", err
}
var jm JSONMetadata
if err == nil {
err = json.Unmarshal(utils.UnsafeStringToBytes(zf.Comment), &jm)
if max_cache_age < 0 {
return cache_path, nil
}
cache_age, err := utils.ISO8601Parse(jm.Timestamp)
if err == nil {
if time.Now().Before(cache_age.Add(max_cache_age)) {
return cache_path, nil
}
}
}
if max_cache_age < 0 {
return "", ErrNoCacheFound
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
if jm.Etag != "" {
req.Header.Add("If-None-Match", jm.Etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("Failed to download %s with error: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusNotModified {
return cache_path, nil
}
return "", fmt.Errorf("Failed to download %s with HTTP error: %s", url, resp.Status)
}
var tf, tf2 *os.File
tf, err = os.CreateTemp(filepath.Dir(cache_path), name+".temp-*")
if err == nil {
tf2, err = os.CreateTemp(filepath.Dir(cache_path), name+".temp-*")
}
defer func() {
if tf != nil {
tf.Close()
os.Remove(tf.Name())
tf = nil
}
if tf2 != nil {
tf2.Close()
os.Remove(tf2.Name())
tf2 = nil
}
}()
if err != nil {
return "", fmt.Errorf("Failed to create temp file in %s with error: %w", filepath.Dir(cache_path), err)
}
_, err = io.Copy(tf, resp.Body)
if err != nil {
return "", fmt.Errorf("Failed to download %s with error: %w", url, err)
}
r, err := zip.OpenReader(tf.Name())
if err != nil {
return "", fmt.Errorf("Failed to open downloaded zip file with error: %w", err)
}
w := zip.NewWriter(tf2)
jm.Etag = resp.Header.Get("ETag")
jm.Timestamp = utils.ISO8601Format(time.Now())
comment, _ := json.Marshal(jm)
w.SetComment(utils.UnsafeBytesToString(comment))
for _, file := range r.File {
err = w.Copy(file)
if err != nil {
return "", fmt.Errorf("Failed to copy zip file from source to destination archive")
}
}
err = w.Close()
if err != nil {
return "", err
}
tf2.Close()
err = os.Rename(tf2.Name(), cache_path)
if err != nil {
return "", fmt.Errorf("Failed to atomic rename temp file to %s with error: %w", cache_path, err)
}
tf2 = nil
return cache_path, nil
}
func FetchCached(max_cache_age time.Duration) (string, error) {
return fetch_cached("kitty-themes", "https://codeload.github.com/kovidgoyal/kitty-themes/zip/master", utils.CacheDir(), max_cache_age)
}
type ThemeMetadata struct {
Name string `json:"name"`
Filepath string `json:"file"`
Is_dark bool `json:"is_dark"`
Num_settings int `json:"num_settings"`
Blurb string `json:"blurb"`
License string `json:"license"`
Upstream string `json:"upstream"`
Author string `json:"author"`
}
func parse_theme_metadata(path string) (*ThemeMetadata, map[string]string, error) {
var in_metadata, in_blurb, finished_metadata bool
ans := ThemeMetadata{}
settings := map[string]string{}
read_is_dark := func(key, val string) (err error) {
settings[key] = val
if key == "background" {
val = strings.TrimSpace(val)
if val != "" {
bg, err := style.ParseColor(val)
if err == nil {
ans.Is_dark = utils.Max(bg.Red, bg.Green, bg.Green) < 115
}
}
}
return
}
read_metadata := func(line string) (err error) {
is_block := strings.HasPrefix(line, "## ")
if in_metadata && !is_block {
finished_metadata = true
}
if finished_metadata {
return
}
if !in_metadata && is_block {
in_metadata = true
}
if !in_metadata {
return
}
line = line[3:]
if in_blurb {
ans.Blurb += " " + line
return
}
key, val, found := strings.Cut(line, ":")
if !found {
return
}
key = strings.TrimSpace(strings.ToLower(key))
val = strings.TrimSpace(val)
switch key {
case "name":
ans.Name = val
case "author":
ans.Author = val
case "upstream":
ans.Upstream = val
case "blurb":
ans.Blurb = val
in_blurb = true
case "license":
ans.License = val
}
return
}
cp := config.ConfigParser{LineHandler: read_is_dark, CommentsHandler: read_metadata}
err := cp.ParseFiles(path)
if err != nil {
return nil, nil, err
}
ans.Num_settings = len(settings)
return &ans, settings, nil
}
type Theme struct {
metadata *ThemeMetadata
code string
settings map[string]string
zip_reader *zip.File
is_user_defined bool
}
func (self *Theme) load_code() (string, error) {
if self.zip_reader != nil {
f, err := self.zip_reader.Open()
self.zip_reader = nil
if err != nil {
return "", err
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return "", err
}
self.code = utils.UnsafeBytesToString(data)
}
return self.code, nil
}
func (self *Theme) Settings() (map[string]string, error) {
if self.zip_reader != nil {
code, err := self.load_code()
if err != nil {
return nil, err
}
self.settings = make(map[string]string, 64)
scanner := bufio.NewScanner(strings.NewReader(code))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" && line[0] != '#' {
key, val, found := strings.Cut(line, " ")
if found {
self.settings[key] = val
}
}
}
}
return self.settings, nil
}
func (self *Theme) AsEscapeCodes() (string, error) {
settings, err := self.Settings()
if err != nil {
return "", err
}
w := strings.Builder{}
w.Grow(4096)
set_color := func(i int, sharp string) {
w.WriteByte(';')
w.WriteString(strconv.Itoa(i))
w.WriteByte(';')
w.WriteString(sharp)
}
set_default_color := func(name, defval string, num int) {
w.WriteString("\033]")
defer func() { w.WriteString("\033\\") }()
val, found := settings[name]
if !found {
val = defval
}
if val != "" {
rgba, err := style.ParseColor(val)
if err == nil {
w.WriteString(strconv.Itoa(num))
w.WriteByte(';')
w.WriteString(rgba.AsRGBSharp())
return
}
}
w.WriteByte('1')
w.WriteString(strconv.Itoa(num))
}
set_default_color("foreground", style.DefaultColors.Foreground, 10)
set_default_color("background", style.DefaultColors.Background, 11)
set_default_color("cursor", style.DefaultColors.Cursor, 12)
set_default_color("selection_background", style.DefaultColors.SelectionBg, 17)
set_default_color("selection_foreground", style.DefaultColors.SelectionFg, 19)
w.WriteString("\033]4")
for i := 0; i < 256; i++ {
key := "color" + strconv.Itoa(i)
val := settings[key]
if val != "" {
rgba, err := style.ParseColor(val)
if err == nil {
set_color(i, rgba.AsRGBSharp())
continue
}
}
rgba := style.RGBA{}
rgba.FromRGB(style.ColorTable[i])
set_color(i, rgba.AsRGBSharp())
}
w.WriteString("\033\\")
return w.String(), nil
}
type Themes struct {
name_map map[string]*Theme
index_map []string
}
var camel_case_pat = (&utils.Once[*regexp.Regexp]{Run: func() *regexp.Regexp {
return regexp.MustCompile(`([a-z])([A-Z])`)
}}).Get
func theme_name_from_file_name(fname string) string {
fname = fname[:len(fname)-len(path.Ext(fname))]
fname = strings.ReplaceAll(fname, "_", " ")
fname = camel_case_pat().ReplaceAllString(fname, "$1 $2")
return strings.Join(utils.Map(strings.Split(fname, " "), strings.Title), " ")
}
func (self *Themes) AddFromFile(path string) (*Theme, error) {
m, conf, err := parse_theme_metadata(path)
if err != nil {
return nil, err
}
if m.Name == "" {
m.Name = theme_name_from_file_name(filepath.Base(path))
}
t := Theme{metadata: m, is_user_defined: true, settings: conf}
self.name_map[m.Name] = &t
return &t, nil
}
func (self *Themes) add_from_dir(dirpath string) error {
entries, err := os.ReadDir(dirpath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
err = nil
}
return err
}
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(e.Name(), ".conf") {
if _, err = self.AddFromFile(filepath.Join(dirpath, e.Name())); err != nil {
return err
}
}
}
return nil
}
func (self *Themes) add_from_zip_file(zippath string) (io.Closer, error) {
r, err := zip.OpenReader(zippath)
if err != nil {
return nil, err
}
name_map := make(map[string]*zip.File, len(r.File))
var themes []ThemeMetadata
theme_dir := ""
for _, file := range r.File {
name_map[file.Name] = file
if path.Base(file.Name) == "themes.json" {
theme_dir = path.Dir(file.Name)
fr, err := file.Open()
if err != nil {
return nil, fmt.Errorf("Error while opening %s from the ZIP file: %w", file.Name, err)
}
defer fr.Close()
raw, err := io.ReadAll(fr)
if err != nil {
return nil, fmt.Errorf("Error while reading %s from the ZIP file: %w", file.Name, err)
}
err = json.Unmarshal(raw, &themes)
if err != nil {
return nil, fmt.Errorf("Error while decoding %s: %w", file.Name, err)
}
}
}
if theme_dir == "" {
return nil, fmt.Errorf("No themes.json found in ZIP file")
}
for _, theme := range themes {
key := path.Join(theme_dir, theme.Filepath)
f := name_map[key]
if f != nil {
t := Theme{metadata: &theme, zip_reader: f}
self.name_map[theme.Name] = &t
}
}
return r, nil
}
func (self *Themes) ThemeByName(name string) *Theme {
return self.name_map[name]
}
func LoadThemes(cache_age time.Duration) (ans *Themes, closer io.Closer, err error) {
zip_path, err := FetchCached(cache_age)
ans = &Themes{name_map: make(map[string]*Theme)}
if err != nil {
return nil, nil, err
}
if closer, err = ans.add_from_zip_file(zip_path); err != nil {
return nil, nil, err
}
if err = ans.add_from_dir(filepath.Join(utils.ConfigDir(), "themes")); err != nil {
return nil, nil, err
}
ans.index_map = maps.Keys(ans.name_map)
ans.index_map = utils.StableSortWithKey(ans.index_map, strings.ToLower)
return ans, closer, nil
}
func ThemeFromFile(path string) (*Theme, error) {
ans := &Themes{name_map: make(map[string]*Theme)}
return ans.AddFromFile(path)
}

View File

@ -0,0 +1,157 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package themes
import (
"archive/zip"
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestThemeCollections(t *testing.T) {
for fname, expected := range map[string]string{
"moose": "Moose",
"mooseCat": "Moose Cat",
"a_bC": "A B C",
} {
actual := theme_name_from_file_name(fname)
if diff := cmp.Diff(expected, actual); diff != "" {
t.Fatalf("Unexpected theme name for %s:\n%s", fname, diff)
}
}
tdir := t.TempDir()
pt := func(expected ThemeMetadata, lines ...string) {
os.WriteFile(filepath.Join(tdir, "temp.conf"), []byte(strings.Join(lines, "\n")), 0o600)
actual, _, err := parse_theme_metadata(filepath.Join(tdir, "temp.conf"))
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(&expected, actual); diff != "" {
t.Fatalf("Failed to parse:\n%s\n\n%s", strings.Join(lines, "\n"), diff)
}
}
pt(ThemeMetadata{Name: "XYZ", Blurb: "a b", Author: "A", Is_dark: true, Num_settings: 2},
"# some crap", " ", "## ", "## author: A", "## name: XYZ", "## blurb: a", "## b", "", "color red", "background black", "include inc.conf")
os.WriteFile(filepath.Join(tdir, "inc.conf"), []byte("background white"), 0o600)
pt(ThemeMetadata{Name: "XYZ", Blurb: "a b", Author: "A", Num_settings: 2},
"# some crap", " ", "## ", "## author: A", "## name: XYZ", "## blurb: a", "## b", "", "color red", "background black", "include inc.conf")
buf := bytes.Buffer{}
zw := zip.NewWriter(&buf)
fw, _ := zw.Create("x/themes.json")
fw.Write([]byte(`[
{
"author": "X Y",
"blurb": "A dark color scheme for the kitty terminal.",
"file": "themes/Alabaster_Dark.conf",
"is_dark": true,
"license": "MIT",
"name": "Alabaster Dark",
"num_settings": 30,
"upstream": "https://xxx.com"
},
{
"name": "Empty", "file": "empty.conf"
}
]`))
fw, _ = zw.Create("x/empty.conf")
fw.Write([]byte("empty"))
fw, _ = zw.Create("x/themes/Alabaster_Dark.conf")
fw.Write([]byte("alabaster"))
zw.Close()
received_etag := ""
request_count := 0
check_etag := true
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request_count++
received_etag = r.Header.Get("If-None-Match")
if check_etag && received_etag == `"xxx"` {
w.WriteHeader(http.StatusNotModified)
return
}
w.Header().Add("ETag", `"xxx"`)
w.Write(buf.Bytes())
}))
defer ts.Close()
_, err := fetch_cached("test", ts.URL, tdir, 0)
if err != nil {
t.Fatal(err)
}
r, err := zip.OpenReader(filepath.Join(tdir, "test.zip"))
if err != nil {
t.Fatal(err)
}
var jm JSONMetadata
err = json.Unmarshal([]byte(r.Comment), &jm)
if err != nil {
t.Fatal(err)
}
if jm.Etag != `"xxx"` {
t.Fatalf("Unexpected ETag: %#v", jm.Etag)
}
_, err = fetch_cached("test", ts.URL, tdir, time.Hour)
if err != nil {
t.Fatal(err)
}
if request_count != 1 {
t.Fatal("Cached zip file was not used")
}
before, _ := os.Stat(filepath.Join(tdir, "test.zip"))
_, err = fetch_cached("test", ts.URL, tdir, 0)
if err != nil {
t.Fatal(err)
}
if request_count != 2 {
t.Fatal("Cached zip file was incorrectly used")
}
if received_etag != `"xxx"` {
t.Fatalf("Got invalid ETag: %#v", received_etag)
}
after, _ := os.Stat(filepath.Join(tdir, "test.zip"))
if before.ModTime() != after.ModTime() {
t.Fatal("Cached zip file was incorrectly re-downloaded")
}
check_etag = false
_, err = fetch_cached("test", ts.URL, tdir, 0)
if err != nil {
t.Fatal(err)
}
after2, _ := os.Stat(filepath.Join(tdir, "test.zip"))
if after2.ModTime() != after.ModTime() {
t.Fatal("Cached zip file was incorrectly not re-downloaded")
}
coll := Themes{name_map: map[string]*Theme{}}
closer, err := coll.add_from_zip_file(filepath.Join(tdir, "test.zip"))
if err != nil {
t.Fatal(err)
}
defer closer.Close()
if code, err := coll.ThemeByName("Empty").load_code(); code != "empty" {
if err != nil {
t.Fatal(err)
}
t.Fatal("failed to load code for empty theme")
}
if code, err := coll.ThemeByName("Alabaster Dark").load_code(); code != "alabaster" {
if err != nil {
t.Fatal(err)
}
t.Fatal("failed to load code for alabaster theme")
}
}

View File

@ -148,6 +148,13 @@ func (self *Term) Close() error {
return err return err
} }
func (self *Term) WasEchoOnOriginally() bool {
if len(self.states) > 0 {
return self.states[0].Lflag&unix.ECHO != 0
}
return false
}
func (self *Term) Tcgetattr(ans *unix.Termios) error { func (self *Term) Tcgetattr(ans *unix.Termios) error {
return eintr_retry_noret(func() error { return Tcgetattr(self.Fd(), ans) }) return eintr_retry_noret(func() error { return Tcgetattr(self.Fd(), ans) })
} }
@ -256,13 +263,19 @@ func (self *Term) ReadWithTimeout(b []byte, d time.Duration) (n int, err error)
} }
num_ready, err := pselect() num_ready, err := pselect()
if err != nil { if err != nil {
return return 0, err
} }
if num_ready == 0 { if num_ready == 0 {
err = os.ErrDeadlineExceeded err = os.ErrDeadlineExceeded
return return 0, err
}
for {
n, err = self.Read(b)
if errors.Is(err, unix.EINTR) {
continue
}
return n, err
} }
return self.Read(b)
} }
func (self *Term) Read(b []byte) (int, error) { func (self *Term) Read(b []byte) (int, error) {
@ -273,10 +286,14 @@ func (self *Term) Write(b []byte) (int, error) {
return self.os_file.Write(b) return self.os_file.Write(b)
} }
func is_temporary_error(err error) bool {
return errors.Is(err, unix.EINTR) || errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, io.ErrShortWrite)
}
func (self *Term) WriteAll(b []byte) error { func (self *Term) WriteAll(b []byte) error {
for len(b) > 0 { for len(b) > 0 {
n, err := self.os_file.Write(b) n, err := self.os_file.Write(b)
if err != nil && !errors.Is(err, io.ErrShortWrite) { if err != nil && !is_temporary_error(err) {
return err return err
} }
b = b[n:] b = b[n:]
@ -284,6 +301,10 @@ func (self *Term) WriteAll(b []byte) error {
return nil return nil
} }
func (self *Term) WriteAllString(s string) error {
return self.WriteAll(utils.UnsafeStringToBytes(s))
}
func (self *Term) WriteString(b string) (int, error) { func (self *Term) WriteString(b string) (int, error) {
return self.os_file.WriteString(b) return self.os_file.WriteString(b)
} }

28
tools/tui/dcs_to_kitty.go Normal file
View File

@ -0,0 +1,28 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package tui
import (
"encoding/base64"
"fmt"
"kitty/tools/utils"
)
var _ = fmt.Print
func DCSToKitty(msgtype, payload string) (string, error) {
data := base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(payload))
ans := "\x1bP@kitty-" + msgtype + "|" + data
tmux := TmuxSocketAddress()
if tmux != "" {
err := TmuxAllowPassthrough()
if err != nil {
return "", err
}
ans = "\033Ptmux;\033" + ans + "\033\033\\\033\\"
} else {
ans += "\033\\"
}
return ans, nil
}

56
tools/tui/tmux.go Normal file
View File

@ -0,0 +1,56 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package tui
import (
"fmt"
"kitty/tools/utils"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"github.com/shirou/gopsutil/v3/process"
"golang.org/x/sys/unix"
)
var _ = fmt.Print
func tmux_socket_address() (socket string) {
socket = os.Getenv("TMUX")
if socket == "" {
return ""
}
addr, pid_str, found := strings.Cut(socket, ",")
if !found {
return ""
}
if unix.Access(addr, unix.R_OK|unix.W_OK) != nil {
return ""
}
pid, err := strconv.ParseInt(pid_str, 10, 32)
if err != nil {
return ""
}
p, err := process.NewProcess(int32(pid))
if err != nil {
return ""
}
cmd, err := p.CmdlineSlice()
if err != nil {
return ""
}
if len(cmd) > 0 && strings.ToLower(filepath.Base(cmd[0])) != "tmux" {
return ""
}
return socket
}
var TmuxSocketAddress = (&utils.Once[string]{Run: tmux_socket_address}).Get
func tmux_allow_passthrough() error {
return exec.Command("tmux", "set", "-p", "allow-passthrough", "on").Run()
}
var TmuxAllowPassthrough = (&utils.Once[error]{Run: tmux_allow_passthrough}).Get

View File

@ -4,11 +4,9 @@ package unicode_names
import ( import (
"bytes" "bytes"
"compress/zlib"
_ "embed" _ "embed"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -64,33 +62,8 @@ func parse_record(record []byte, mark uint16) {
var parse_once sync.Once var parse_once sync.Once
func read_all(r io.Reader, expected_size int) ([]byte, error) {
b := make([]byte, 0, expected_size)
for {
if len(b) == cap(b) {
// Add more capacity (let append pick how much).
b = append(b, 0)[:len(b)]
}
n, err := r.Read(b[len(b):cap(b)])
b = b[:len(b)+n]
if err != nil {
if err == io.EOF {
err = nil
}
return b, err
}
}
}
func parse_data() { func parse_data() {
compressed := utils.UnsafeStringToBytes(unicode_name_data) raw := utils.ReadCompressedEmbeddedData(unicode_name_data)
uncompressed_size := binary.LittleEndian.Uint32(compressed)
r, _ := zlib.NewReader(bytes.NewReader(compressed[4:]))
defer r.Close()
raw, err := read_all(r, int(uncompressed_size))
if err != nil {
panic(err)
}
num_of_lines := binary.LittleEndian.Uint32(raw) num_of_lines := binary.LittleEndian.Uint32(raw)
raw = raw[4:] raw = raw[4:]
num_of_words := binary.LittleEndian.Uint32(raw) num_of_words := binary.LittleEndian.Uint32(raw)

View File

@ -12,6 +12,33 @@ import (
var _ = fmt.Print var _ = fmt.Print
func AtomicCreateSymlink(oldname, newname string) (err error) {
err = os.Symlink(oldname, newname)
if err == nil {
return nil
}
if !errors.Is(err, fs.ErrExist) {
return err
}
if et, err := os.Readlink(newname); err == nil && et == oldname {
return nil
}
for {
tempname := newname + RandomFilename()
err = os.Symlink(oldname, tempname)
if err == nil {
err = os.Rename(tempname, newname)
if err != nil {
os.Remove(tempname)
}
return err
}
if !errors.Is(err, fs.ErrExist) {
return err
}
}
}
func AtomicWriteFile(path string, data []byte, perm os.FileMode) (err error) { func AtomicWriteFile(path string, data []byte, perm os.FileMode) (err error) {
npath, err := filepath.EvalSymlinks(path) npath, err := filepath.EvalSymlinks(path)
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {

46
tools/utils/embed.go Normal file
View File

@ -0,0 +1,46 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"bytes"
"compress/bzip2"
"encoding/binary"
"fmt"
"io"
)
var _ = fmt.Print
func ReadAll(r io.Reader, expected_size int) ([]byte, error) {
b := make([]byte, 0, expected_size)
for {
if len(b) == cap(b) {
// Add more capacity (let append pick how much).
b = append(b, 0)[:len(b)]
}
n, err := r.Read(b[len(b):cap(b)])
b = b[:len(b)+n]
if err != nil {
if err == io.EOF {
err = nil
}
return b, err
}
}
}
func ReadCompressedEmbeddedData(raw string) []byte {
compressed := UnsafeStringToBytes(raw)
uncompressed_size := binary.LittleEndian.Uint32(compressed)
r := bzip2.NewReader(bytes.NewReader(compressed[4:]))
ans, err := ReadAll(r, int(uncompressed_size))
if err != nil {
panic(err)
}
return ans
}
func ReaderForCompressedEmbeddedData(raw string) io.Reader {
return bzip2.NewReader(bytes.NewReader(UnsafeStringToBytes(raw)[4:]))
}

166
tools/utils/iso8601.go Normal file
View File

@ -0,0 +1,166 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"fmt"
"strconv"
"strings"
"time"
)
var _ = fmt.Print
func is_digit(x byte) bool {
return '0' <= x && x <= '9'
}
// The following is copied from the Go standard library to implement date range validation logic
// equivalent to the behaviour of Go's time.Parse.
func isLeap(year int) bool {
return year%4 == 0 && (year%100 != 0 || year%400 == 0)
}
// daysInMonth is the number of days for non-leap years in each calendar month starting at 1
var daysInMonth = [13]int{0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}
func daysIn(m time.Month, year int) int {
if m == time.February && isLeap(year) {
return 29
}
return daysInMonth[int(m)]
}
func ISO8601Parse(raw string) (time.Time, error) {
orig := raw
raw = strings.TrimSpace(raw)
required_number := func(num_digits int) (int, error) {
if len(raw) < num_digits {
return 0, fmt.Errorf("Insufficient digits")
}
text := raw[:num_digits]
raw = raw[num_digits:]
ans, err := strconv.ParseUint(text, 10, 32)
return int(ans), err
}
optional_separator := func(x byte) bool {
if len(raw) > 0 && raw[0] == x {
raw = raw[1:]
}
return len(raw) > 0 && is_digit(raw[0])
}
errf := func(msg string) (time.Time, error) {
return time.Time{}, fmt.Errorf("Invalid ISO8601 timestamp: %#v. %s", orig, msg)
}
optional_separator('+')
year, err := required_number(4)
if err != nil {
return errf("timestamp does not start with a 4 digit year")
}
var month int = 1
var day int = 1
if optional_separator('-') {
month, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit month")
}
if optional_separator('-') {
day, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit day")
}
}
}
var hour, minute, second, nsec int
if len(raw) > 0 && (raw[0] == 'T' || raw[0] == ' ') {
raw = raw[1:]
hour, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit hour")
}
if optional_separator(':') {
minute, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit minute")
}
if optional_separator(':') {
second, err = required_number(2)
if err != nil {
return errf("timestamp does not have a valid 2 digit second")
}
}
}
if len(raw) > 0 && (raw[0] == '.' || raw[0] == ',') {
raw = raw[1:]
num_digits := 0
for len(raw) > num_digits && is_digit(raw[num_digits]) {
num_digits++
}
text := raw[:num_digits]
raw = raw[num_digits:]
extra := 9 - len(text)
if extra < 0 {
text = text[:9]
}
if text != "" {
n, err := strconv.ParseUint(text, 10, 64)
if err != nil {
return errf("timestamp does not have a valid nanosecond field")
}
nsec = int(n)
for ; extra > 0; extra-- {
nsec *= 10
}
}
}
}
switch {
case month < 1 || month > 12:
return errf("timestamp has invalid month value")
case day < 1 || day > 31 || day > daysIn(time.Month(month), year):
return errf("timestamp has invalid day value")
case hour < 0 || hour > 23:
return errf("timestamp has invalid hour value")
case minute < 0 || minute > 59:
return errf("timestamp has invalid minute value")
case second < 0 || second > 59:
return errf("timestamp has invalid second value")
}
loc := time.UTC
tzsign, tzhour, tzminute := 0, 0, 0
if len(raw) > 0 {
switch raw[0] {
case '+':
tzsign = 1
case '-':
tzsign = -1
}
}
if tzsign != 0 {
raw = raw[1:]
tzhour, err = required_number(2)
if err != nil {
return errf("timestamp has invalid timezone hour")
}
optional_separator(':')
tzminute, err = required_number(2)
if err != nil {
tzminute = 0
}
seconds := tzhour*3600 + tzminute*60
loc = time.FixedZone("", tzsign*seconds)
}
return time.Date(year, time.Month(month), day, hour, minute, second, nsec, loc), err
}
func ISO8601Format(x time.Time) string {
return x.Format(time.RFC3339Nano)
}

View File

@ -0,0 +1,40 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"fmt"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
func TestISO8601(t *testing.T) {
now := time.Now()
tt := func(raw string, expected time.Time) {
actual, err := ISO8601Parse(raw)
if err != nil {
t.Fatalf("Parsing: %#v failed with error: %s", raw, err)
}
if diff := cmp.Diff(expected, actual); diff != "" {
t.Fatalf("Parsing: %#v failed:\n%s", raw, diff)
}
}
tt(ISO8601Format(now), now)
tt("2023-02-08T07:24:09.551975+00:00", time.Date(2023, 2, 8, 7, 24, 9, 551975000, time.UTC))
tt("2023-02-08T07:24:09.551975Z", time.Date(2023, 2, 8, 7, 24, 9, 551975000, time.UTC))
tt("2023", time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC))
tt("2023-11-13", time.Date(2023, 11, 13, 0, 0, 0, 0, time.UTC))
tt("2023-11-13 07:23", time.Date(2023, 11, 13, 7, 23, 0, 0, time.UTC))
tt("2023-11-13 07:23:01", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC))
tt("2023-11-13 07:23:01.", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC))
tt("2023-11-13 07:23:01.0", time.Date(2023, 11, 13, 7, 23, 1, 0, time.UTC))
tt("2023-11-13 07:23:01.1", time.Date(2023, 11, 13, 7, 23, 1, 100000000, time.UTC))
tt("202311-13 07", time.Date(2023, 11, 13, 7, 0, 0, 0, time.UTC))
tt("20231113 0705", time.Date(2023, 11, 13, 7, 5, 0, 0, time.UTC))
}

View File

@ -11,12 +11,9 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
) )
var _ = fmt.Print var _ = fmt.Print
var user_mime_only_once sync.Once
var user_defined_mime_map map[string]string
func load_mime_file(filename string, mime_map map[string]string) error { func load_mime_file(filename string, mime_map map[string]string) error {
f, err := os.Open(filename) f, err := os.Open(filename)
@ -45,18 +42,19 @@ func load_mime_file(filename string, mime_map map[string]string) error {
return nil return nil
} }
func load_user_mime_maps() { var UserMimeMap = (&Once[map[string]string]{Run: func() map[string]string {
conf_path := filepath.Join(ConfigDir(), "mime.types") conf_path := filepath.Join(ConfigDir(), "mime.types")
err := load_mime_file(conf_path, user_defined_mime_map) ans := make(map[string]string, 32)
err := load_mime_file(conf_path, ans)
if err != nil && !errors.Is(err, fs.ErrNotExist) { if err != nil && !errors.Is(err, fs.ErrNotExist) {
fmt.Fprintln(os.Stderr, "Failed to parse", conf_path, "for MIME types with error:", err) fmt.Fprintln(os.Stderr, "Failed to parse", conf_path, "for MIME types with error:", err)
} }
} return ans
}}).Get
func GuessMimeType(filename string) string { func GuessMimeType(filename string) string {
user_mime_only_once.Do(load_user_mime_maps)
ext := filepath.Ext(filename) ext := filepath.Ext(filename)
mime_with_parameters := user_defined_mime_map[ext] mime_with_parameters := UserMimeMap()[ext]
if mime_with_parameters == "" { if mime_with_parameters == "" {
mime_with_parameters = mime.TypeByExtension(ext) mime_with_parameters = mime.TypeByExtension(ext)
} }

View File

@ -57,6 +57,14 @@ func Filter[T any](s []T, f func(x T) bool) []T {
return ans return ans
} }
func Map[T any](s []T, f func(x T) T) []T {
ans := make([]T, 0, len(s))
for _, x := range s {
ans = append(ans, f(x))
}
return ans
}
func Sort[T any](s []T, less func(a, b T) bool) []T { func Sort[T any](s []T, less func(a, b T) bool) []T {
sort.Slice(s, func(i, j int) bool { return less(s[i], s[j]) }) sort.Slice(s, func(i, j int) bool { return less(s[i], s[j]) })
return s return s

35
tools/utils/once.go Normal file
View File

@ -0,0 +1,35 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"fmt"
"sync"
"sync/atomic"
)
var _ = fmt.Print
type Once[T any] struct {
done uint32
mutex sync.Mutex
cached_val T
Run func() T
}
func (self *Once[T]) Get() T {
if atomic.LoadUint32(&self.done) == 0 {
self.do_slow()
}
return self.cached_val
}
func (self *Once[T]) do_slow() {
self.mutex.Lock()
defer self.mutex.Unlock()
if atomic.LoadUint32(&self.done) == 0 {
defer atomic.StoreUint32(&self.done, 1)
self.cached_val = self.Run()
}
}

24
tools/utils/once_test.go Normal file
View File

@ -0,0 +1,24 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package utils
import (
"fmt"
"testing"
)
var _ = fmt.Print
func TestOnce(t *testing.T) {
num := 0
var G = (&Once[string]{Run: func() string {
num++
return fmt.Sprintf("%d", num)
}}).Get
G()
G()
G()
if num != 1 {
t.Fatalf("num unexpectedly: %d", num)
}
}

View File

@ -3,13 +3,18 @@
package utils package utils
import ( import (
"crypto/rand"
"encoding/base32"
"fmt"
"io/fs" "io/fs"
not_rand "math/rand"
"os" "os"
"os/exec"
"os/user" "os/user"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -57,61 +62,57 @@ func Abspath(path string) string {
return path return path
} }
var config_dir, kitty_exe, cache_dir string var KittyExe = (&Once[string]{Run: func() string {
var kitty_exe_err error
var config_dir_once, kitty_exe_once, cache_dir_once sync.Once
func find_kitty_exe() {
exe, err := os.Executable() exe, err := os.Executable()
if err == nil { if err == nil {
kitty_exe = filepath.Join(filepath.Dir(exe), "kitty") return filepath.Join(filepath.Dir(exe), "kitty")
kitty_exe_err = unix.Access(kitty_exe, unix.X_OK)
} else {
kitty_exe_err = err
} }
} return ""
}}).Get
func KittyExe() (string, error) { var ConfigDir = (&Once[string]{Run: func() (config_dir string) {
kitty_exe_once.Do(find_kitty_exe) if kcd := os.Getenv("KITTY_CONFIG_DIRECTORY"); kcd != "" {
return kitty_exe, kitty_exe_err return Abspath(Expanduser(kcd))
} }
var locations []string
func find_config_dir() { seen := NewSet[string]()
if os.Getenv("KITTY_CONFIG_DIRECTORY") != "" { add := func(x string) {
config_dir = Abspath(Expanduser(os.Getenv("KITTY_CONFIG_DIRECTORY"))) x = Abspath(Expanduser(x))
} else { if !seen.Has(x) {
var locations []string seen.Add(x)
if os.Getenv("XDG_CONFIG_HOME") != "" { locations = append(locations, x)
locations = append(locations, os.Getenv("XDG_CACHE_HOME"))
} }
locations = append(locations, Expanduser("~/.config")) }
if runtime.GOOS == "darwin" { if xh := os.Getenv("XDG_CONFIG_HOME"); xh != "" {
locations = append(locations, Expanduser("~/Library/Preferences")) add(xh)
}
if dirs := os.Getenv("XDG_CONFIG_DIRS"); dirs != "" {
for _, candidate := range strings.Split(dirs, ":") {
add(candidate)
} }
for _, loc := range locations { }
if loc != "" { add("~/.config")
q := filepath.Join(loc, "kitty") if runtime.GOOS == "darwin" {
if _, err := os.Stat(filepath.Join(q, "kitty.conf")); err == nil { add("~/Library/Preferences")
config_dir = q }
break for _, loc := range locations {
} if loc != "" {
} q := filepath.Join(loc, "kitty")
} if _, err := os.Stat(filepath.Join(q, "kitty.conf")); err == nil {
for _, loc := range locations { config_dir = q
if loc != "" { return
config_dir = filepath.Join(loc, "kitty")
break
} }
} }
} }
} config_dir = os.Getenv("XDG_CONFIG_HOME")
if config_dir == "" {
config_dir = "~/.config"
}
config_dir = filepath.Join(Expanduser(config_dir), "kitty")
return
}}).Get
func ConfigDir() string { var CacheDir = (&Once[string]{Run: func() (cache_dir string) {
config_dir_once.Do(find_config_dir)
return config_dir
}
func find_cache_dir() {
candidate := "" candidate := ""
if edir := os.Getenv("KITTY_CACHE_DIRECTORY"); edir != "" { if edir := os.Getenv("KITTY_CACHE_DIRECTORY"); edir != "" {
candidate = Abspath(Expanduser(edir)) candidate = Abspath(Expanduser(edir))
@ -125,13 +126,71 @@ func find_cache_dir() {
candidate = filepath.Join(Expanduser(candidate), "kitty") candidate = filepath.Join(Expanduser(candidate), "kitty")
} }
os.MkdirAll(candidate, 0o755) os.MkdirAll(candidate, 0o755)
cache_dir = candidate return candidate
}}).Get
func macos_user_cache_dir() string {
// Sadly Go does not provide confstr() so we use this hack.
// Note that given a user generateduid and uid we can derive this by using
// the algorithm at https://github.com/ydkhatri/MacForensics/blob/master/darwin_path_generator.py
// but I cant find a good way to get the generateduid. Requires calling dscl in which case we might as well call getconf
// The data is in /var/db/dslocal/nodes/Default/users/<username>.plist but it needs root
// So instead we use various hacks to get it quickly, falling back to running /usr/bin/getconf
is_ok := func(m string) bool {
s, err := os.Stat(m)
if err != nil {
return false
}
stat, ok := s.Sys().(unix.Stat_t)
return ok && s.IsDir() && int(stat.Uid) == os.Geteuid() && s.Mode().Perm() == 0o700 && unix.Access(m, unix.X_OK|unix.W_OK|unix.R_OK) == nil
}
if tdir := strings.TrimRight(os.Getenv("TMPDIR"), "/"); filepath.Base(tdir) == "T" {
if m := filepath.Join(filepath.Dir(tdir), "C"); is_ok(m) {
return m
}
}
matches, err := filepath.Glob("/private/var/folders/*/*/C")
if err == nil {
for _, m := range matches {
if is_ok(m) {
return m
}
}
}
out, err := exec.Command("/usr/bin/getconf", "DARWIN_USER_CACHE_DIR").Output()
if err == nil {
return strings.TrimRight(strings.TrimSpace(UnsafeBytesToString(out)), "/")
}
return ""
} }
func CacheDir() string { var RuntimeDir = (&Once[string]{Run: func() (runtime_dir string) {
cache_dir_once.Do(find_cache_dir) var candidate string
return cache_dir if q := os.Getenv("KITTY_RUNTIME_DIRECTORY"); q != "" {
} candidate = q
} else if runtime.GOOS == "darwin" {
candidate = macos_user_cache_dir()
} else if q := os.Getenv("XDG_RUNTIME_DIR"); q != "" {
candidate = q
}
candidate = strings.TrimRight(candidate, "/")
if candidate == "" {
q := fmt.Sprintf("/run/user/%d", os.Geteuid())
if s, err := os.Stat(q); err == nil && s.IsDir() && unix.Access(q, unix.X_OK|unix.R_OK|unix.W_OK) == nil {
candidate = q
} else {
candidate = filepath.Join(CacheDir(), "run")
}
}
os.MkdirAll(candidate, 0o700)
if s, err := os.Stat(candidate); err == nil && s.Mode().Perm() != 0o700 {
os.Chmod(candidate, 0o700)
}
return candidate
}}).Get
type Walk_callback func(path, abspath string, d fs.DirEntry, err error) error type Walk_callback func(path, abspath string, d fs.DirEntry, err error) error
@ -205,3 +264,21 @@ func WalkWithSymlink(dirpath string, callback Walk_callback, transformers ...fun
seen: make(map[string]bool), real_callback: callback, transform_func: transform, needs_recurse_func: needs_symlink_recurse} seen: make(map[string]bool), real_callback: callback, transform_func: transform, needs_recurse_func: needs_symlink_recurse}
return sw.walk(dirpath) return sw.walk(dirpath)
} }
func RandomFilename() string {
b := []byte{0, 0, 0, 0, 0, 0, 0, 0}
_, err := rand.Read(b)
if err != nil {
return strconv.FormatUint(uint64(not_rand.Uint32()), 16)
}
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b)
}
func ResolveConfPath(path string) string {
cs := os.ExpandEnv(Expanduser(path))
if !filepath.IsAbs(cs) {
cs = filepath.Join(ConfigDir(), cs)
}
return cs
}

View File

@ -0,0 +1,65 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package paths
import (
"fmt"
"os"
"path/filepath"
"strings"
"kitty/tools/utils"
)
var _ = fmt.Print
type Ctx struct {
home, cwd string
}
func (ctx *Ctx) SetHome(val string) {
ctx.home = val
}
func (ctx *Ctx) SetCwd(val string) {
ctx.cwd = val
}
func (ctx *Ctx) HomePath() (ans string) {
ans = ctx.home
if ans == "" {
ans = utils.Expanduser("~")
}
return
}
func (ctx *Ctx) CwdPath() (ans string) {
ans = ctx.cwd
if ans == "" {
var err error
ans, err = os.Getwd()
if err != nil {
ans = "."
}
}
return
}
func abspath(path, base string) (ans string) {
return filepath.Join(base, path)
}
func (ctx *Ctx) Abspath(path string) (ans string) {
return abspath(path, ctx.CwdPath())
}
func (ctx *Ctx) AbspathFromHome(path string) (ans string) {
return abspath(path, ctx.HomePath())
}
func (ctx *Ctx) ExpandHome(path string) (ans string) {
if strings.HasPrefix(path, "~/") {
return ctx.AbspathFromHome(path)
}
return path
}

View File

@ -0,0 +1,42 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package secrets
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
)
var _ = fmt.Print
const DEFAULT_NUM_OF_BYTES_FOR_TOKEN = 32
func TokenBytes(nbytes ...int) ([]byte, error) {
if len(nbytes) == 0 {
nbytes = []int{DEFAULT_NUM_OF_BYTES_FOR_TOKEN}
}
buf := make([]byte, nbytes[0])
_, err := rand.Read(buf)
if err != nil {
return nil, err
}
return buf, nil
}
func TokenHex(nbytes ...int) (string, error) {
b, err := TokenBytes(nbytes...)
if err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
func TokenBase64(nbytes ...int) (string, error) {
b, err := TokenBytes(nbytes...)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}

View File

@ -4,6 +4,8 @@ package utils
import ( import (
"fmt" "fmt"
"golang.org/x/exp/maps"
) )
var _ = fmt.Print var _ = fmt.Print
@ -22,6 +24,10 @@ func (self *Set[T]) AddItems(val ...T) {
} }
} }
func (self *Set[T]) String() string {
return fmt.Sprintf("%#v", maps.Keys(self.items))
}
func (self *Set[T]) Remove(val T) { func (self *Set[T]) Remove(val T) {
delete(self.items, val) delete(self.items, val)
} }
@ -68,6 +74,25 @@ func (self *Set[T]) Intersect(other *Set[T]) (ans *Set[T]) {
return return
} }
func (self *Set[T]) Subtract(other *Set[T]) (ans *Set[T]) {
ans = NewSet[T](self.Len())
for x := range self.items {
if !other.Has(x) {
ans.items[x] = struct{}{}
}
}
return ans
}
func (self *Set[T]) IsSubsetOf(other *Set[T]) bool {
for x := range self.items {
if !other.Has(x) {
return false
}
}
return true
}
func NewSet[T comparable](capacity ...int) (ans *Set[T]) { func NewSet[T comparable](capacity ...int) (ans *Set[T]) {
if len(capacity) == 0 { if len(capacity) == 0 {
ans = &Set[T]{items: make(map[T]struct{}, 8)} ans = &Set[T]{items: make(map[T]struct{}, 8)}
@ -76,3 +101,9 @@ func NewSet[T comparable](capacity ...int) (ans *Set[T]) {
} }
return return
} }
func NewSetWithItems[T comparable](items ...T) (ans *Set[T]) {
ans = NewSet[T](len(items))
ans.AddItems(items...)
return ans
}

View File

@ -3,15 +3,16 @@
package shm package shm
import ( import (
"crypto/rand" "encoding/binary"
"encoding/base32"
"errors" "errors"
"fmt" "fmt"
not_rand "math/rand" "io"
"io/fs"
"os" "os"
"strconv"
"strings" "strings"
"kitty/tools/cli"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -43,15 +44,6 @@ func prefix_and_suffix(pattern string) (prefix, suffix string, err error) {
return prefix, suffix, nil return prefix, suffix, nil
} }
func next_random() string {
b := make([]byte, 8)
_, err := rand.Read(b)
if err != nil {
return strconv.FormatUint(uint64(not_rand.Uint32()), 16)
}
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b)
}
type MMap interface { type MMap interface {
Close() error Close() error
Unlink() error Unlink() error
@ -59,6 +51,13 @@ type MMap interface {
Name() string Name() string
IsFileSystemBacked() bool IsFileSystemBacked() bool
FileSystemName() string FileSystemName() string
Stat() (fs.FileInfo, error)
Flush() error
Seek(offset int64, whence int) (int64, error)
Read(b []byte) (int, error)
ReadWithSize() ([]byte, error)
Write(p []byte) (n int, err error)
WriteWithSize([]byte) error
} }
type AccessFlags int type AccessFlags int
@ -109,3 +108,78 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) {
} }
return return
} }
func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) {
p := buf
for len(p) > 0 {
n, err := f.Read(p)
p = p[n:]
if err != nil {
if len(p) == 0 && errors.Is(err, io.EOF) {
err = nil
}
return buf[:len(buf)-len(p)], err
}
}
return buf, nil
}
func read_with_size(f *os.File) ([]byte, error) {
szbuf := []byte{0, 0, 0, 0}
szbuf, err := read_till_buf_full(f, szbuf)
if err != nil {
return nil, err
}
size := int(binary.BigEndian.Uint32(szbuf))
return read_till_buf_full(f, make([]byte, size))
}
func write_with_size(f *os.File, b []byte) error {
szbuf := []byte{0, 0, 0, 0}
binary.BigEndian.PutUint32(szbuf, uint32(len(b)))
_, err := f.Write(szbuf)
if err == nil {
_, err = f.Write(b)
}
return err
}
func test_integration_with_python(args []string) (rc int, err error) {
switch args[0] {
default:
return 1, fmt.Errorf("Unknown test type: %s", args[0])
case "read":
data, err := ReadWithSizeAndUnlink(args[1])
if err != nil {
return 1, err
}
_, err = os.Stdout.Write(data)
if err != nil {
return 1, err
}
case "write":
data, err := io.ReadAll(os.Stdin)
if err != nil {
return 1, err
}
mmap, err := CreateTemp("shmtest-", uint64(len(data)+4))
if err != nil {
return 1, err
}
mmap.WriteWithSize(data)
mmap.Close()
fmt.Println(mmap.Name())
}
return 0, nil
}
func TestEntryPoint(root *cli.Command) {
root.AddSubCommand(&cli.Command{
Name: "shm",
OnlyArgsAllowed: true,
Run: func(cmd *cli.Command, args []string) (rc int, err error) {
return test_integration_with_python(args)
},
})
}

View File

@ -8,10 +8,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"kitty/tools/utils"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"kitty/tools/utils"
"golang.org/x/sys/unix"
) )
var _ = fmt.Print var _ = fmt.Print
@ -39,6 +42,10 @@ func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool, speci
return &file_based_mmap{f: f, region: region, special_name: special_name}, nil return &file_based_mmap{f: f, region: region, special_name: special_name}, nil
} }
func (self *file_based_mmap) Stat() (fs.FileInfo, error) {
return self.f.Stat()
}
func (self *file_based_mmap) Name() string { func (self *file_based_mmap) Name() string {
if self.special_name != "" { if self.special_name != "" {
return self.special_name return self.special_name
@ -46,6 +53,30 @@ func (self *file_based_mmap) Name() string {
return filepath.Base(self.f.Name()) return filepath.Base(self.f.Name())
} }
func (self *file_based_mmap) Flush() error {
return unix.Msync(self.region, unix.MS_SYNC)
}
func (self *file_based_mmap) Seek(offset int64, whence int) (int64, error) {
return self.f.Seek(offset, whence)
}
func (self *file_based_mmap) Read(b []byte) (int, error) {
return self.f.Read(b)
}
func (self *file_based_mmap) Write(b []byte) (int, error) {
return self.f.Write(b)
}
func (self *file_based_mmap) WriteWithSize(b []byte) error {
return write_with_size(self.f, b)
}
func (self *file_based_mmap) ReadWithSize() ([]byte, error) {
return read_with_size(self.f)
}
func (self *file_based_mmap) FileSystemName() string { func (self *file_based_mmap) FileSystemName() string {
return self.f.Name() return self.f.Name()
} }
@ -92,7 +123,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
var f *os.File var f *os.File
try := 0 try := 0
for { for {
name := prefix + next_random() + suffix name := prefix + utils.RandomFilename() + suffix
path := file_path_from_name(name) path := file_path_from_name(name)
f, err = os.OpenFile(path, os.O_EXCL|os.O_CREATE|os.O_RDWR, 0600) f, err = os.OpenFile(path, os.O_EXCL|os.O_CREATE|os.O_RDWR, 0600)
if err != nil { if err != nil {
@ -113,7 +144,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
return file_mmap(f, size, WRITE, true, special_name) return file_mmap(f, size, WRITE, true, special_name)
} }
func Open(name string, size uint64) (MMap, error) { func open(name string) (*os.File, error) {
ans, err := os.OpenFile(file_path_from_name(name), os.O_RDONLY, 0) ans, err := os.OpenFile(file_path_from_name(name), os.O_RDONLY, 0)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
@ -123,5 +154,29 @@ func Open(name string, size uint64) (MMap, error) {
} }
return nil, err return nil, err
} }
return ans, nil
}
func Open(name string, size uint64) (MMap, error) {
ans, err := open(name)
if err != nil {
return nil, err
}
return file_mmap(ans, size, READ, false, name) return file_mmap(ans, size, READ, false, name)
} }
func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) {
f, err := open(name)
if err != nil {
return nil, err
}
defer f.Close()
defer os.Remove(f.Name())
for _, cb := range file_callback {
err = cb(f)
if err != nil {
return nil, err
}
}
return read_with_size(f)
}

View File

@ -11,6 +11,8 @@ import (
"strings" "strings"
"unsafe" "unsafe"
"kitty/tools/utils"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -86,11 +88,38 @@ func syscall_mmap(f *os.File, size uint64, access AccessFlags, truncate bool) (M
func (self *syscall_based_mmap) Name() string { func (self *syscall_based_mmap) Name() string {
return self.f.Name() return self.f.Name()
} }
func (self *syscall_based_mmap) Stat() (fs.FileInfo, error) {
return self.f.Stat()
}
func (self *syscall_based_mmap) Flush() error {
return unix.Msync(self.region, unix.MS_SYNC)
}
func (self *syscall_based_mmap) Slice() []byte { func (self *syscall_based_mmap) Slice() []byte {
return self.region return self.region
} }
func (self *syscall_based_mmap) Seek(offset int64, whence int) (int64, error) {
return self.f.Seek(offset, whence)
}
func (self *syscall_based_mmap) Read(b []byte) (int, error) {
return self.f.Read(b)
}
func (self *syscall_based_mmap) Write(b []byte) (int, error) {
return self.f.Write(b)
}
func (self *syscall_based_mmap) WriteWithSize(b []byte) error {
return write_with_size(self.f, b)
}
func (self *syscall_based_mmap) ReadWithSize() ([]byte, error) {
return read_with_size(self.f)
}
func (self *syscall_based_mmap) Close() (err error) { func (self *syscall_based_mmap) Close() (err error) {
if self.region != nil { if self.region != nil {
self.f.Close() self.f.Close()
@ -124,7 +153,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
var f *os.File var f *os.File
try := 0 try := 0
for { for {
name := prefix + next_random() + suffix name := prefix + utils.RandomFilename() + suffix
if len(name) > SHM_NAME_MAX { if len(name) > SHM_NAME_MAX {
return nil, ErrPatternTooLong return nil, ErrPatternTooLong
} }
@ -151,3 +180,19 @@ func Open(name string, size uint64) (MMap, error) {
} }
return syscall_mmap(ans, size, READ, false) return syscall_mmap(ans, size, READ, false)
} }
func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) {
f, err := shm_open(name, os.O_RDONLY, 0)
if err != nil {
return nil, err
}
defer f.Close()
defer shm_unlink(f.Name())
for _, cb := range file_callback {
err = cb(f)
if err != nil {
return nil, err
}
}
return read_with_size(f)
}

View File

@ -23,6 +23,10 @@ func TestSHM(t *testing.T) {
} }
copy(mm.Slice(), data) copy(mm.Slice(), data)
err = mm.Flush()
if err != nil {
t.Fatalf("Failed to msync() with error: %v", err)
}
err = mm.Close() err = mm.Close()
if err != nil { if err != nil {
t.Fatalf("Failed to close with error: %v", err) t.Fatalf("Failed to close with error: %v", err)

View File

@ -57,6 +57,10 @@ type RGBA struct {
Red, Green, Blue, Inverse_alpha uint8 Red, Green, Blue, Inverse_alpha uint8
} }
func (self RGBA) AsRGBSharp() string {
return fmt.Sprintf("#%02x%02x%02x", self.Red, self.Green, self.Blue)
}
func (self *RGBA) parse_rgb_strings(r string, g string, b string) bool { func (self *RGBA) parse_rgb_strings(r string, g string, b string) bool {
var rv, gv, bv uint64 var rv, gv, bv uint64
var err error var err error
@ -77,6 +81,12 @@ func (self *RGBA) AsRGB() uint32 {
return uint32(self.Blue) | (uint32(self.Green) << 8) | (uint32(self.Red) << 16) return uint32(self.Blue) | (uint32(self.Green) << 8) | (uint32(self.Red) << 16)
} }
func (self *RGBA) FromRGB(col uint32) {
self.Red = uint8((col >> 16) & 0xff)
self.Green = uint8((col >> 8) & 0xff)
self.Blue = uint8((col) & 0xff)
}
type color_type struct { type color_type struct {
is_numbered bool is_numbered bool
val RGBA val RGBA

View File

@ -13,15 +13,18 @@ import (
var _ = fmt.Print var _ = fmt.Print
func Which(cmd string) string { func Which(cmd string, paths ...string) string {
if strings.Contains(cmd, string(os.PathSeparator)) { if strings.Contains(cmd, string(os.PathSeparator)) {
return "" return ""
} }
path := os.Getenv("PATH") if len(paths) == 0 {
if path == "" { path := os.Getenv("PATH")
return "" if path == "" {
return ""
}
paths = strings.Split(path, string(os.PathListSeparator))
} }
for _, dir := range strings.Split(path, string(os.PathListSeparator)) { for _, dir := range paths {
q := filepath.Join(dir, cmd) q := filepath.Join(dir, cmd)
if unix.Access(q, unix.X_OK) == nil { if unix.Access(q, unix.X_OK) == nil {
s, err := os.Stat(q) s, err := os.Stat(q)