Implement setting of env vars

This commit is contained in:
Kovid Goyal 2022-02-27 14:47:17 +05:30
parent c6f37afeff
commit ae6665493a
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
5 changed files with 95 additions and 44 deletions

View File

@ -3,6 +3,7 @@
import atexit
import io
import json
import os
import re
import shlex
@ -17,15 +18,36 @@ from typing import (
)
from kitty.constants import cache_dir, shell_integration_dir, terminfo_dir
from kitty.fast_data_types import get_options
from kitty.short_uuid import uuid4
from kitty.types import run_once
from kitty.utils import SSHConnectionData
from .completion import complete, ssh_options
from .options.types import Options as SSHOptions
from .options.utils import DELETE_ENV_VAR
def make_tarfile(ssh_opts: SSHOptions) -> bytes:
def serialize_env(env: Dict[str, str], base_env: Dict[str, str]) -> bytes:
lines = []
def a(k: str, val: str) -> None:
lines.append(f'export {k}={shlex.quote(val)}')
for k in sorted(env):
v = env[k]
if v is DELETE_ENV_VAR:
lines.append(f'unset {shlex.quote(k)}')
elif v == '_kitty_copy_env_var_':
q = base_env.get(k)
if q is not None:
a(k, q)
else:
a(k, v)
return '\n'.join(lines).encode('utf-8')
def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str]) -> bytes:
def normalize_tarinfo(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo:
tarinfo.uname = tarinfo.gname = 'kitty'
@ -48,9 +70,6 @@ def make_tarfile(ssh_opts: SSHOptions) -> bytes:
return None
return normalize_tarinfo(tarinfo)
buf = io.BytesIO()
with tarfile.open(mode='w:bz2', fileobj=buf, encoding='utf-8') as tf:
rd = ssh_opts.remote_dir.rstrip('/')
from kitty.shell_integration import get_effective_ksi_env_var
if ssh_opts.shell_integration == 'inherit':
ksi = get_effective_ksi_env_var()
@ -58,9 +77,24 @@ def make_tarfile(ssh_opts: SSHOptions) -> bytes:
from kitty.options.types import Options
from kitty.options.utils import shell_integration
ksi = get_effective_ksi_env_var(Options({'shell_integration': shell_integration(ssh_opts.shell_integration)}))
env = {
'TERM': get_options().term,
'COLORTERM': 'truecolor',
}
for q in ('KITTY_WINDOW_ID', 'WINDOWID'):
val = os.environ.get(q)
if val is not None:
env[q] = val
env.update(ssh_opts.env)
env['KITTY_SHELL_INTEGRATION'] = ksi or DELETE_ENV_VAR
env_script = serialize_env(env, base_env)
buf = io.BytesIO()
with tarfile.open(mode='w:bz2', fileobj=buf, encoding='utf-8') as tf:
rd = ssh_opts.remote_dir.rstrip('/')
add_data_as_file(tf, rd + '/settings/env-vars.sh', env_script)
if ksi:
tf.add(shell_integration_dir, arcname=rd + '/shell-integration', filter=filter_files)
add_data_as_file(tf, rd + '/settings/ksi_env_var', ksi)
tf.add(terminfo_dir, arcname='.terminfo', filter=filter_files)
return buf.getvalue()
@ -92,16 +126,19 @@ def get_ssh_data(msg: str, ssh_opts: Optional[Dict[str, SSHOptions]] = None) ->
yield fmt_prefix('!invalid ssh data request message')
else:
try:
with open(os.path.join(cache_dir(), pwfilename)) as f:
with open(os.path.join(cache_dir(), pwfilename), 'rb') as f:
os.unlink(f.name)
if pw != f.read():
env_data = json.load(f)
if pw != env_data['pw']:
raise ValueError('Incorrect password')
except Exception:
import traceback
traceback.print_exc()
yield fmt_prefix('!incorrect ssh data password')
else:
resolved_ssh_opts = options_for_host(hostname, ssh_opts)
try:
data = make_tarfile(resolved_ssh_opts)
data = make_tarfile(resolved_ssh_opts, env_data['env'])
except Exception:
yield fmt_prefix('!error while gathering ssh data')
else:
@ -119,8 +156,9 @@ def safe_remove(x: str) -> None:
def prepare_script(ans: str, replacements: Dict[str, str]) -> str:
pw = uuid4()
with tempfile.NamedTemporaryFile(prefix='ssh-kitten-pw-', dir=cache_dir(), delete=False) as tf:
tf.write(pw.encode('utf-8'))
with tempfile.NamedTemporaryFile(prefix='ssh-kitten-pw-', suffix='.json', dir=cache_dir(), delete=False) as tf:
data = {'pw': pw, 'env': dict(os.environ)}
tf.write(json.dumps(data).encode('utf-8'))
atexit.register(safe_remove, tf.name)
replacements['DATA_PASSWORD'] = pw
replacements['PASSWORD_FILENAME'] = os.path.basename(tf.name)

View File

@ -28,13 +28,13 @@ to SSH to connect to it.
opt('remote_dir', '.local/share/kitty-ssh-kitten', long_text='''
The location on the remote computer where the files needed for this kitten
are installed. The location is relative to the HOME directory.
are installed. The location is relative to the HOME directory for relative paths.
''')
opt('shell_integration', 'inherit', long_text='''
Control the shell integration on the remote host. See ref:`shell_integration`
for details on how this setting works. The special value :code:`inherit` means
use the setting from kitty.conf. This setting is mainly useful for overriding
use the setting from kitty.conf. This setting is useful for overriding
integration on a per-host basis.''')
opt('+env', '', option_type='env', add_to_default=False, long_text='''
@ -47,6 +47,7 @@ environment variables can refer to each other, so if you use::
The value of MYVAR2 will be :code:`a/<path to home directory>/b`. Using
:code:`VAR=` will set it to the empty string and using just :code:`VAR`
will delete the variable from the child process' environment. The definitions
are processed alphabetically.
are processed alphabetically. The special value :code:`_kitty_copy_env_var_`
will cause the value of the variable to be copied from the local machine.
''')
egr() # }}}

View File

@ -97,7 +97,7 @@ class Callbacks:
def handle_remote_ssh(self, msg):
from kittens.ssh.main import get_ssh_data
if self.pty:
for line in get_ssh_data(msg):
for line in get_ssh_data(msg, {'*': self.pty.ssh_opts} if self.pty.ssh_opts else None):
self.pty.write_to_child(line)
def handle_remote_echo(self, msg):
@ -159,9 +159,9 @@ class BaseTest(TestCase):
s = Screen(c, lines, cols, scrollback, cell_width, cell_height, 0, c)
return s
def create_pty(self, argv, cols=80, lines=25, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None):
def create_pty(self, argv, cols=80, lines=25, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None, ssh_opts=None):
self.set_options(options)
return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd, env)
return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd, env, ssh_opts)
def assertEqualAttributes(self, c1, c2):
x1, y1, c1.x, c1.y = c1.x, c1.y, 0, 0
@ -174,7 +174,12 @@ class BaseTest(TestCase):
class PTY:
def __init__(self, argv, rows=25, columns=80, scrollback=100, cell_width=10, cell_height=20, cwd=None, env=None):
def __init__(self, argv, rows=25, columns=80, scrollback=100, cell_width=10, cell_height=20, cwd=None, env=None, ssh_opts=None):
if ssh_opts:
from kittens.ssh.options.types import Options as SSHOptions
self.ssh_opts = SSHOptions(ssh_opts or {})
else:
self.ssh_opts = None
if isinstance(argv, str):
argv = shlex.split(argv)
pid, self.master_fd = fork()

