diff --git a/kittens/transfer/main.py b/kittens/transfer/main.py index 003720048..e2960df11 100644 --- a/kittens/transfer/main.py +++ b/kittens/transfer/main.py @@ -6,16 +6,23 @@ import os import stat import sys from contextlib import contextmanager +from enum import auto from itertools import count from typing import ( - Dict, Generator, Iterable, Iterator, List, Sequence, Tuple, Union + Callable, Dict, Generator, Iterable, Iterator, List, Sequence, Tuple, + Union, cast ) from kitty.cli import parse_args from kitty.cli_stub import TransferCLIOptions +from kitty.fast_data_types import FILE_TRANSFER_CODE from kitty.file_transmission import ( - Action, Compression, FileTransmissionCommand, FileType + Action, Compression, FileTransmissionCommand, FileType, NameReprEnum, + TransmissionType ) +from kitty.types import run_once + +from ..tui.handler import Handler _cwd = _home = '' @@ -34,6 +41,17 @@ def expand_home(path: str) -> str: return path +@run_once +def short_uuid_func() -> Callable[[], str]: + from kitty.short_uuid import ShortUUID, escape_code_safe_alphabet + return ShortUUID(alphabet=''.join(set(escape_code_safe_alphabet) - {';'})).uuid4 + + +def random_id() -> str: + f = short_uuid_func() + return cast(str, f()) + + @contextmanager def set_paths(cwd: str = '', home: str = '') -> Generator[None, None, None]: global _cwd, _home @@ -93,11 +111,21 @@ def get_remote_path(local_path: str, remote_base: str) -> str: return remote_base +class FileState(NameReprEnum): + waiting_for_start = auto() + waiting_for_data = auto() + transmitting = auto() + finished = auto() + + class File: def __init__( - self, local_path: str, expanded_local_path: str, file_id: int, stat_result: os.stat_result, remote_base: str, file_type: FileType + self, local_path: str, expanded_local_path: str, file_id: int, stat_result: os.stat_result, + remote_base: str, file_type: FileType, ttype: TransmissionType = TransmissionType.simple ) -> None: + self.ttype = ttype + self.state = FileState.waiting_for_start self.local_path = local_path self.expanded_local_path = expanded_local_path self.permissions = stat.S_IMODE(stat_result.st_mode) @@ -112,6 +140,9 @@ class File: self.stat_result = stat_result self.file_type = file_type self.compression = Compression.zlib if self.file_size > 2048 else Compression.none + self.remote_final_path = '' + self.remote_initial_size = -1 + self.err_msg = '' def metadata_command(self) -> FileTransmissionCommand: return FileTransmissionCommand( @@ -222,10 +253,101 @@ def files_for_send(cli_opts: TransferCLIOptions, args: List[str]) -> Tuple[File, return tuple(files) +class SendState(NameReprEnum): + waiting_for_permission = auto() + permission_granted = auto() + permission_denied = auto() + finished = auto() + + class SendManager: - def __init__(self, files: Tuple[File, ...]): + def __init__(self, request_id: str, files: Tuple[File, ...]): self.files = files + self.fid_map = {str(f.file_id): f for f in self.files} + self.request_id = request_id + self.state = SendState.waiting_for_permission + self.all_done = False + self.all_started = False + + def update_collective_statuses(self) -> None: + found_not_started = found_not_done = False + for f in self.files: + if f.state is not FileState.finished: + found_not_done = True + if f.state is FileState.waiting_for_start: + found_not_started = True + if found_not_started and found_not_done: + break + self.all_done = not found_not_done + self.all_started = not found_not_started + + def start_transfer(self) -> str: + return FileTransmissionCommand(action=Action.send).serialize() + + def send_file_metadata(self) -> Iterator[str]: + for f in self.files: + yield f.metadata_command().serialize() + + def on_file_status_update(self, ftc: FileTransmissionCommand) -> None: + file = self.fid_map.get(ftc.file_id) + if file is None: + return + if ftc.status == 'STARTED': + file.state = FileState.waiting_for_data if file.ttype is TransmissionType.rsync else FileState.transmitting + file.remote_final_path = ftc.name + file.remote_initial_size = ftc.size + else: + if ftc.name and not file.remote_final_path: + file.remote_final_path = ftc.name + file.state = FileState.finished + if ftc.status != 'OK': + file.err_msg = ftc.status + self.update_collective_statuses() + + def on_file_transfer_response(self, ftc: FileTransmissionCommand) -> None: + if ftc.action is Action.status: + if ftc.file_id: + self.on_file_status_update(ftc) + else: + self.status = SendState.permission_granted if ftc.status == 'OK' else SendState.permission_denied + + +class Send(Handler): + use_alternate_screen = False + + def __init__(self, manager: SendManager): + Handler.__init__(self) + self.manager = manager + + def send_payload(self, payload: str) -> None: + self.write(f'\x1b]{FILE_TRANSFER_CODE};id={self.manager.request_id};') + self.write(payload) + self.write(b'\x1b\\') + + def on_file_transfer_response(self, ftc: FileTransmissionCommand) -> None: + if ftc.id != self.manager.request_id: + return + before = self.manager.state + self.manager.on_file_transfer_response(ftc) + if before == SendState.waiting_for_permission: + if self.manager.status == SendState.permission_denied: + self.cmd.styled('Permission denied for this transfer', fg='red') + self.quit_loop(1) + return + + def initialize(self) -> None: + self.send_payload(self.manager.start_transfer()) + for payload in self.manager.send_file_metadata(): + self.send_payload(payload) + + def on_interrupt(self) -> None: + self.cmd.styled('Interrupt requested, cancelling transfer, transferred files are in undefined state', fg='red') + self.abort_transfer() + + def abort_transfer(self) -> None: + self.send_payload(FileTransmissionCommand(action=Action.cancel).serialize()) + self.quit_loop(1) def send_main(cli_opts: TransferCLIOptions, args: List[str]) -> None: diff --git a/kittens/tui/handler.py b/kittens/tui/handler.py index 8d3195250..aafe461f8 100644 --- a/kittens/tui/handler.py +++ b/kittens/tui/handler.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: class Handler: image_manager_class: Optional[Type[ImageManagerType]] = None + use_alternate_screen = True def _initialize( self, diff --git a/kittens/tui/loop.py b/kittens/tui/loop.py index 5252db73f..01ab5028c 100644 --- a/kittens/tui/loop.py +++ b/kittens/tui/loop.py @@ -68,20 +68,21 @@ debug = Debug() class TermManager: - def __init__(self, optional_actions: int = termios.TCSANOW) -> None: + def __init__(self, optional_actions: int = termios.TCSANOW, use_alternate_screen: bool = True) -> None: self.extra_finalize: Optional[str] = None self.optional_actions = optional_actions + self.use_alternate_screen = use_alternate_screen def set_state_for_loop(self, set_raw: bool = True) -> None: if set_raw: raw_tty(self.tty_fd, self.original_termios) - write_all(self.tty_fd, init_state()) + write_all(self.tty_fd, init_state(self.use_alternate_screen)) def reset_state_to_original(self) -> None: normal_tty(self.tty_fd, self.original_termios) if self.extra_finalize: write_all(self.tty_fd, self.extra_finalize) - write_all(self.tty_fd, reset_state()) + write_all(self.tty_fd, reset_state(self.use_alternate_screen)) @contextmanager def suspend(self) -> Generator['TermManager', None, None]: @@ -409,7 +410,7 @@ class Loop: handler.on_resize(handler.screen_size) signal_manager = SignalManager(self.asycio_loop, _on_sigwinch, handler.on_interrupt, handler.on_term) - with TermManager(self.optional_actions) as term_manager, signal_manager: + with TermManager(self.optional_actions, handler.use_alternate_screen) as term_manager, signal_manager: self._get_screen_size: ScreenSizeGetter = screen_size_function(term_manager.tty_fd) image_manager = None if handler.image_manager_class is not None: diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index 9fdff48c7..5265394cd 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -68,7 +68,8 @@ class TransmissionError(Exception): msg: str = 'Generic error', transmit: bool = True, file_id: str = '', - name: str = '' + name: str = '', + size: int = -1 ) -> None: Exception.__init__(self, msg) self.transmit = transmit @@ -76,13 +77,14 @@ class TransmissionError(Exception): self.human_msg = msg self.code = code self.name = name + self.size = size def as_escape_code(self, request_id: str = '') -> str: name = self.code if isinstance(self.code, str) else self.code.name if self.human_msg: name += ':' + self.human_msg return FileTransmissionCommand( - action=Action.status, id=request_id, file_id=self.file_id, status=name, name=self.name + action=Action.status, id=request_id, file_id=self.file_id, status=name, name=self.name, size=self.size ).serialize() @@ -99,6 +101,7 @@ class FileTransmissionCommand: quiet: int = 0 mtime: int = -1 permissions: int = -1 + size: int = -1 data: bytes = b'' name: str = field(default='', metadata={'base64': True}) status: str = field(default='', metadata={'base64': True}) @@ -199,6 +202,11 @@ class DestFile: self.name = os.path.expanduser(self.name) if not os.path.isabs(self.name): self.name = os.path.join(tempfile.gettempdir(), self.name) + try: + self.existing_stat: Optional[os.stat_result] = os.stat(self.name, follow_symlinks=False) + except OSError: + self.existing_stat = None + self.needs_unlink = self.existing_stat is not None and (self.existing_stat.st_nlink > 1 or stat.S_ISLNK(self.existing_stat.st_mode)) self.mtime = ftc.mtime self.file_id = ftc.file_id self.permissions = ftc.permissions @@ -242,6 +250,13 @@ class DestFile: else: os.utime(self.name, ns=(self.mtime, self.mtime)) + def unlink_existing_if_needed(self, force: bool = False) -> None: + if force or self.needs_unlink: + with suppress(FileNotFoundError): + os.unlink(self.name) + self.existing_stat = None + self.needs_unlink = False + def write_data(self, all_files: Dict[str, 'DestFile'], data: bytes, is_last: bool) -> None: if self.ftype is FileType.directory: raise TransmissionError(code=ErrorCode.EISDIR, file_id=self.file_id, msg='Cannot write data to a directory entry') @@ -252,6 +267,7 @@ class DestFile: if is_last: lt = self.link_target.decode('utf-8', 'replace') base = self.make_parent_dirs() + self.unlink_existing_if_needed(force=True) if lt.startswith('fid:'): lt = all_files[lt[4:]].name if self.ftype is FileType.symlink: @@ -280,7 +296,11 @@ class DestFile: elif self.ftype is FileType.regular: if self.actual_file is None: self.make_parent_dirs() - self.actual_file = open(self.name, 'wb') + self.unlink_existing_if_needed() + flags = os.O_RDWR | os.O_CREAT | getattr(os, 'O_CLOEXEC', 0) | getattr(os, 'O_BINARY', 0) + if self.ttype is TransmissionType.simple: + flags |= os.O_TRUNC + self.actual_file = open(os.open(self.name, flags, self.permissions), mode='r+b', closefd=True) data = self.decompressor(data, is_last=is_last) self.actual_file.write(data) if is_last: @@ -439,7 +459,8 @@ class FileTransmission: self.send_status_response(ErrorCode.OK, ar.id, df.file_id, name=df.name) else: if ar.send_acknowledgements: - self.send_status_response(code=ErrorCode.STARTED, request_id=ar.id, file_id=df.file_id, name=df.name) + sz = df.existing_stat.st_size if df.existing_stat is not None else -1 + self.send_status_response(code=ErrorCode.STARTED, request_id=ar.id, file_id=df.file_id, name=df.name, size=sz) elif cmd.action in (Action.data, Action.end_data): try: df = ar.add_data(cmd) @@ -470,9 +491,9 @@ class FileTransmission: def send_status_response( self, code: Union[ErrorCode, str] = ErrorCode.EINVAL, request_id: str = '', file_id: str = '', msg: str = '', - name: str = '', + name: str = '', size: int = -1 ) -> bool: - err = TransmissionError(code=code, msg=msg, file_id=file_id, name=name) + err = TransmissionError(code=code, msg=msg, file_id=file_id, name=name, size=size) data = err.as_escape_code(request_id) return self.write_osc_to_child(request_id, data) diff --git a/kitty/short_uuid.py b/kitty/short_uuid.py new file mode 100644 index 000000000..b1e32044e --- /dev/null +++ b/kitty/short_uuid.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# vim:fileencoding=utf-8 +# License: GPLv3 Copyright: 2021, Kovid Goyal + +import math +import string +import uuid as _uuid +from typing import Dict, Optional, Sequence + + +def num_to_string(number: int, alphabet: Sequence[str], alphabet_len: int, pad_to_length: Optional[int] = None) -> str: + ans = [] + number = max(0, number) + while number: + number, digit = divmod(number, alphabet_len) + ans.append(alphabet[digit]) + if pad_to_length is not None and pad_to_length > len(ans): + ans.append(alphabet[0] * (pad_to_length - len(ans))) + return ''.join(ans) + + +def string_to_num(string: str, alphabet_map: Dict[str, int], alphabet_len: int) -> int: + ans = 0 + for char in reversed(string): + ans = ans * alphabet_len + alphabet_map[char] + return ans + + +escape_code_safe_alphabet = string.ascii_letters + string.digits + string.punctuation + ' ' +human_alphabet = (string.digits + string.ascii_letters)[2:] + + +class ShortUUID: + + def __init__(self, alphabet: str = human_alphabet): + self.alphabet = tuple(sorted(alphabet)) + self.alphabet_len = len(self.alphabet) + self.alphabet_map = {c: i for i, c in enumerate(self.alphabet)} + self.uuid_pad_len = int(math.ceil(math.log(1 << 128, self.alphabet_len))) + + def uuid4(self, pad_to_length: Optional[int] = None) -> str: + if pad_to_length is None: + pad_to_length = self.uuid_pad_len + return num_to_string(_uuid.uuid4().int, self.alphabet, self.alphabet_len, pad_to_length) + + def uuid5(self, namespace: _uuid.UUID, name: str, pad_to_length: Optional[int] = None) -> str: + if pad_to_length is None: + pad_to_length = self.uuid_pad_len + return num_to_string(_uuid.uuid5(namespace, name).int, self.alphabet, self.alphabet_len, pad_to_length) + + def decode(self, encoded: str) -> _uuid.UUID: + return _uuid.UUID(int=string_to_num(encoded, self.alphabet_map, self.alphabet_len)) + + +_global_instance = ShortUUID() +uuid4 = _global_instance.uuid4 +uuid5 = _global_instance.uuid5 +decode = _global_instance.decode diff --git a/kitty_tests/file_transmission.py b/kitty_tests/file_transmission.py index 6c0aff11f..4410b40b0 100644 --- a/kitty_tests/file_transmission.py +++ b/kitty_tests/file_transmission.py @@ -21,7 +21,7 @@ from kitty.file_transmission import ( from . import BaseTest -def response(id='', msg='', file_id='', name='', action='status', status=''): +def response(id='', msg='', file_id='', name='', action='status', status='', size=-1): ans = {'action': 'status'} if id: ans['id'] = id @@ -31,6 +31,8 @@ def response(id='', msg='', file_id='', name='', action='status', status=''): ans['name'] = name if status: ans['status'] = status + if size > -1: + ans['size'] = size return ans @@ -93,12 +95,12 @@ class TestFileTransmission(BaseTest): ft.handle_serialized_command(serialized_cmd(action='send', quiet=quiet)) self.assertIn('', ft.active_receives) self.ae(ft.test_responses, [] if quiet else [response(status='OK')]) - ft.handle_serialized_command(serialized_cmd(action='file', name=dest, quiet=quiet)) + ft.handle_serialized_command(serialized_cmd(action='file', name=dest)) self.assertPathEqual(ft.active_file().name, dest) self.assertIsNone(ft.active_file().actual_file) self.ae(ft.test_responses, [] if quiet else [response(status='OK'), response(status='STARTED', name=dest)]) ft.handle_serialized_command(serialized_cmd(action='data', data='abcd')) - self.assertPathEqual(ft.active_file().actual_file.name, dest) + self.assertPathEqual(ft.active_file().name, dest) ft.handle_serialized_command(serialized_cmd(action='end_data', data='123')) self.ae(ft.test_responses, [] if quiet else [response(status='OK'), response(status='STARTED', name=dest), response(status='OK', name=dest)]) self.assertTrue(ft.active_receives) @@ -137,6 +139,41 @@ class TestFileTransmission(BaseTest): del odata del data + # overwriting + self.clean_tdir() + ft = FileTransmission() + one = os.path.join(self.tdir, '1') + two = os.path.join(self.tdir, '2') + three = os.path.join(self.tdir, '3') + open(two, 'w').close() + os.symlink(two, one) + ft.handle_serialized_command(serialized_cmd(action='send')) + ft.handle_serialized_command(serialized_cmd(action='file', name=one)) + ft.handle_serialized_command(serialized_cmd(action='end_data', data='abcd')) + ft.handle_serialized_command(serialized_cmd(action='finish')) + self.assertFalse(os.path.islink(one)) + with open(one) as f: + self.ae(f.read(), 'abcd') + self.assertTrue(os.path.isfile(two)) + ft = FileTransmission() + ft.handle_serialized_command(serialized_cmd(action='send')) + ft.handle_serialized_command(serialized_cmd(action='file', name=two, ftype='symlink')) + ft.handle_serialized_command(serialized_cmd(action='end_data', data='path:/abcd')) + ft.handle_serialized_command(serialized_cmd(action='finish')) + self.ae(os.readlink(two), '/abcd') + with open(three, 'w') as f: + f.write('abcd') + self.responses = [] + ft = FileTransmission() + ft.handle_serialized_command(serialized_cmd(action='send')) + self.assertResponses(ft, status='OK') + ft.handle_serialized_command(serialized_cmd(action='file', name=three)) + self.assertResponses(ft, status='STARTED', name=three, size=4) + ft.handle_serialized_command(serialized_cmd(action='end_data', data='11')) + ft.handle_serialized_command(serialized_cmd(action='finish')) + with open(three) as f: + self.ae(f.read(), '11') + # multi file send self.clean_tdir() ft = FileTransmission()