More work on file transmission
This commit is contained in:
parent
aa0b344b55
commit
42dcecde14
@ -2,11 +2,20 @@
|
||||
# vim:fileencoding=utf-8
|
||||
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
import copy
|
||||
import errno
|
||||
import os
|
||||
import tempfile
|
||||
from base64 import standard_b64decode
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from .utils import log_error
|
||||
from kitty.fast_data_types import OSC, get_boss
|
||||
|
||||
from .utils import log_error, sanitize_control_codes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kittens.transfer_ask.main import Response
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
@ -25,50 +34,62 @@ class Container(Enum):
|
||||
txz = auto()
|
||||
none = auto()
|
||||
|
||||
@classmethod
|
||||
def extractor_for_container_fmt(cls, fobj: IO[bytes], container_fmt: 'Container') -> Union['ZipExtractor', 'TarExtractor']:
|
||||
if container_fmt is Container.tar:
|
||||
return TarExtractor(fobj, 'r|')
|
||||
if container_fmt is Container.tgz:
|
||||
return TarExtractor(fobj, 'r|gz')
|
||||
if container_fmt is Container.tbz2:
|
||||
return TarExtractor(fobj, 'r|bz2')
|
||||
if container_fmt is Container.txz:
|
||||
return TarExtractor(fobj, 'r|xz')
|
||||
if container_fmt is Container.zip:
|
||||
return ZipExtractor(fobj)
|
||||
raise KeyError(f'Unknown container format: {container_fmt}')
|
||||
|
||||
|
||||
class Compression(Enum):
|
||||
zlib = auto()
|
||||
none = auto()
|
||||
|
||||
|
||||
class Encoding(Enum):
|
||||
base64 = auto()
|
||||
|
||||
|
||||
class FileTransmissionCommand:
|
||||
|
||||
action = Action.invalid
|
||||
container_fmt = Container.none
|
||||
compression = Compression.none
|
||||
encoding = Encoding.base64
|
||||
id: str = ''
|
||||
secret: str = ''
|
||||
mime: str = ''
|
||||
|
||||
payload = b''
|
||||
quiet: int = 0
|
||||
dest: str = ''
|
||||
data: bytes = b''
|
||||
|
||||
|
||||
def parse_command(data: str) -> FileTransmissionCommand:
|
||||
parts = data.split(':', 1)
|
||||
ans = FileTransmissionCommand()
|
||||
if len(parts) == 1:
|
||||
control, payload = parts[0], ''
|
||||
else:
|
||||
control, payload = parts
|
||||
ans.payload = standard_b64decode(payload)
|
||||
parts = data.replace(';;', '\0').split(';')
|
||||
|
||||
for x in control.split(','):
|
||||
for i, x in enumerate(parts):
|
||||
k, v = x.partition('=')[::2]
|
||||
v = v.replace('\0', ';')
|
||||
if k == 'action':
|
||||
ans.action = Action[v]
|
||||
elif k == 'container_fmt':
|
||||
ans.container_fmt = Container[v]
|
||||
elif k == 'compression':
|
||||
ans.compression = Compression[v]
|
||||
elif k == 'encoding':
|
||||
ans.encoding = Encoding[v]
|
||||
elif k in ('secret', 'mime', 'id'):
|
||||
setattr(ans, k, v)
|
||||
elif k in ('quiet',):
|
||||
setattr(ans, k, int(v))
|
||||
elif k in ('dest', 'data'):
|
||||
val = standard_b64decode(v)
|
||||
if k == 'dest':
|
||||
ans.dest = sanitize_control_codes(val.decode('utf-8'))
|
||||
else:
|
||||
ans.data = val
|
||||
|
||||
if ans.action is Action.invalid:
|
||||
raise ValueError('No valid action specified in file transmission command')
|
||||
@ -76,9 +97,88 @@ def parse_command(data: str) -> FileTransmissionCommand:
|
||||
return ans
|
||||
|
||||
|
||||
class IdentityDecompressor:
|
||||
|
||||
def __call__(self, data: bytes, is_last: bool = False) -> bytes:
|
||||
return data
|
||||
|
||||
|
||||
class ZlibDecompressor:
|
||||
|
||||
def __init__(self) -> None:
|
||||
import zlib
|
||||
self.d = zlib.decompressobj(wbits=0)
|
||||
|
||||
def __call__(self, data: bytes, is_last: bool = False) -> bytes:
|
||||
ans = self.d.decompress(data)
|
||||
if is_last:
|
||||
ans += self.d.flush()
|
||||
return ans
|
||||
|
||||
|
||||
def resolve_name(name: str, base: str) -> Optional[str]:
|
||||
if name.startswith('/'):
|
||||
return None
|
||||
base = os.path.abspath(base)
|
||||
q = os.path.abspath(os.path.join(base, name))
|
||||
return q if q.startswith(base) else None
|
||||
|
||||
|
||||
class TarExtractor:
|
||||
|
||||
def __init__(self, fobj: IO[bytes], mode: str):
|
||||
import tarfile
|
||||
self.tf = tarfile.open(mode=mode, fileobj=fobj)
|
||||
|
||||
def __call__(self, dest: str) -> None:
|
||||
directories = []
|
||||
for tinfo in self.tf:
|
||||
targetpath = resolve_name(tinfo.name, dest)
|
||||
if targetpath is None:
|
||||
continue
|
||||
if tinfo.isdir():
|
||||
self.tf.makedir(tinfo, targetpath)
|
||||
directories.append((targetpath, copy.copy(tinfo)))
|
||||
continue
|
||||
if tinfo.isfile():
|
||||
self.tf.makefile(tinfo, targetpath)
|
||||
elif tinfo.isfifo():
|
||||
self.tf.makefifo(tinfo, targetpath)
|
||||
elif tinfo.ischr() or tinfo.isblk():
|
||||
self.tf.makedev(tinfo, targetpath)
|
||||
elif tinfo.islnk() or tinfo.issym():
|
||||
self.tf.makelink(tinfo, targetpath)
|
||||
else:
|
||||
continue
|
||||
if not tinfo.issym():
|
||||
self.tf.chmod(tinfo, targetpath)
|
||||
self.tf.utime(tinfo, targetpath)
|
||||
directories.sort(reverse=True, key=lambda x: x[0])
|
||||
for targetpath, tinfo in directories:
|
||||
self.tf.chmod(tinfo, targetpath)
|
||||
self.tf.utime(tinfo, targetpath)
|
||||
|
||||
|
||||
class ZipExtractor:
|
||||
|
||||
def __init__(self, fobj: IO[bytes]):
|
||||
import zipfile
|
||||
self.zf = zipfile.ZipFile(fobj)
|
||||
|
||||
def __call__(self, dest: str) -> None:
|
||||
for zinfo in self.zf.infolist():
|
||||
targetpath = resolve_name(zinfo.filename, dest)
|
||||
if targetpath is None:
|
||||
continue
|
||||
self.zf.extract(zinfo, targetpath)
|
||||
|
||||
|
||||
class FileTransmission:
|
||||
|
||||
active_cmd: Optional[FileTransmissionCommand] = None
|
||||
active_file: Optional[IO[bytes]] = None
|
||||
active_dest: str = ''
|
||||
active_decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor()
|
||||
|
||||
def __init__(self, window_id: int):
|
||||
self.window_id = window_id
|
||||
@ -95,9 +195,120 @@ class FileTransmission:
|
||||
self.abort_in_flight()
|
||||
if cmd.action is Action.send:
|
||||
self.start_send(cmd)
|
||||
elif cmd.action in (Action.data, Action.end_data):
|
||||
self.add_data(cmd)
|
||||
if cmd.action is Action.end_data and self.active_cmd is not None:
|
||||
self.commit()
|
||||
|
||||
def send_response(self, **fields: str) -> None:
|
||||
ac = self.active_cmd
|
||||
if ac is None:
|
||||
return
|
||||
if 'id' not in fields and ac.id:
|
||||
fields['id'] = ac.id
|
||||
self.write_response_to_child(fields)
|
||||
|
||||
def write_response_to_child(self, fields: Dict[str, str]) -> None:
|
||||
boss = get_boss()
|
||||
window = boss.window_id_map.get(self.window_id)
|
||||
if window is not None:
|
||||
window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
|
||||
|
||||
def start_send(self, cmd: FileTransmissionCommand) -> None:
|
||||
self.active_cmd = cmd
|
||||
boss = get_boss()
|
||||
window = boss.window_id_map.get(self.window_id)
|
||||
if window is not None:
|
||||
boss._run_kitten(
|
||||
'transfer_ask', ['put', 'multiple' if cmd.container_fmt else 'single', cmd.dest],
|
||||
window=window, custom_callback=self.handle_send_confirmation
|
||||
)
|
||||
|
||||
def handle_send_confirmation(self, data: 'Response', *a: Any) -> None:
|
||||
cmd = self.active_cmd
|
||||
if cmd is None:
|
||||
return
|
||||
if data['allowed']:
|
||||
self.active_dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest'])))
|
||||
self.active_decompressor = ZlibDecompressor() if cmd.compression is Compression.zlib else IdentityDecompressor()
|
||||
if cmd.quiet:
|
||||
return
|
||||
else:
|
||||
self.active_cmd = None
|
||||
self.active_dest = ''
|
||||
if cmd.quiet > 1:
|
||||
return
|
||||
self.send_response(status='OK' if data['allowed'] else 'EPERM:User refused the transfer')
|
||||
|
||||
def send_fail_on_os_error(self, err: OSError, msg: str) -> None:
|
||||
ac = self.active_cmd
|
||||
if ac is None or ac.quiet < 2:
|
||||
return
|
||||
errname = errno.errorcode.get(err.errno, 'EFAIL')
|
||||
self.send_response(status=f'{errname}:{msg}')
|
||||
|
||||
def add_data(self, cmd: FileTransmissionCommand) -> None:
|
||||
ac = self.active_cmd
|
||||
if ac is None or not self.active_dest:
|
||||
return
|
||||
if self.active_file is None:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.active_dest), exist_ok=True)
|
||||
except OSError as e:
|
||||
self.send_fail_on_os_error(e, 'Creating destination directory failed')
|
||||
return self.abort_in_flight()
|
||||
if ac.container_fmt is Container.none:
|
||||
try:
|
||||
self.active_file = open(self.active_dest, 'wb')
|
||||
except OSError as e:
|
||||
self.send_fail_on_os_error(e, 'Creating destination file failed')
|
||||
return self.abort_in_flight()
|
||||
else:
|
||||
try:
|
||||
self.active_file = tempfile.TemporaryFile(dir=os.path.dirname(self.active_dest))
|
||||
except OSError as e:
|
||||
self.send_fail_on_os_error(e, 'Creating destination temp file failed')
|
||||
return self.abort_in_flight()
|
||||
data = self.active_decompressor(cmd.data, cmd.action is Action.end_data)
|
||||
try:
|
||||
self.active_file.write(data)
|
||||
except OSError as e:
|
||||
self.send_fail_on_os_error(e, 'Writing to destination file failed')
|
||||
return self.abort_in_flight()
|
||||
|
||||
def commit(self) -> None:
|
||||
cmd = self.active_cmd
|
||||
if cmd is None:
|
||||
return
|
||||
try:
|
||||
if cmd.container_fmt and self.active_file is not None:
|
||||
self.active_file.seek(0, os.SEEK_SET)
|
||||
Container.extractor_for_container_fmt(self.active_file, cmd.container_fmt)(self.active_dest)
|
||||
finally:
|
||||
self.active_cmd = None
|
||||
self.active_dest = ''
|
||||
if self.active_file is not None:
|
||||
self.active_file.close()
|
||||
self.active_file = None
|
||||
|
||||
def abort_in_flight(self) -> None:
|
||||
self.active_cmd = None
|
||||
self.active_dest = ''
|
||||
if self.active_file is not None:
|
||||
self.active_file.close()
|
||||
self.active_file = None
|
||||
|
||||
|
||||
class TestFileTransmission(FileTransmission):
|
||||
|
||||
def __init__(self, dest: str = '') -> None:
|
||||
super().__init__(0)
|
||||
self.test_responses: List[Dict[str, str]] = []
|
||||
self.test_dest = dest
|
||||
|
||||
def write_response_to_child(self, fields: Dict[str, str]) -> None:
|
||||
self.test_responses.append(fields)
|
||||
|
||||
def start_send(self, cmd: FileTransmissionCommand) -> None:
|
||||
self.active_cmd = cmd
|
||||
self.handle_send_confirmation({'dest': self.test_dest, 'allowed': bool(self.test_dest)})
|
||||
|
||||
@ -15,7 +15,7 @@ from functools import lru_cache
|
||||
from time import monotonic
|
||||
from typing import (
|
||||
TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Mapping,
|
||||
Match, NamedTuple, Optional, Tuple, Union, cast
|
||||
Match, NamedTuple, Optional, Pattern, Tuple, Union, cast
|
||||
)
|
||||
|
||||
from .constants import (
|
||||
@ -27,8 +27,8 @@ from .types import run_once
|
||||
from .typing import AddressFamily, PopenType, Socket, StartupCtx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .options.types import Options
|
||||
from .fast_data_types import OSWindowSize
|
||||
from .options.types import Options
|
||||
else:
|
||||
Options = object
|
||||
|
||||
@ -748,6 +748,7 @@ def is_kitty_gui_cmdline(*cmd: str) -> bool:
|
||||
|
||||
def reload_conf_in_all_kitties() -> None:
|
||||
import signal
|
||||
|
||||
from kitty.child import cmdline_of_process # type: ignore
|
||||
for pid in get_all_processes():
|
||||
try:
|
||||
@ -756,3 +757,12 @@ def reload_conf_in_all_kitties() -> None:
|
||||
continue
|
||||
if cmd and is_kitty_gui_cmdline(*cmd):
|
||||
os.kill(pid, signal.SIGUSR1)
|
||||
|
||||
|
||||
@run_once
|
||||
def control_codes_pat() -> Pattern:
|
||||
return re.compile('[\x00-\x09\x0b-\x1f\x7f\x80-\x9f]')
|
||||
|
||||
|
||||
def sanitize_control_codes(text: str, replace_with: str = '') -> str:
|
||||
return cast(str, control_codes_pat().sub(replace_with, text))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user