Add basic tests for zsh shell integration

This commit is contained in:
Kovid Goyal 2022-02-21 17:57:25 +05:30
parent 595698d8e9
commit c9cc832875
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
5 changed files with 109 additions and 22 deletions

View File

@ -74,7 +74,8 @@ def is_new_zsh_install(env: Dict[str, str]) -> bool:
# the latter will bail if there are rc files in $HOME # the latter will bail if there are rc files in $HOME
zdotdir = env.get('ZDOTDIR') zdotdir = env.get('ZDOTDIR')
if not zdotdir: if not zdotdir:
zdotdir = os.path.expanduser('~') zdotdir = env.get('HOME', os.path.expanduser('~'))
assert isinstance(zdotdir, str)
if zdotdir == '~': if zdotdir == '~':
return True return True
for q in ('.zshrc', '.zshenv', '.zprofile', '.zlogin'): for q in ('.zshrc', '.zshenv', '.zprofile', '.zlogin'):

View File

@ -148,6 +148,19 @@ class ScreenSize(NamedTuple):
cell_height: int cell_height: int
def read_screen_size(fd: int = -1) -> ScreenSize:
import array
import fcntl
import termios
buf = array.array('H', [0, 0, 0, 0])
if fd < 0:
fd = sys.stdout.fileno()
fcntl.ioctl(fd, termios.TIOCGWINSZ, cast(bytearray, buf))
rows, cols, width, height = tuple(buf)
cell_width, cell_height = width // (cols or 1), height // (rows or 1)
return ScreenSize(rows, cols, width, height, cell_width, cell_height)
class ScreenSizeGetter: class ScreenSizeGetter:
changed = True changed = True
Size = ScreenSize Size = ScreenSize
@ -160,14 +173,7 @@ class ScreenSizeGetter:
def __call__(self) -> ScreenSize: def __call__(self) -> ScreenSize:
if self.changed: if self.changed:
import array self.ans = read_screen_size()
import fcntl
import termios
buf = array.array('H', [0, 0, 0, 0])
fcntl.ioctl(self.fd, termios.TIOCGWINSZ, cast(bytearray, buf))
rows, cols, width, height = tuple(buf)
cell_width, cell_height = width // (cols or 1), height // (rows or 1)
self.ans = ScreenSize(rows, cols, width, height, cell_width, cell_height)
self.changed = False self.changed = False
return cast(ScreenSize, self.ans) return cast(ScreenSize, self.ans)

View File

@ -7,8 +7,8 @@ import os
import select import select
import shlex import shlex
import struct import struct
import sys
import termios import termios
import time
from pty import CHILD, fork from pty import CHILD, fork
from unittest import TestCase from unittest import TestCase
@ -19,7 +19,7 @@ from kitty.fast_data_types import (
from kitty.options.parse import merge_result_dicts from kitty.options.parse import merge_result_dicts
from kitty.options.types import Options, defaults from kitty.options.types import Options, defaults
from kitty.types import MouseEvent from kitty.types import MouseEvent
from kitty.utils import no_echo, write_all from kitty.utils import read_screen_size, write_all
class Callbacks: class Callbacks:
@ -139,9 +139,9 @@ class BaseTest(TestCase):
s = Screen(c, lines, cols, scrollback, cell_width, cell_height, 0, c) s = Screen(c, lines, cols, scrollback, cell_width, cell_height, 0, c)
return s return s
def create_pty(self, argv, cols=80, lines=25, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None): def create_pty(self, argv, cols=80, lines=25, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None):
self.set_options(options) self.set_options(options)
return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd) return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd, env)
def assertEqualAttributes(self, c1, c2): def assertEqualAttributes(self, c1, c2):
x1, y1, c1.x, c1.y = c1.x, c1.y, 0, 0 x1, y1, c1.x, c1.y = c1.x, c1.y, 0, 0
@ -154,23 +154,27 @@ class BaseTest(TestCase):
class PTY: class PTY:
def __init__(self, argv, rows=25, columns=80, scrollback=100, cell_width=10, cell_height=20, cwd=None): def __init__(self, argv, rows=25, columns=80, scrollback=100, cell_width=10, cell_height=20, cwd=None, env=None):
pid, self.master_fd = fork() pid, self.master_fd = fork()
self.is_child = pid == CHILD self.is_child = pid == CHILD
if self.is_child: if self.is_child:
while read_screen_size().width != columns * cell_width:
time.sleep(0.01)
if cwd: if cwd:
os.chdir(cwd) os.chdir(cwd)
if env:
os.environ.clear()
os.environ.update(env)
if isinstance(argv, str): if isinstance(argv, str):
argv = shlex.split(argv) argv = shlex.split(argv)
with no_echo():
sys.stdin.readline()
os.execlp(argv[0], *argv) os.execlp(argv[0], *argv)
os.set_blocking(self.master_fd, False) os.set_blocking(self.master_fd, False)
self.cell_width = cell_width
self.cell_height = cell_height
self.set_window_size(rows=rows, columns=columns) self.set_window_size(rows=rows, columns=columns)
new = termios.tcgetattr(self.master_fd) new = termios.tcgetattr(self.master_fd)
new[3] = new[3] & ~termios.ECHO new[3] = new[3] & ~termios.ECHO
termios.tcsetattr(self.master_fd, termios.TCSADRAIN, new) termios.tcsetattr(self.master_fd, termios.TCSADRAIN, new)
self.write_to_child('ready\r\n')
self.callbacks = Callbacks() self.callbacks = Callbacks()
self.screen = Screen(self.callbacks, rows, columns, scrollback, cell_width, cell_height, 0, self.callbacks) self.screen = Screen(self.callbacks, rows, columns, scrollback, cell_width, cell_height, 0, self.callbacks)
@ -186,7 +190,11 @@ class PTY:
rd = select.select([self.master_fd], [], [], timeout)[0] rd = select.select([self.master_fd], [], [], timeout)[0]
return bool(rd) return bool(rd)
def process_input_from_child(self): def send_cmd_to_child(self, cmd):
self.write_to_child(cmd + '\r')
def process_input_from_child(self, timeout=10):
self.wait_for_input_from_child(timeout=10)
bytes_read = 0 bytes_read = 0
while True: while True:
try: try:
@ -199,7 +207,16 @@ class PTY:
parse_bytes(self.screen, data) parse_bytes(self.screen, data)
return bytes_read return bytes_read
def set_window_size(self, rows=25, columns=80, x_pixels=0, y_pixels=0): def wait_till(self, q, timeout=10):
st = time.monotonic()
while not q() and time.monotonic() - st < timeout:
self.process_input_from_child(timeout=timeout - (time.monotonic() - st))
if not q():
raise TimeoutError('The condition was not met')
def set_window_size(self, rows=25, columns=80):
x_pixels = columns * self.cell_width
y_pixels = rows * self.cell_height
s = struct.pack('HHHH', rows, columns, x_pixels, y_pixels) s = struct.pack('HHHH', rows, columns, x_pixels, y_pixels)
fcntl.ioctl(self.master_fd, termios.TIOCSWINSZ, s) fcntl.ioctl(self.master_fd, termios.TIOCSWINSZ, s)
@ -210,3 +227,9 @@ class PTY:
if x: if x:
lines.append(x) lines.append(x)
return '\n'.join(lines) return '\n'.join(lines)
def last_cmd_output(self, as_ansi=False, add_wrap_markers=False):
lines = []
from kitty.window import CommandOutput
self.screen.cmd_output(CommandOutput.last_run, lines.append, as_ansi, add_wrap_markers)
return ''.join(lines)

