diff --git a/kittens/transfer/receive.py b/kittens/transfer/receive.py index a5c059c20..18f5fce56 100644 --- a/kittens/transfer/receive.py +++ b/kittens/transfer/receive.py @@ -4,12 +4,14 @@ import os import posixpath from enum import auto +from itertools import count from typing import Dict, Iterator, List, Optional from kitty.cli_stub import TransferCLIOptions from kitty.fast_data_types import FILE_TRANSFER_CODE from kitty.file_transmission import ( - Action, FileTransmissionCommand, FileType, NameReprEnum, encode_bypass + Action, Compression, FileTransmissionCommand, FileType, NameReprEnum, + encode_bypass ) from kitty.typing import KeyEventType from kitty.utils import sanitize_control_codes @@ -18,9 +20,10 @@ from ..tui.handler import Handler from ..tui.loop import Loop, debug from ..tui.operations import styled, without_line_wrap from ..tui.utils import human_size -from .utils import expand_home, random_id +from .utils import expand_home, random_id, should_be_compressed debug +file_counter = count(1) class State(NameReprEnum): @@ -44,6 +47,8 @@ class File: self.remote_target = ftc.data.decode('utf-8') self.parent = ftc.parent self.expanded_local_path = '' + self.file_id = str(next(file_counter)) + self.compression_capable = self.ftype is FileType.regular and self.expected_size > 4096 and should_be_compressed(self.expanded_local_path) def __repr__(self) -> str: return f'File(rpath={self.remote_path!r}, lpath={self.expanded_local_path!r})' @@ -148,9 +153,12 @@ class Manager: def request_files(self) -> Iterator[str]: for f in self.files: - if f.ftype is FileType.directory: + if f.ftype is FileType.directory or (f.ftype is FileType.link and f.remote_target): continue - yield FileTransmissionCommand(action=Action.file, name=f.remote_path).serialize() + yield FileTransmissionCommand( + action=Action.file, name=f.remote_path, file_id=f.file_id, + compression=Compression.zlib if f.compression_capable else Compression.none + ).serialize() def collect_files(self, cli_opts: TransferCLIOptions) -> None: self.files = list(files_for_receive(cli_opts, self.dest, self.files, self.remote_home, self.spec)) diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index 4f106f30a..916eb438d 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -20,7 +20,9 @@ from typing import ( ) from kittens.transfer.librsync import PatchFile, signature_of_file -from kittens.transfer.utils import abspath, expand_home, home_path +from kittens.transfer.utils import ( + IdentityCompressor, ZlibCompressor, abspath, expand_home, home_path +) from kitty.fast_data_types import ( FILE_TRANSFER_CODE, OSC, add_timer, get_boss, get_options ) @@ -529,6 +531,50 @@ class ActiveReceive: df.apply_metadata() +class SourceFile: + + def __init__(self, ftc: FileTransmissionCommand): + self.file_id = ftc.file_id + self.path = ftc.name + self.ttype = ftc.ttype + self.waiting_for_signature = True if self.ttype is TransmissionType.rsync else False + self.transmitted = False + self.stat = os.stat(self.path) + if stat.S_ISDIR(self.stat.st_mode): + raise TransmissionError(ErrorCode.EINVAL, msg='Cannot send a directory', file_id=self.file_id) + self.compressor: Union[ZlibCompressor, IdentityCompressor] = IdentityCompressor() + if stat.S_ISLNK(self.stat.st_mode): + self.target = os.readlink(self.path) + else: + self.open_file = open(self.path, 'rb') + if ftc.compression is Compression.zlib: + self.compressor = ZlibCompressor() + + @property + def ready_to_transmit(self) -> bool: + return not self.transmitted and not self.waiting_for_signature + + def close(self) -> None: + if hasattr(self, 'open_file'): + self.open_file.close() + del self.open_file + + def next_chunk(self, sz: int = 1024 * 1024) -> Tuple[bytes, int]: + if hasattr(self, 'target'): + self.transmitted = True + return self.target.encode('utf-8'), len(self.target) + data = self.open_file.read(sz) + if not data or self.open_file.tell() >= self.stat.st_size: + self.transmitted = True + uncompressed_sz = len(data) + cchunk = self.compressor.compress(data) + if self.transmitted and not isinstance(self.compressor, IdentityCompressor): + cchunk += self.compressor.flush() + if self.transmitted: + self.close() + return cchunk, uncompressed_sz + + class ActiveSend: def __init__(self, request_id: str, quiet: int, bypass: str, num_of_args: int) -> None: @@ -544,22 +590,57 @@ class ActiveSend: self.send_errors = quiet < 2 self.last_activity_at = monotonic() self.file_specs: List[Tuple[str, str]] = [] + self.queued_files: List[SourceFile] = [] + self.active_file: Optional[SourceFile] = None + self.pending_chunks: Deque[FileTransmissionCommand] = deque() + self.metadata_sent = False @property def spec_complete(self) -> bool: return self.expected_num_of_args <= len(self.file_specs) def add_file_spec(self, cmd: FileTransmissionCommand) -> None: + self.last_activity_at = monotonic() if len(self.file_specs) > 8192 or self.spec_complete: raise TransmissionError(ErrorCode.EINVAL, 'Too many file specs') self.file_specs.append((cmd.file_id, cmd.name)) + def add_send_file(self, cmd: FileTransmissionCommand) -> None: + self.last_activity_at = monotonic() + if len(self.queued_files) > 32768: + raise TransmissionError(ErrorCode.EINVAL, 'Too many queued files') + self.queued_files.append(SourceFile(cmd)) + @property def is_expired(self) -> bool: return monotonic() - self.last_activity_at > (60 * EXPIRE_TIME) def close(self) -> None: - pass # TODO: Implement this + if self.active_file is not None: + self.active_file.close() + self.active_file = None + + def next_chunk(self) -> Optional[FileTransmissionCommand]: + self.last_activity_at = monotonic() + if self.pending_chunks: + return self.pending_chunks.popleft() + af = self.active_file + if af is None: + for f in self.queued_files: + if f.ready_to_transmit: + self.active_file = af = f + break + if af is None: + return None + self.queued_files.remove(af) + chunk, uncompressed_sz = af.next_chunk() + if af.transmitted: + self.active_file = None + self.pending_chunks.extend(split_for_transfer(chunk, file_id=af.file_id, mark_last=af.transmitted)) + return self.pending_chunks.popleft() + + def return_chunk(self, ftc: FileTransmissionCommand) -> None: + self.pending_chunks.insert(0, ftc) class FileTransmission: @@ -648,14 +729,24 @@ class FileTransmission: return if cmd.action is Action.file: try: - asd.add_file_spec(cmd) + asd.add_send_file(cmd) if asd.metadata_sent else asd.add_file_spec(cmd) + except OSError as err: + self.send_fail_on_os_error(err, 'Failed to add send file', asd, cmd.file_id) + self.drop_send(asd.id) + return except TransmissionError as err: self.drop_send(asd.id) if asd.send_errors: self.send_transmission_error(asd.id, err) return - if asd.spec_complete and asd.accepted: - self.send_metadata_for_send_transfer(asd) + if asd.metadata_sent: + self.pump_send_chunks(asd) + else: + if asd.spec_complete and asd.accepted: + self.send_metadata_for_send_transfer(asd) + return + if cmd.action is Action.status: + self.drop_send(asd.id) return if not asd.accepted: log_error(f'File transmission command {cmd.action} received for pending id: {cmd.id}, aborting') @@ -690,10 +781,33 @@ class FileTransmission: sent = True if sent: self.send_status_response(code=ErrorCode.OK, request_id=asd.id, name=home_path()) + asd.metadata_sent = True else: self.send_status_response(code=ErrorCode.ENOENT, request_id=asd.id, msg='No files found') self.drop_send(asd.id) + def pump_send_chunks(self, asd: ActiveSend) -> None: + while True: + try: + ftc = asd.next_chunk() + except OSError as err: + fid = asd.active_file.file_id if asd.active_file else '' + self.send_fail_on_os_error(err, 'Failed to read data from file', asd, file_id=fid) + self.drop_send(asd.id) + break + if ftc is None: + break + ftc.id = asd.id + if not self.write_ftc_to_child(ftc, use_pending=False): + asd.return_chunk(ftc) + self.callback_after(self.pump_sends, 0.05) + break + + def pump_sends(self, timer_id: Optional[int]) -> None: + for asd in self.active_sends.values(): + if asd.metadata_sent: + self.pump_send_chunks(asd) + def handle_receive_cmd(self, cmd: FileTransmissionCommand) -> None: if cmd.id in self.active_receives: ar = self.active_receives[cmd.id] @@ -922,7 +1036,7 @@ class FileTransmission: if ar.send_errors: self.send_status_response(code=ErrorCode.EPERM, request_id=ar.id, msg='User refused the transfer') - def send_fail_on_os_error(self, err: OSError, msg: str, ar: ActiveReceive, file_id: str = '') -> None: + def send_fail_on_os_error(self, err: OSError, msg: str, ar: Union[ActiveSend, ActiveReceive], file_id: str = '') -> None: if not ar.send_errors: return errname = errno.errorcode.get(err.errno, 'EFAIL') diff --git a/kitty_tests/file_transmission.py b/kitty_tests/file_transmission.py index 9d42cf640..6ebec2c53 100644 --- a/kitty_tests/file_transmission.py +++ b/kitty_tests/file_transmission.py @@ -16,11 +16,11 @@ from kittens.transfer.main import parse_transfer_args from kittens.transfer.receive import File, files_for_receive from kittens.transfer.rsync import decode_utf8_buffer, parse_ftc from kittens.transfer.send import files_for_send -from kittens.transfer.utils import expand_home, home_path, set_paths, cwd_path +from kittens.transfer.utils import cwd_path, expand_home, home_path, set_paths from kitty.file_transmission import ( Action, Compression, FileTransmissionCommand, FileType, TestFileTransmission as FileTransmission, TransmissionType, - iter_file_metadata + ZlibDecompressor, iter_file_metadata ) from . import BaseTest @@ -184,6 +184,25 @@ class TestFileTransmission(BaseTest): q = files[f.name + 'd/q'] self.ae(q['ftype'], 'symlink') self.assertNotIn('data', q) + base = os.path.join(self.tdir, 'base') + os.mkdir(base) + src = os.path.join(base, 'src.bin') + data = os.urandom(16 * 1024) + with open(src, 'wb') as f: + f.write(data) + for compress in ('none', 'zlib'): + ft = FileTransmission() + self.responses = [] + ft.handle_serialized_command(serialized_cmd(action='receive', size=1)) + self.assertResponses(ft, status='OK') + ft.handle_serialized_command(serialized_cmd(action='file', file_id='src', name=src)) + ft.active_sends['test'].metadata_sent = True + ft.test_responses = [] + ft.handle_serialized_command(serialized_cmd(action='file', file_id='src', name=src, compression=compress)) + received = b''.join(x['data'] for x in ft.test_responses) + if compress == 'zlib': + received = ZlibDecompressor()(received, True) + self.ae(data, received) def test_file_put(self): # send refusal