Change transmission protocol to not need container formats

This commit is contained in:
Kovid Goyal 2021-08-29 06:27:07 +05:30
parent 495981bade
commit 1d9425ecdc
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 159 additions and 365 deletions

View File

@ -1,193 +0,0 @@
#!/usr/bin/env python
# vim:fileencoding=utf-8
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
import os
import sys
from contextlib import suppress
from typing import List, Optional, Tuple, Iterator
from kitty.constants import cache_dir
from kitty.types import run_once
from kitty.typing import TypedDict
from kitty.config import atomic_save
from ..tui.operations import clear_screen, set_window_title, styled, set_cursor_shape
from ..tui.utils import get_key_press
def history_path() -> str:
return os.path.join(cache_dir(), 'transfer-ask.history')
class Response(TypedDict):
dest: str
allowed: bool
def sort_key(item: str) -> Tuple[int, str]:
return len(item), item.lower()
def get_filesystem_matches(prefix: str) -> Iterator[str]:
fp = os.path.abspath(os.path.expanduser(prefix))
base = os.path.dirname(fp)
with suppress(OSError):
for x in os.listdir(base):
q = os.path.join(base, x)
if q.startswith(fp):
yield prefix + q[len(fp):]
class ReadPath:
def __enter__(self) -> 'ReadPath':
self.matches: List[str] = []
import readline
from kitty.shell import init_readline
init_readline()
readline.set_completer(self.complete)
self.delims = readline.get_completer_delims()
readline.set_completer_delims('\n\t')
readline.clear_history()
for x in history_list():
readline.add_history(x)
return self
def input(self, add_to_history: str = '', prompt: str = '> ') -> str:
import readline
print(end=set_cursor_shape('bar'))
if add_to_history:
readline.add_history(add_to_history)
try:
return input(prompt)
finally:
print(end=set_cursor_shape())
def complete(self, text: str, state: int) -> Optional[str]:
if state == 0:
self.matches = sorted(get_filesystem_matches(text), key=sort_key)
with suppress(IndexError):
return self.matches[state]
def add_history(self, x: str) -> None:
hl = history_list()
with suppress(ValueError):
hl.remove(x)
hl.append(x)
del hl[:-50]
atomic_save('\n'.join(hl).encode('utf-8'), history_path())
def __exit__(self, *a: object) -> None:
import readline
readline.set_completer()
readline.set_completer_delims(self.delims)
@run_once
def history_list() -> List[str]:
with suppress(FileNotFoundError), open(history_path()) as f:
return f.read().splitlines()
return []
def guess_destination(requested: str) -> str:
if os.path.isabs(requested):
return requested
if history_list():
return os.path.join(history_list()[-1], requested)
for q in ('~/Downloads', '~/downloads', '/tmp'):
q = os.path.expanduser(q)
if os.path.isdir(q) and os.access(q, os.X_OK):
return os.path.join(q, requested)
return os.path.join(os.path.expanduser('~'), requested)
def a(x: str) -> str:
return styled(x.upper(), fg='red', fg_intense=True)
def draw_put_main_screen(is_multiple: bool, dest: str) -> None:
print(end=clear_screen())
sd = styled(dest, fg='green', fg_intense=True, bold=True)
if is_multiple:
print('The remote machine wants to send you multiple files')
print('They will be placed in the', sd, 'directory')
else:
print('The remote machine wants to send you a single file')
print('It will be saved as', sd)
if os.path.exists(dest):
print()
print(styled(f'{dest} already exists and will be replaced', fg='magenta', fg_intense=True, bold=True))
print()
print()
print(f'{a("A")}llow the download')
print(f'{a("R")}efuse the download')
print(f'{a("C")}hange the download location')
def change_destination(is_multiple: bool, dest: str) -> str:
print(end=clear_screen())
print('Choose a destination')
print('Current: ', styled(dest, italic=True))
print()
with ReadPath() as r:
new_dest = r.input(dest)
if new_dest:
r.add_history(os.path.dirname(new_dest))
new_dest = os.path.abspath(os.path.expanduser(new_dest))
return new_dest or dest
def put_main(args: List[str]) -> Response:
print(end=set_window_title('Receive a file?'))
is_multiple = args[1] == 'multiple'
dest = guess_destination(args[2])
while True:
draw_put_main_screen(is_multiple, dest)
res = get_key_press('arc', 'r')
if res == 'r':
return {'dest': '', 'allowed': False}
if res == 'a':
return {'dest': dest, 'allowed': True}
if res == 'c':
dest = change_destination(is_multiple, dest)
def get_main(args: List[str]) -> Response:
dest = os.path.abspath(os.path.expanduser(args[1]))
if not os.path.exists(dest) or not os.access(dest, os.R_OK):
return {'dest': dest, 'allowed': False}
is_dir = os.path.isdir(dest)
q = 'directory' if is_dir else 'file'
print(end=set_window_title(f'Send a {q}?'))
sd = styled(dest, fg='green', fg_intense=True, bold=True)
while True:
print(end=clear_screen())
print(f'The remote machine is asking for the {q}: {sd}')
print()
print(f'{a("A")}llow the download')
print(f'{a("R")}efuse the download')
res = get_key_press('ar', 'r')
if res == 'r':
return {'dest': '', 'allowed': False}
if res == 'a':
return {'dest': dest, 'allowed': True}
def main(args: List[str]) -> Response:
q = args[1]
del args[1]
if q == 'put':
return put_main(args)
return get_main(args)
if __name__ == '__main__':
ans = main(sys.argv)
if ans:
print(ans)