View File

@ -0,0 +1,59 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import os
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from kitty.constants import terminfo_dir
from kitty.fast_data_types import CURSOR_BEAM
from kitty.shell_integration import setup_zsh_env
from . import BaseTest
def safe_env_for_running_shell(home_dir, rc='', shell='zsh'):
ans = {
'PATH': os.environ['PATH'],
'HOME': home_dir,
'TERM': 'xterm-kitty',
'TERMINFO': terminfo_dir,
'KITTY_SHELL_INTEGRATION': 'enabled',
}
if shell == 'zsh':
ans['ZLE_RPROMPT_INDENT'] = '0'
with open(os.path.join(home_dir, '.zshenv'), 'w') as f:
print('unset GLOBAL_RCS', file=f)
with open(os.path.join(home_dir, '.zshrc'), 'w') as f:
print(rc, file=f)
setup_zsh_env(ans)
return ans
class ShellIntegration(BaseTest):
@contextmanager
def run_shell(self, shell='zsh', rc=''):
with TemporaryDirectory() as home_dir:
pty = self.create_pty(f'{shell} -il', cwd=home_dir, env=safe_env_for_running_shell(home_dir, rc))
i = 10
while i > 0 and not pty.screen_contents().strip():
pty.process_input_from_child()
i -= 1
yield pty
def test_zsh_integration(self):
ps1, rps1 = 'left>', '<right'
with self.run_shell(
rc=f'''
PS1="{ps1}"
RPS1="{rps1}"
''') as pty:
self.ae(pty.callbacks.titlebuf, '~')
q = ps1 + ' ' * (pty.screen.columns - len(ps1) - len(rps1)) + rps1
self.ae(pty.screen_contents(), q)
pty.wait_till(lambda: pty.screen.cursor.shape == CURSOR_BEAM)
pty.send_cmd_to_child('mkdir test && ls -a')
pty.wait_till(lambda: pty.screen_contents().count('left>') == 2)
self.ae(pty.last_cmd_output(), str(pty.screen.line(1)))

View File

@ -14,7 +14,6 @@ class SSHTest(BaseTest):
def test_basic_pty_operations(self): def test_basic_pty_operations(self):
pty = self.create_pty('echo hello') pty = self.create_pty('echo hello')
self.assertTrue(pty.wait_for_input_from_child())
pty.process_input_from_child() pty.process_input_from_child()
self.ae(pty.screen_contents(), 'hello') self.ae(pty.screen_contents(), 'hello')
pty = self.create_pty(self.cmd_to_run_python_code('''\ pty = self.create_pty(self.cmd_to_run_python_code('''\
@ -22,9 +21,8 @@ import array, fcntl, sys, termios
buf = array.array('H', [0, 0, 0, 0]) buf = array.array('H', [0, 0, 0, 0])
fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, buf) fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, buf)
print(' '.join(map(str, buf)))'''), lines=13, cols=77) print(' '.join(map(str, buf)))'''), lines=13, cols=77)
self.assertTrue(pty.wait_for_input_from_child())
pty.process_input_from_child() pty.process_input_from_child()
self.ae(pty.screen_contents(), '13 77 0 0') self.ae(pty.screen_contents(), '13 77 770 260')
def test_ssh_connection_data(self): def test_ssh_connection_data(self):
def t(cmdline, binary='ssh', host='main', port=None, identity_file=''): def t(cmdline, binary='ssh', host='main', port=None, identity_file=''):