More work on file transfer

This commit is contained in:
Kovid Goyal 2021-09-09 12:59:31 +05:30
parent 2178c8e4af
commit f9c99a61d4
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
6 changed files with 257 additions and 17 deletions

View File

@ -6,16 +6,23 @@ import os
import stat import stat
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from enum import auto
from itertools import count from itertools import count
from typing import ( from typing import (
Dict, Generator, Iterable, Iterator, List, Sequence, Tuple, Union Callable, Dict, Generator, Iterable, Iterator, List, Sequence, Tuple,
Union, cast
) )
from kitty.cli import parse_args from kitty.cli import parse_args
from kitty.cli_stub import TransferCLIOptions from kitty.cli_stub import TransferCLIOptions
from kitty.fast_data_types import FILE_TRANSFER_CODE
from kitty.file_transmission import ( from kitty.file_transmission import (
Action, Compression, FileTransmissionCommand, FileType Action, Compression, FileTransmissionCommand, FileType, NameReprEnum,
TransmissionType
) )
from kitty.types import run_once
from ..tui.handler import Handler
_cwd = _home = '' _cwd = _home = ''
@ -34,6 +41,17 @@ def expand_home(path: str) -> str:
return path return path
@run_once
def short_uuid_func() -> Callable[[], str]:
from kitty.short_uuid import ShortUUID, escape_code_safe_alphabet
return ShortUUID(alphabet=''.join(set(escape_code_safe_alphabet) - {';'})).uuid4
def random_id() -> str:
f = short_uuid_func()
return cast(str, f())
@contextmanager @contextmanager
def set_paths(cwd: str = '', home: str = '') -> Generator[None, None, None]: def set_paths(cwd: str = '', home: str = '') -> Generator[None, None, None]:
global _cwd, _home global _cwd, _home
@ -93,11 +111,21 @@ def get_remote_path(local_path: str, remote_base: str) -> str:
return remote_base return remote_base
class FileState(NameReprEnum):
waiting_for_start = auto()
waiting_for_data = auto()
transmitting = auto()
finished = auto()
class File: class File:
def __init__( def __init__(
self, local_path: str, expanded_local_path: str, file_id: int, stat_result: os.stat_result, remote_base: str, file_type: FileType self, local_path: str, expanded_local_path: str, file_id: int, stat_result: os.stat_result,
remote_base: str, file_type: FileType, ttype: TransmissionType = TransmissionType.simple
) -> None: ) -> None:
self.ttype = ttype
self.state = FileState.waiting_for_start
self.local_path = local_path self.local_path = local_path
self.expanded_local_path = expanded_local_path self.expanded_local_path = expanded_local_path
self.permissions = stat.S_IMODE(stat_result.st_mode) self.permissions = stat.S_IMODE(stat_result.st_mode)
@ -112,6 +140,9 @@ class File:
self.stat_result = stat_result self.stat_result = stat_result
self.file_type = file_type self.file_type = file_type
self.compression = Compression.zlib if self.file_size > 2048 else Compression.none self.compression = Compression.zlib if self.file_size > 2048 else Compression.none
self.remote_final_path = ''
self.remote_initial_size = -1
self.err_msg = ''
def metadata_command(self) -> FileTransmissionCommand: def metadata_command(self) -> FileTransmissionCommand:
return FileTransmissionCommand( return FileTransmissionCommand(
@ -222,10 +253,101 @@ def files_for_send(cli_opts: TransferCLIOptions, args: List[str]) -> Tuple[File,
return tuple(files) return tuple(files)
class SendState(NameReprEnum):
waiting_for_permission = auto()
permission_granted = auto()
permission_denied = auto()
finished = auto()
class SendManager: class SendManager:
def __init__(self, files: Tuple[File, ...]): def __init__(self, request_id: str, files: Tuple[File, ...]):
self.files = files self.files = files
self.fid_map = {str(f.file_id): f for f in self.files}
self.request_id = request_id
self.state = SendState.waiting_for_permission
self.all_done = False
self.all_started = False
def update_collective_statuses(self) -> None:
found_not_started = found_not_done = False
for f in self.files:
if f.state is not FileState.finished:
found_not_done = True
if f.state is FileState.waiting_for_start:
found_not_started = True
if found_not_started and found_not_done:
break
self.all_done = not found_not_done
self.all_started = not found_not_started
def start_transfer(self) -> str:
return FileTransmissionCommand(action=Action.send).serialize()
def send_file_metadata(self) -> Iterator[str]:
for f in self.files:
yield f.metadata_command().serialize()
def on_file_status_update(self, ftc: FileTransmissionCommand) -> None:
file = self.fid_map.get(ftc.file_id)
if file is None:
return
if ftc.status == 'STARTED':
file.state = FileState.waiting_for_data if file.ttype is TransmissionType.rsync else FileState.transmitting
file.remote_final_path = ftc.name
file.remote_initial_size = ftc.size
else:
if ftc.name and not file.remote_final_path:
file.remote_final_path = ftc.name
file.state = FileState.finished
if ftc.status != 'OK':
file.err_msg = ftc.status
self.update_collective_statuses()
def on_file_transfer_response(self, ftc: FileTransmissionCommand) -> None:
if ftc.action is Action.status:
if ftc.file_id:
self.on_file_status_update(ftc)
else:
self.status = SendState.permission_granted if ftc.status == 'OK' else SendState.permission_denied
class Send(Handler):
use_alternate_screen = False
def __init__(self, manager: SendManager):
Handler.__init__(self)
self.manager = manager
def send_payload(self, payload: str) -> None:
self.write(f'\x1b]{FILE_TRANSFER_CODE};id={self.manager.request_id};')
self.write(payload)
self.write(b'\x1b\\')
def on_file_transfer_response(self, ftc: FileTransmissionCommand) -> None:
if ftc.id != self.manager.request_id:
return
before = self.manager.state
self.manager.on_file_transfer_response(ftc)
if before == SendState.waiting_for_permission:
if self.manager.status == SendState.permission_denied:
self.cmd.styled('Permission denied for this transfer', fg='red')
self.quit_loop(1)
return
def initialize(self) -> None:
self.send_payload(self.manager.start_transfer())
for payload in self.manager.send_file_metadata():
self.send_payload(payload)
def on_interrupt(self) -> None:
self.cmd.styled('Interrupt requested, cancelling transfer, transferred files are in undefined state', fg='red')
self.abort_transfer()
def abort_transfer(self) -> None:
self.send_payload(FileTransmissionCommand(action=Action.cancel).serialize())
self.quit_loop(1)
def send_main(cli_opts: TransferCLIOptions, args: List[str]) -> None: def send_main(cli_opts: TransferCLIOptions, args: List[str]) -> None:

View File

@ -24,6 +24,7 @@ if TYPE_CHECKING:
class Handler: class Handler:
image_manager_class: Optional[Type[ImageManagerType]] = None image_manager_class: Optional[Type[ImageManagerType]] = None
use_alternate_screen = True
def _initialize( def _initialize(
self, self,

View File

@ -68,20 +68,21 @@ debug = Debug()
class TermManager: class TermManager:
def __init__(self, optional_actions: int = termios.TCSANOW) -> None: def __init__(self, optional_actions: int = termios.TCSANOW, use_alternate_screen: bool = True) -> None:
self.extra_finalize: Optional[str] = None self.extra_finalize: Optional[str] = None
self.optional_actions = optional_actions self.optional_actions = optional_actions
self.use_alternate_screen = use_alternate_screen
def set_state_for_loop(self, set_raw: bool = True) -> None: def set_state_for_loop(self, set_raw: bool = True) -> None:
if set_raw: if set_raw:
raw_tty(self.tty_fd, self.original_termios) raw_tty(self.tty_fd, self.original_termios)
write_all(self.tty_fd, init_state()) write_all(self.tty_fd, init_state(self.use_alternate_screen))
def reset_state_to_original(self) -> None: def reset_state_to_original(self) -> None:
normal_tty(self.tty_fd, self.original_termios) normal_tty(self.tty_fd, self.original_termios)
if self.extra_finalize: if self.extra_finalize:
write_all(self.tty_fd, self.extra_finalize) write_all(self.tty_fd, self.extra_finalize)
write_all(self.tty_fd, reset_state()) write_all(self.tty_fd, reset_state(self.use_alternate_screen))
@contextmanager @contextmanager
def suspend(self) -> Generator['TermManager', None, None]: def suspend(self) -> Generator['TermManager', None, None]:
@ -409,7 +410,7 @@ class Loop:
handler.on_resize(handler.screen_size) handler.on_resize(handler.screen_size)
signal_manager = SignalManager(self.asycio_loop, _on_sigwinch, handler.on_interrupt, handler.on_term) signal_manager = SignalManager(self.asycio_loop, _on_sigwinch, handler.on_interrupt, handler.on_term)
with TermManager(self.optional_actions) as term_manager, signal_manager: with TermManager(self.optional_actions, handler.use_alternate_screen) as term_manager, signal_manager:
self._get_screen_size: ScreenSizeGetter = screen_size_function(term_manager.tty_fd) self._get_screen_size: ScreenSizeGetter = screen_size_function(term_manager.tty_fd)
image_manager = None image_manager = None
if handler.image_manager_class is not None: if handler.image_manager_class is not None:

View File

@ -68,7 +68,8 @@ class TransmissionError(Exception):
msg: str = 'Generic error', msg: str = 'Generic error',
transmit: bool = True, transmit: bool = True,
file_id: str = '', file_id: str = '',
name: str = '' name: str = '',
size: int = -1
) -> None: ) -> None:
Exception.__init__(self, msg) Exception.__init__(self, msg)
self.transmit = transmit self.transmit = transmit
@ -76,13 +77,14 @@ class TransmissionError(Exception):
self.human_msg = msg self.human_msg = msg
self.code = code self.code = code
self.name = name self.name = name
self.size = size
def as_escape_code(self, request_id: str = '') -> str: def as_escape_code(self, request_id: str = '') -> str:
name = self.code if isinstance(self.code, str) else self.code.name name = self.code if isinstance(self.code, str) else self.code.name
if self.human_msg: if self.human_msg:
name += ':' + self.human_msg name += ':' + self.human_msg
return FileTransmissionCommand( return FileTransmissionCommand(
action=Action.status, id=request_id, file_id=self.file_id, status=name, name=self.name action=Action.status, id=request_id, file_id=self.file_id, status=name, name=self.name, size=self.size
).serialize() ).serialize()
@ -99,6 +101,7 @@ class FileTransmissionCommand:
quiet: int = 0 quiet: int = 0
mtime: int = -1 mtime: int = -1
permissions: int = -1 permissions: int = -1
size: int = -1
data: bytes = b'' data: bytes = b''
name: str = field(default='', metadata={'base64': True}) name: str = field(default='', metadata={'base64': True})
status: str = field(default='', metadata={'base64': True}) status: str = field(default='', metadata={'base64': True})
@ -199,6 +202,11 @@ class DestFile:
self.name = os.path.expanduser(self.name) self.name = os.path.expanduser(self.name)
if not os.path.isabs(self.name): if not os.path.isabs(self.name):
self.name = os.path.join(tempfile.gettempdir(), self.name) self.name = os.path.join(tempfile.gettempdir(), self.name)
try:
self.existing_stat: Optional[os.stat_result] = os.stat(self.name, follow_symlinks=False)
except OSError:
self.existing_stat = None
self.needs_unlink = self.existing_stat is not None and (self.existing_stat.st_nlink > 1 or stat.S_ISLNK(self.existing_stat.st_mode))
self.mtime = ftc.mtime self.mtime = ftc.mtime
self.file_id = ftc.file_id self.file_id = ftc.file_id
self.permissions = ftc.permissions self.permissions = ftc.permissions
@ -242,6 +250,13 @@ class DestFile:
else: else:
os.utime(self.name, ns=(self.mtime, self.mtime)) os.utime(self.name, ns=(self.mtime, self.mtime))
def unlink_existing_if_needed(self, force: bool = False) -> None:
if force or self.needs_unlink:
with suppress(FileNotFoundError):
os.unlink(self.name)
self.existing_stat = None
self.needs_unlink = False
def write_data(self, all_files: Dict[str, 'DestFile'], data: bytes, is_last: bool) -> None: def write_data(self, all_files: Dict[str, 'DestFile'], data: bytes, is_last: bool) -> None:
if self.ftype is FileType.directory: if self.ftype is FileType.directory:
raise TransmissionError(code=ErrorCode.EISDIR, file_id=self.file_id, msg='Cannot write data to a directory entry') raise TransmissionError(code=ErrorCode.EISDIR, file_id=self.file_id, msg='Cannot write data to a directory entry')
@ -252,6 +267,7 @@ class DestFile:
if is_last: if is_last:
lt = self.link_target.decode('utf-8', 'replace') lt = self.link_target.decode('utf-8', 'replace')
base = self.make_parent_dirs() base = self.make_parent_dirs()
self.unlink_existing_if_needed(force=True)
if lt.startswith('fid:'): if lt.startswith('fid:'):
lt = all_files[lt[4:]].name lt = all_files[lt[4:]].name
if self.ftype is FileType.symlink: if self.ftype is FileType.symlink:
@ -280,7 +296,11 @@ class DestFile:
elif self.ftype is FileType.regular: elif self.ftype is FileType.regular:
if self.actual_file is None: if self.actual_file is None:
self.make_parent_dirs() self.make_parent_dirs()
self.actual_file = open(self.name, 'wb') self.unlink_existing_if_needed()
flags = os.O_RDWR | os.O_CREAT | getattr(os, 'O_CLOEXEC', 0) | getattr(os, 'O_BINARY', 0)
if self.ttype is TransmissionType.simple:
flags |= os.O_TRUNC
self.actual_file = open(os.open(self.name, flags, self.permissions), mode='r+b', closefd=True)
data = self.decompressor(data, is_last=is_last) data = self.decompressor(data, is_last=is_last)
self.actual_file.write(data) self.actual_file.write(data)
if is_last: if is_last:
@ -439,7 +459,8 @@ class FileTransmission:
self.send_status_response(ErrorCode.OK, ar.id, df.file_id, name=df.name) self.send_status_response(ErrorCode.OK, ar.id, df.file_id, name=df.name)
else: else:
if ar.send_acknowledgements: if ar.send_acknowledgements:
self.send_status_response(code=ErrorCode.STARTED, request_id=ar.id, file_id=df.file_id, name=df.name) sz = df.existing_stat.st_size if df.existing_stat is not None else -1
self.send_status_response(code=ErrorCode.STARTED, request_id=ar.id, file_id=df.file_id, name=df.name, size=sz)
elif cmd.action in (Action.data, Action.end_data): elif cmd.action in (Action.data, Action.end_data):
try: try:
df = ar.add_data(cmd) df = ar.add_data(cmd)
@ -470,9 +491,9 @@ class FileTransmission:
def send_status_response( def send_status_response(
self, code: Union[ErrorCode, str] = ErrorCode.EINVAL, self, code: Union[ErrorCode, str] = ErrorCode.EINVAL,
request_id: str = '', file_id: str = '', msg: str = '', request_id: str = '', file_id: str = '', msg: str = '',
name: str = '', name: str = '', size: int = -1
) -> bool: ) -> bool:
err = TransmissionError(code=code, msg=msg, file_id=file_id, name=name) err = TransmissionError(code=code, msg=msg, file_id=file_id, name=name, size=size)
data = err.as_escape_code(request_id) data = err.as_escape_code(request_id)
return self.write_osc_to_child(request_id, data) return self.write_osc_to_child(request_id, data)

58
kitty/short_uuid.py Normal file
View File

@ -0,0 +1,58 @@
#!/usr/bin/env python
# vim:fileencoding=utf-8
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
import math
import string
import uuid as _uuid
from typing import Dict, Optional, Sequence
def num_to_string(number: int, alphabet: Sequence[str], alphabet_len: int, pad_to_length: Optional[int] = None) -> str:
ans = []
number = max(0, number)
while number:
number, digit = divmod(number, alphabet_len)
ans.append(alphabet[digit])
if pad_to_length is not None and pad_to_length > len(ans):
ans.append(alphabet[0] * (pad_to_length - len(ans)))
return ''.join(ans)
def string_to_num(string: str, alphabet_map: Dict[str, int], alphabet_len: int) -> int:
ans = 0
for char in reversed(string):
ans = ans * alphabet_len + alphabet_map[char]
return ans
escape_code_safe_alphabet = string.ascii_letters + string.digits + string.punctuation + ' '
human_alphabet = (string.digits + string.ascii_letters)[2:]
class ShortUUID:
def __init__(self, alphabet: str = human_alphabet):
self.alphabet = tuple(sorted(alphabet))
self.alphabet_len = len(self.alphabet)
self.alphabet_map = {c: i for i, c in enumerate(self.alphabet)}
self.uuid_pad_len = int(math.ceil(math.log(1 << 128, self.alphabet_len)))
def uuid4(self, pad_to_length: Optional[int] = None) -> str:
if pad_to_length is None:
pad_to_length = self.uuid_pad_len
return num_to_string(_uuid.uuid4().int, self.alphabet, self.alphabet_len, pad_to_length)
def uuid5(self, namespace: _uuid.UUID, name: str, pad_to_length: Optional[int] = None) -> str:
if pad_to_length is None:
pad_to_length = self.uuid_pad_len
return num_to_string(_uuid.uuid5(namespace, name).int, self.alphabet, self.alphabet_len, pad_to_length)
def decode(self, encoded: str) -> _uuid.UUID:
return _uuid.UUID(int=string_to_num(encoded, self.alphabet_map, self.alphabet_len))
_global_instance = ShortUUID()
uuid4 = _global_instance.uuid4
uuid5 = _global_instance.uuid5
decode = _global_instance.decode

View File

@ -21,7 +21,7 @@ from kitty.file_transmission import (
from . import BaseTest from . import BaseTest
def response(id='', msg='', file_id='', name='', action='status', status=''): def response(id='', msg='', file_id='', name='', action='status', status='', size=-1):
ans = {'action': 'status'} ans = {'action': 'status'}
if id: if id:
ans['id'] = id ans['id'] = id
@ -31,6 +31,8 @@ def response(id='', msg='', file_id='', name='', action='status', status=''):
ans['name'] = name ans['name'] = name
if status: if status:
ans['status'] = status ans['status'] = status
if size > -1:
ans['size'] = size
return ans return ans
@ -93,12 +95,12 @@ class TestFileTransmission(BaseTest):
ft.handle_serialized_command(serialized_cmd(action='send', quiet=quiet)) ft.handle_serialized_command(serialized_cmd(action='send', quiet=quiet))
self.assertIn('', ft.active_receives) self.assertIn('', ft.active_receives)
self.ae(ft.test_responses, [] if quiet else [response(status='OK')]) self.ae(ft.test_responses, [] if quiet else [response(status='OK')])
ft.handle_serialized_command(serialized_cmd(action='file', name=dest, quiet=quiet)) ft.handle_serialized_command(serialized_cmd(action='file', name=dest))
self.assertPathEqual(ft.active_file().name, dest) self.assertPathEqual(ft.active_file().name, dest)
self.assertIsNone(ft.active_file().actual_file) self.assertIsNone(ft.active_file().actual_file)
self.ae(ft.test_responses, [] if quiet else [response(status='OK'), response(status='STARTED', name=dest)]) self.ae(ft.test_responses, [] if quiet else [response(status='OK'), response(status='STARTED', name=dest)])
ft.handle_serialized_command(serialized_cmd(action='data', data='abcd')) ft.handle_serialized_command(serialized_cmd(action='data', data='abcd'))
self.assertPathEqual(ft.active_file().actual_file.name, dest) self.assertPathEqual(ft.active_file().name, dest)
ft.handle_serialized_command(serialized_cmd(action='end_data', data='123')) ft.handle_serialized_command(serialized_cmd(action='end_data', data='123'))
self.ae(ft.test_responses, [] if quiet else [response(status='OK'), response(status='STARTED', name=dest), response(status='OK', name=dest)]) self.ae(ft.test_responses, [] if quiet else [response(status='OK'), response(status='STARTED', name=dest), response(status='OK', name=dest)])
self.assertTrue(ft.active_receives) self.assertTrue(ft.active_receives)
@ -137,6 +139,41 @@ class TestFileTransmission(BaseTest):
del odata del odata
del data del data
# overwriting
self.clean_tdir()
ft = FileTransmission()
one = os.path.join(self.tdir, '1')
two = os.path.join(self.tdir, '2')
three = os.path.join(self.tdir, '3')
open(two, 'w').close()
os.symlink(two, one)
ft.handle_serialized_command(serialized_cmd(action='send'))
ft.handle_serialized_command(serialized_cmd(action='file', name=one))
ft.handle_serialized_command(serialized_cmd(action='end_data', data='abcd'))
ft.handle_serialized_command(serialized_cmd(action='finish'))
self.assertFalse(os.path.islink(one))
with open(one) as f:
self.ae(f.read(), 'abcd')
self.assertTrue(os.path.isfile(two))
ft = FileTransmission()
ft.handle_serialized_command(serialized_cmd(action='send'))
ft.handle_serialized_command(serialized_cmd(action='file', name=two, ftype='symlink'))
ft.handle_serialized_command(serialized_cmd(action='end_data', data='path:/abcd'))
ft.handle_serialized_command(serialized_cmd(action='finish'))
self.ae(os.readlink(two), '/abcd')
with open(three, 'w') as f:
f.write('abcd')
self.responses = []
ft = FileTransmission()
ft.handle_serialized_command(serialized_cmd(action='send'))
self.assertResponses(ft, status='OK')
ft.handle_serialized_command(serialized_cmd(action='file', name=three))
self.assertResponses(ft, status='STARTED', name=three, size=4)
ft.handle_serialized_command(serialized_cmd(action='end_data', data='11'))
ft.handle_serialized_command(serialized_cmd(action='finish'))
with open(three) as f:
self.ae(f.read(), '11')
# multi file send # multi file send
self.clean_tdir() self.clean_tdir()
ft = FileTransmission() ft = FileTransmission()