View File

@ -7,21 +7,22 @@ import errno
import os
import tempfile
from base64 import standard_b64decode, standard_b64encode
from contextlib import suppress
from time import monotonic
from enum import Enum, auto
from functools import partial
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import IO, Any, Dict, List, Optional, Union
from gettext import gettext as _
from kitty.fast_data_types import OSC, get_boss
from .utils import log_error, sanitize_control_codes
if TYPE_CHECKING:
from kittens.transfer_ask.main import Response
EXPIRE_TIME = 10 # minutes
class Action(Enum):
send = auto()
file = auto()
data = auto()
end_data = auto()
receive = auto()
@ -29,63 +30,67 @@ class Action(Enum):
cancel = auto()
class Container(Enum):
zip = auto()
tar = auto()
tgz = auto()
tbz2 = auto()
txz = auto()
none = auto()
@classmethod
def extractor_for_container_fmt(cls, fobj: IO[bytes], container_fmt: 'Container') -> Union['ZipExtractor', 'TarExtractor']:
if container_fmt is Container.tar:
return TarExtractor(fobj, 'r|')
if container_fmt is Container.tgz:
return TarExtractor(fobj, 'r|gz')
if container_fmt is Container.tbz2:
return TarExtractor(fobj, 'r|bz2')
if container_fmt is Container.txz:
return TarExtractor(fobj, 'r|xz')
if container_fmt is Container.zip:
return ZipExtractor(fobj)
raise KeyError(f'Unknown container format: {container_fmt}')
class Compression(Enum):
zlib = auto()
none = auto()
class FileType(Enum):
regular = auto()
directory = auto()
symlink = auto()
link = auto()
class TransmisstionType(Enum):
simple = auto()
resume = auto()
rsync = auto()
class FileTransmissionCommand:
action = Action.invalid
container_fmt = Container.none
compression = Compression.none
ftype = FileType.regular
ttype = TransmisstionType.simple
id: str = ''
file_id: str = ''
secret: str = ''
mime: str = ''
quiet: int = 0
dest: str = ''
name: str = ''
mtime: int = -1
permissions: int = -1
data: bytes = b''
def serialize(self) -> str:
ans = [f'action={self.action.name}']
if self.container_fmt is not Container.none:
ans.append(f'container_fmt={self.container_fmt.name}')
if self.compression is not Compression.none:
ans.append(f'compression={self.compression.name}')
for x in ('id', 'secret', 'mime', 'quiet'):
if self.ftype is not FileType.regular:
ans.append(f'ftype={self.ftype.name}')
if self.ttype is not TransmisstionType.simple:
ans.append(f'ttype={self.ttype.name}')
for x in ('id', 'file_id', 'secret', 'mime', 'quiet'):
val = getattr(self, x)
if val:
ans.append(f'{x}={val}')
if self.dest:
val = standard_b64encode(self.dest.encode('utf-8')).decode('ascii')
ans.append(f'dest={val}')
for k in ('mtime', 'permissions'):
val = getattr(self, k)
if val >= 0:
ans.append(f'{k}={val}')
if self.name:
val = standard_b64encode(self.name.encode('utf-8')).decode('ascii')
ans.append(f'name={val}')
if self.data:
val = standard_b64encode(self.data).decode('ascii')
ans.append(f'data={val}')
return ';'.join(ans)
def escape_semicolons(x: str) -> str:
return x.replace(';', ';;')
return ';'.join(map(escape_semicolons, ans))
def parse_command(data: str) -> FileTransmissionCommand:
@ -93,22 +98,27 @@ def parse_command(data: str) -> FileTransmissionCommand:
parts = data.replace(';;', '\0').split(';')
for i, x in enumerate(parts):
k, v = x.partition('=')[::2]
v = v.replace('\0', ';')
k, v = x.replace('\0', ';').partition('=')[::2]
if k == 'action':
ans.action = Action[v]
elif k == 'container_fmt':
ans.container_fmt = Container[v]
elif k == 'compression':
ans.compression = Compression[v]
elif k in ('secret', 'mime', 'id'):
setattr(ans, k, v)
elif k == 'ftype':
ans.ftype = FileType[v]
elif k == 'ttype':
ans.ttype = TransmisstionType[v]
elif k in ('secret', 'mime', 'id', 'file_id'):
setattr(ans, k, sanitize_control_codes(v))
elif k in ('quiet',):
setattr(ans, k, int(v))
elif k in ('dest', 'data'):
elif k in ('mtime', 'permissions'):
mt = int(v)
if mt >= 0:
setattr(ans, k, mt)
elif k in ('name', 'data'):
val = standard_b64decode(v)
if k == 'dest':
ans.dest = sanitize_control_codes(val.decode('utf-8'))
if k == 'name':
ans.name = sanitize_control_codes(val.decode('utf-8'))
else:
ans.data = val
@ -137,172 +147,149 @@ class ZlibDecompressor:
return ans
def resolve_name(name: str, base: str) -> Optional[str]:
if name.startswith('/') or os.path.isabs(name):
return None
base = os.path.abspath(base)
q = os.path.abspath(os.path.join(base, name))
return q if q.startswith(base) else None
class TarExtractor:
def __init__(self, fobj: IO[bytes], mode: str):
import tarfile
self.tf = tarfile.open(mode=mode, fileobj=fobj)
def __call__(self, dest: str) -> None:
directories = []
for tinfo in self.tf:
targetpath = resolve_name(tinfo.name, dest)
if targetpath is None:
continue
upperdirs = os.path.dirname(targetpath)
os.makedirs(upperdirs, exist_ok=True)
if tinfo.isdir():
self.tf.makedir(tinfo, targetpath)
directories.append((targetpath, copy.copy(tinfo)))
continue
if tinfo.islnk():
tinfo._link_target = os.path.join(upperdirs, tinfo.linkname) # type: ignore
if tinfo.isfile():
self.tf.makefile(tinfo, targetpath)
elif tinfo.isfifo():
self.tf.makefifo(tinfo, targetpath)
elif tinfo.ischr() or tinfo.isblk():
self.tf.makedev(tinfo, targetpath)
elif tinfo.islnk() or tinfo.issym():
self.tf.makelink(tinfo, targetpath)
else:
continue
if not tinfo.issym():
self.tf.chmod(tinfo, targetpath)
self.tf.utime(tinfo, targetpath)
directories.sort(reverse=True, key=lambda x: x[0])
for targetpath, tinfo in directories:
self.tf.chmod(tinfo, targetpath)
self.tf.utime(tinfo, targetpath)
class ZipExtractor:
def __init__(self, fobj: IO[bytes]):
import zipfile
self.zf = zipfile.ZipFile(fobj)
def __call__(self, dest: str) -> None:
for zinfo in self.zf.infolist():
targetpath = resolve_name(zinfo.filename, dest)
if targetpath is not None:
self.zf.extract(zinfo, dest)
class ActiveCommand:
ftc: FileTransmissionCommand
file: Optional[IO[bytes]] = None
dest: str = ''
decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor()
class DestFile:
def __init__(self, ftc: FileTransmissionCommand) -> None:
self.ftc = ftc
self.name = ftc.name
self.mtime = ftc.mtime
self.permissions = ftc.permissions
self.ftype = ftc.ftype
self.ttype = ftc.ttype
self.needs_data_sent = self.ttype is not TransmisstionType.simple
self.decompressor = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor()
def close(self) -> None:
if self.file is not None:
self.file.close()
self.file = None
pass
class ActiveReceive:
id: str
files: Dict[str, DestFile]
accepted: bool = False
def __init__(self, id: str) -> None:
self.id = id
self.files = {}
self.last_activity_at = monotonic()
@property
def is_expired(self) -> bool:
return monotonic() - self.last_activity_at > (60 * EXPIRE_TIME)
def close(self) -> None:
for x in self.files.values():
x.close()
self.files = {}
def cancel(self) -> None:
needs_delete = self.file is not None and self.file.name and self.ftc.container_fmt is Container.none
if needs_delete:
fname = getattr(self.file, 'name')
self.close()
if needs_delete:
with suppress(FileNotFoundError):
os.unlink(fname)
def start_file(self, ftc: FileTransmissionCommand) -> DestFile:
if ftc.file_id in self.files:
raise KeyError(f'The file_id {ftc.file_id} already exists')
self.files[ftc.file_id] = result = DestFile(ftc)
return result
class FileTransmission:
active_cmds: Dict[str, ActiveCommand]
active_receives: Dict[str, ActiveReceive]
def __init__(self, window_id: int):
self.window_id = window_id
self.active_cmds = {}
self.active_receives = {}
def __del__(self) -> None:
for cmd in self.active_cmds.values():
cmd.close()
self.active_cmds = {}
for ar in self.active_receives.values():
ar.close()
self.active_receives = {}
def drop_receive(self, receive_id: str) -> None:
ar = self.active_receives.pop(receive_id, None)
if ar is not None:
ar.close()
def prune_expired(self) -> None:
for k in tuple(self.active_receives):
if self.active_receives[k].is_expired:
self.drop_receive(k)
def handle_serialized_command(self, data: str) -> None:
self.prune_expired()
try:
cmd = parse_command(data)
except Exception as e:
log_error(f'Failed to parse file transmission command with error: {e}')
return
if cmd.id in self.active_cmds and cmd.action not in (Action.data, Action.end_data, Action.cancel):
log_error('File transmission command received while another is in flight, aborting')
self.active_cmds[cmd.id].close()
del self.active_cmds[cmd.id]
if cmd.id in self.active_receives or cmd.action is Action.send:
self.handle_receive_cmd(cmd)
if cmd.action is Action.send:
self.active_cmds[cmd.id] = ActiveCommand(cmd)
self.start_send(cmd)
elif cmd.action is Action.cancel:
ac = self.active_cmds.pop(cmd.id, None)
if ac is not None:
ac.cancel()
elif cmd.action in (Action.data, Action.end_data):
if cmd.id not in self.active_cmds:
log_error('File transmission data command received with unknown id')
def handle_receive_cmd(self, cmd: FileTransmissionCommand) -> None:
if cmd.id in self.active_receives:
if cmd.action is Action.send:
log_error('File transmission send received for already active id, aborting')
self.drop_receive(cmd.id)
return
ar = self.active_receives[cmd.id]
if not ar.accepted:
log_error(f'File transmission command received for rejected id: {cmd.id}, aborting')
self.drop_receive(cmd.id)
return
ar.last_activity_at = monotonic()
else:
if cmd.action is not Action.send:
log_error(f'File transmission command received for unknown or rejected id: {cmd.id}, ignoring')
return
ar = ActiveReceive(cmd.id)
self.start_receive(ar.id)
return
if cmd.action is Action.cancel:
self.drop_receive(ar.id)
elif cmd.action is Action.file:
ar.start_file(cmd)
elif cmd.action in (Action.data, Action.end_data):
try:
self.add_data(cmd)
self.add_data(ar, cmd)
except Exception:
self.abort_in_flight(cmd.id)
self.drop_receive(ar.id)
raise
if cmd.action is Action.end_data and cmd.id in self.active_cmds:
try:
self.commit(cmd.id)
except Exception:
self.abort_in_flight(cmd.id)
self.drop_receive(cmd.id)
def send_response(self, ac: Optional[FileTransmissionCommand], **fields: str) -> None:
if ac is None:
return
if 'id' not in fields and ac.id:
fields['id'] = ac.id
self.write_response_to_child(fields)
def send_response(self, id: str = '', **fields: str) -> bool:
if 'id' not in fields and id:
fields['id'] = id
return self.write_response_to_child(fields)
def write_response_to_child(self, fields: Dict[str, str]) -> None:
def write_response_to_child(self, fields: Dict[str, str]) -> bool:
boss = get_boss()
window = boss.window_id_map.get(self.window_id)
if window is not None:
window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
return window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
return False
def start_send(self, cmd: FileTransmissionCommand) -> None:
def start_receive(self, ar_id: str) -> None:
boss = get_boss()
window = boss.window_id_map.get(self.window_id)
if window is not None:
boss._run_kitten(
'transfer_ask', ['put', 'multiple' if cmd.container_fmt else 'single', cmd.dest],
window=window, custom_callback=partial(self.handle_send_confirmation, cmd.id),
boss._run_kitten('ask', ['--type=yesno', '--message', _(
'The remote machine wants to send some files to this computer. Do you want to allow the transfer?'
)],
window=window, custom_callback=partial(self.handle_send_confirmation, ar_id),
)
def handle_send_confirmation(self, cmd_id: str, data: 'Response', *a: Any) -> None:
cmd = self.active_cmds.get(cmd_id)
if cmd is None:
def handle_send_confirmation(self, cmd_id: str, data: Dict[str, str], *a: Any) -> None:
ar = self.active_receives.get(cmd_id)
if ar is None:
return
if data['allowed']:
cmd.dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest'])))
cmd.decompressor = ZlibDecompressor() if cmd.ftc.compression is Compression.zlib else IdentityDecompressor()
if cmd.ftc.quiet:
return
if data['response'] == 'y':
ar.accepted = True
else:
cmd.close()
del self.active_cmds[cmd_id]
if cmd.ftc.quiet > 1:
return
self.drop_receive(ar.id)
self.send_response(cmd.ftc, status='OK' if data['allowed'] else 'EPERM:User refused the transfer')
def send_fail_on_os_error(self, ac: Optional[FileTransmissionCommand], err: OSError, msg: str) -> None:
@ -372,14 +359,14 @@ class FileTransmission:
class TestFileTransmission(FileTransmission):
def __init__(self, dest: str = '') -> None:
def __init__(self, allow: bool = True) -> None:
super().__init__(0)
self.test_responses: List[Dict[str, str]] = []
self.test_dest = dest
self.allow = allow
def write_response_to_child(self, fields: Dict[str, str]) -> None:
def write_response_to_child(self, fields: Dict[str, str]) -> bool:
self.test_responses.append(fields)
return True
def start_send(self, cmd: FileTransmissionCommand) -> None:
dest = cmd.dest or self.test_dest
self.handle_send_confirmation(cmd.id, {'dest': dest, 'allowed': bool(dest)})
def start_receive(self, aid: str) -> None:
self.handle_send_confirmation(aid, {'response': 'y' if self.allow else 'm'})