View File

@ -6,9 +6,11 @@ import os
import shlex
import shutil
import tempfile
from functools import lru_cache
from kittens.ssh.config import load_config, options_for_host
from kittens.ssh.main import bootstrap_script, get_connection_data
from kittens.ssh.options.utils import DELETE_ENV_VAR
from kitty.constants import is_macos
from kitty.fast_data_types import CURSOR_BEAM
from kitty.options.utils import shell_integration
@ -65,10 +67,25 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
self.ae(for_host('x', 'env a=').env, {'a': ''})
self.ae(for_host('x', 'env a').env, {'a': '_delete_this_env_var_'})
@property
@lru_cache()
def all_possible_sh(self):
return tuple(sh for sh in ('dash', 'zsh', 'bash', 'posh', 'sh') if shutil.which(sh))
def test_ssh_bootstrap_script(self):
# test setting env vars
with tempfile.TemporaryDirectory() as tdir:
pty = self.check_bootstrap(
'dash', tdir, extra_exec='env; exit 0', SHELL_INTEGRATION_VALUE='',
ssh_opts={'env': {
'TSET': 'set-works',
'COLORTERM': DELETE_ENV_VAR,
}}
)
pty.wait_till(lambda: 'TSET=set-works' in pty.screen_contents())
self.assertNotIn('COLORTERM', pty.screen_contents())
# test handling of data in tty before tarfile is sent
all_possible_sh = tuple(sh for sh in ('dash', 'zsh', 'bash', 'posh', 'sh') if shutil.which(sh))
for sh in all_possible_sh:
for sh in self.all_possible_sh:
with self.subTest(sh=sh), tempfile.TemporaryDirectory() as tdir:
pty = self.check_bootstrap(
sh, tdir, extra_exec='echo "ld:$leading_data"; exit 0',
@ -90,15 +107,15 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
import pwd
expected_login_shell = pwd.getpwuid(os.geteuid()).pw_shell
for m in methods:
for sh in all_possible_sh:
for sh in self.all_possible_sh:
with self.subTest(sh=sh, method=m), tempfile.TemporaryDirectory() as tdir:
pty = self.check_bootstrap(sh, tdir, extra_exec=f'{m}; echo "$login_shell"; exit 0', SHELL_INTEGRATION_VALUE='')
self.assertIn(expected_login_shell, pty.screen_contents())
# check that shell integration works
ok_login_shell = ''
for sh in all_possible_sh:
for login_shell in {'fish', 'zsh', 'bash'} & set(all_possible_sh):
for sh in self.all_possible_sh:
for login_shell in {'fish', 'zsh', 'bash'} & set(self.all_possible_sh):
if login_shell == 'bash' and not bash_ok():
continue
ok_login_shell = login_shell
@ -110,7 +127,7 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
with tempfile.TemporaryDirectory() as tdir:
self.check_bootstrap('sh', tdir, ok_login_shell, val)
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', extra_exec='', pre_data=''):
def check_bootstrap(self, sh, home_dir, login_shell='', SHELL_INTEGRATION_VALUE='enabled', extra_exec='', pre_data='', ssh_opts=None):
script = bootstrap_script(
EXEC_CMD=f'echo "UNTAR_DONE"; {extra_exec}',
OVERRIDE_LOGIN_SHELL=login_shell,
@ -121,7 +138,7 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
# prevent newuser-install from running
open(os.path.join(home_dir, '.zshrc'), 'w').close()
options = {'shell_integration': shell_integration(SHELL_INTEGRATION_VALUE or 'disabled')}
pty = self.create_pty(f'{sh} -c {shlex.quote(script)}', cwd=home_dir, env=env, options=options)
pty = self.create_pty(f'{sh} -c {shlex.quote(script)}', cwd=home_dir, env=env, options=options, ssh_opts=ssh_opts)
if pre_data:
pty.write_buf = pre_data.encode('utf-8')
del script

View File

@ -70,14 +70,7 @@ get_data() {
die "$size"
;;
esac
data_dir=$(read_record)
case "$data_dir" in
("/"*)
;;
(*)
data_dir="$HOME/$data_dir"
;;
esac
data_dir="$HOME/$(read_record)"
# using dd with bs=1 is very slow on Linux, so use head
command head -c "$size" < /dev/tty | untar
rc="$?";
@ -95,7 +88,8 @@ fi
if [ "$rc" != "0" ]; then die "Failed to extract data transmitted by ssh kitten over the TTY device"; fi
[ -f "$HOME/.terminfo/kitty.terminfo" ] || die "Incomplete extraction of ssh data, no kitty.terminfo found";
shell_integration_dir="$data_dir/shell-integration"
shell_integration_settings_file="$data_dir/settings/ksi_env_var"
settings_dir="$data_dir/settings"
env_var_file="$settings_dir/env-vars.sh"
# export TERMINFO
tname=".terminfo"
@ -105,6 +99,9 @@ if [ -e "/usr/share/misc/terminfo.cdb" ]; then
fi
export TERMINFO="$HOME/$tname"
# setup env vars
. "$env_var_file"
# compile terminfo for this system
if [ -x "$(command -v tic)" ]; then
tic_out=$(command tic -x -o "$HOME/$tname" "$HOME/.terminfo/kitty.terminfo" 2>&1)
@ -198,13 +195,6 @@ else
fi
shell_name=$(basename $login_shell)
# read the variable and remove all leading and trailing spaces and collapse multiple spaces using xargs
if [ -f "$shell_integration_settings_file" ]; then
export KITTY_SHELL_INTEGRATION="$(cat $shell_integration_settings_file | xargs echo)"
else
unset KITTY_SHELL_INTEGRATION
fi
exec_bash_with_integration() {
export ENV="$shell_integration_dir/bash/kitty.bash"
export KITTY_BASH_INJECT="1"