From c9cc8328750d784c1ae66f0642621c3eee1acb5f Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 21 Feb 2022 17:57:25 +0530 Subject: [PATCH] Add basic tests for zsh shell integration --- kitty/shell_integration.py | 3 +- kitty/utils.py | 22 +++++++----- kitty_tests/__init__.py | 43 +++++++++++++++++------ kitty_tests/shell_integration.py | 59 ++++++++++++++++++++++++++++++++ kitty_tests/ssh.py | 4 +-- 5 files changed, 109 insertions(+), 22 deletions(-) create mode 100644 kitty_tests/shell_integration.py diff --git a/kitty/shell_integration.py b/kitty/shell_integration.py index b3de5186f..dc4c6c903 100644 --- a/kitty/shell_integration.py +++ b/kitty/shell_integration.py @@ -74,7 +74,8 @@ def is_new_zsh_install(env: Dict[str, str]) -> bool: # the latter will bail if there are rc files in $HOME zdotdir = env.get('ZDOTDIR') if not zdotdir: - zdotdir = os.path.expanduser('~') + zdotdir = env.get('HOME', os.path.expanduser('~')) + assert isinstance(zdotdir, str) if zdotdir == '~': return True for q in ('.zshrc', '.zshenv', '.zprofile', '.zlogin'): diff --git a/kitty/utils.py b/kitty/utils.py index 78649162b..9d67fb82c 100644 --- a/kitty/utils.py +++ b/kitty/utils.py @@ -148,6 +148,19 @@ class ScreenSize(NamedTuple): cell_height: int +def read_screen_size(fd: int = -1) -> ScreenSize: + import array + import fcntl + import termios + buf = array.array('H', [0, 0, 0, 0]) + if fd < 0: + fd = sys.stdout.fileno() + fcntl.ioctl(fd, termios.TIOCGWINSZ, cast(bytearray, buf)) + rows, cols, width, height = tuple(buf) + cell_width, cell_height = width // (cols or 1), height // (rows or 1) + return ScreenSize(rows, cols, width, height, cell_width, cell_height) + + class ScreenSizeGetter: changed = True Size = ScreenSize @@ -160,14 +173,7 @@ class ScreenSizeGetter: def __call__(self) -> ScreenSize: if self.changed: - import array - import fcntl - import termios - buf = array.array('H', [0, 0, 0, 0]) - fcntl.ioctl(self.fd, termios.TIOCGWINSZ, cast(bytearray, buf)) - rows, cols, width, height = tuple(buf) - cell_width, cell_height = width // (cols or 1), height // (rows or 1) - self.ans = ScreenSize(rows, cols, width, height, cell_width, cell_height) + self.ans = read_screen_size() self.changed = False return cast(ScreenSize, self.ans) diff --git a/kitty_tests/__init__.py b/kitty_tests/__init__.py index b64148377..8d42561a0 100644 --- a/kitty_tests/__init__.py +++ b/kitty_tests/__init__.py @@ -7,8 +7,8 @@ import os import select import shlex import struct -import sys import termios +import time from pty import CHILD, fork from unittest import TestCase @@ -19,7 +19,7 @@ from kitty.fast_data_types import ( from kitty.options.parse import merge_result_dicts from kitty.options.types import Options, defaults from kitty.types import MouseEvent -from kitty.utils import no_echo, write_all +from kitty.utils import read_screen_size, write_all class Callbacks: @@ -139,9 +139,9 @@ class BaseTest(TestCase): s = Screen(c, lines, cols, scrollback, cell_width, cell_height, 0, c) return s - def create_pty(self, argv, cols=80, lines=25, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None): + def create_pty(self, argv, cols=80, lines=25, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None): self.set_options(options) - return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd) + return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd, env) def assertEqualAttributes(self, c1, c2): x1, y1, c1.x, c1.y = c1.x, c1.y, 0, 0 @@ -154,23 +154,27 @@ class BaseTest(TestCase): class PTY: - def __init__(self, argv, rows=25, columns=80, scrollback=100, cell_width=10, cell_height=20, cwd=None): + def __init__(self, argv, rows=25, columns=80, scrollback=100, cell_width=10, cell_height=20, cwd=None, env=None): pid, self.master_fd = fork() self.is_child = pid == CHILD if self.is_child: + while read_screen_size().width != columns * cell_width: + time.sleep(0.01) if cwd: os.chdir(cwd) + if env: + os.environ.clear() + os.environ.update(env) if isinstance(argv, str): argv = shlex.split(argv) - with no_echo(): - sys.stdin.readline() os.execlp(argv[0], *argv) os.set_blocking(self.master_fd, False) + self.cell_width = cell_width + self.cell_height = cell_height self.set_window_size(rows=rows, columns=columns) new = termios.tcgetattr(self.master_fd) new[3] = new[3] & ~termios.ECHO termios.tcsetattr(self.master_fd, termios.TCSADRAIN, new) - self.write_to_child('ready\r\n') self.callbacks = Callbacks() self.screen = Screen(self.callbacks, rows, columns, scrollback, cell_width, cell_height, 0, self.callbacks) @@ -186,7 +190,11 @@ class PTY: rd = select.select([self.master_fd], [], [], timeout)[0] return bool(rd) - def process_input_from_child(self): + def send_cmd_to_child(self, cmd): + self.write_to_child(cmd + '\r') + + def process_input_from_child(self, timeout=10): + self.wait_for_input_from_child(timeout=10) bytes_read = 0 while True: try: @@ -199,7 +207,16 @@ class PTY: parse_bytes(self.screen, data) return bytes_read - def set_window_size(self, rows=25, columns=80, x_pixels=0, y_pixels=0): + def wait_till(self, q, timeout=10): + st = time.monotonic() + while not q() and time.monotonic() - st < timeout: + self.process_input_from_child(timeout=timeout - (time.monotonic() - st)) + if not q(): + raise TimeoutError('The condition was not met') + + def set_window_size(self, rows=25, columns=80): + x_pixels = columns * self.cell_width + y_pixels = rows * self.cell_height s = struct.pack('HHHH', rows, columns, x_pixels, y_pixels) fcntl.ioctl(self.master_fd, termios.TIOCSWINSZ, s) @@ -210,3 +227,9 @@ class PTY: if x: lines.append(x) return '\n'.join(lines) + + def last_cmd_output(self, as_ansi=False, add_wrap_markers=False): + lines = [] + from kitty.window import CommandOutput + self.screen.cmd_output(CommandOutput.last_run, lines.append, as_ansi, add_wrap_markers) + return ''.join(lines) diff --git a/kitty_tests/shell_integration.py b/kitty_tests/shell_integration.py new file mode 100644 index 000000000..5d5346536 --- /dev/null +++ b/kitty_tests/shell_integration.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# License: GPLv3 Copyright: 2022, Kovid Goyal + + +import os +from contextlib import contextmanager +from tempfile import TemporaryDirectory + +from kitty.constants import terminfo_dir +from kitty.fast_data_types import CURSOR_BEAM +from kitty.shell_integration import setup_zsh_env + +from . import BaseTest + + +def safe_env_for_running_shell(home_dir, rc='', shell='zsh'): + ans = { + 'PATH': os.environ['PATH'], + 'HOME': home_dir, + 'TERM': 'xterm-kitty', + 'TERMINFO': terminfo_dir, + 'KITTY_SHELL_INTEGRATION': 'enabled', + } + if shell == 'zsh': + ans['ZLE_RPROMPT_INDENT'] = '0' + with open(os.path.join(home_dir, '.zshenv'), 'w') as f: + print('unset GLOBAL_RCS', file=f) + with open(os.path.join(home_dir, '.zshrc'), 'w') as f: + print(rc, file=f) + setup_zsh_env(ans) + return ans + + +class ShellIntegration(BaseTest): + + @contextmanager + def run_shell(self, shell='zsh', rc=''): + with TemporaryDirectory() as home_dir: + pty = self.create_pty(f'{shell} -il', cwd=home_dir, env=safe_env_for_running_shell(home_dir, rc)) + i = 10 + while i > 0 and not pty.screen_contents().strip(): + pty.process_input_from_child() + i -= 1 + yield pty + + def test_zsh_integration(self): + ps1, rps1 = 'left>', '') == 2) + self.ae(pty.last_cmd_output(), str(pty.screen.line(1))) diff --git a/kitty_tests/ssh.py b/kitty_tests/ssh.py index c5730116a..b5aa0cef4 100644 --- a/kitty_tests/ssh.py +++ b/kitty_tests/ssh.py @@ -14,7 +14,6 @@ class SSHTest(BaseTest): def test_basic_pty_operations(self): pty = self.create_pty('echo hello') - self.assertTrue(pty.wait_for_input_from_child()) pty.process_input_from_child() self.ae(pty.screen_contents(), 'hello') pty = self.create_pty(self.cmd_to_run_python_code('''\ @@ -22,9 +21,8 @@ import array, fcntl, sys, termios buf = array.array('H', [0, 0, 0, 0]) fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, buf) print(' '.join(map(str, buf)))'''), lines=13, cols=77) - self.assertTrue(pty.wait_for_input_from_child()) pty.process_input_from_child() - self.ae(pty.screen_contents(), '13 77 0 0') + self.ae(pty.screen_contents(), '13 77 770 260') def test_ssh_connection_data(self): def t(cmdline, binary='ssh', host='main', port=None, identity_file=''):