Allow copying the same file to multiple locations

This commit is contained in:
Kovid Goyal 2022-02-28 13:09:19 +05:30
parent b4cc38a1d9
commit 7d653cb7bf
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 16 additions and 9 deletions

View File

@ -5,13 +5,15 @@
import glob import glob
import os import os
import shlex import shlex
import uuid
from typing import ( from typing import (
Iterable, Iterator, List, NamedTuple, Optional, Sequence, Tuple Dict, Iterable, Iterator, List, NamedTuple, Optional, Sequence, Tuple
) )
from kitty.cli import parse_args from kitty.cli import parse_args
from kitty.cli_stub import CopyCLIOptions from kitty.cli_stub import CopyCLIOptions
from kitty.types import run_once from kitty.types import run_once
from ..transfer.utils import expand_home, home_path from ..transfer.utils import expand_home, home_path
@ -81,11 +83,12 @@ def get_arcname(loc: str, dest: Optional[str], home: str) -> str:
class CopyInstruction(NamedTuple): class CopyInstruction(NamedTuple):
local_path: str
arcname: str arcname: str
exclude_patterns: Tuple[str, ...] exclude_patterns: Tuple[str, ...]
def parse_copy_instructions(val: str) -> Iterable[Tuple[str, CopyInstruction]]: def parse_copy_instructions(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]:
opts, args = parse_copy_args(shlex.split(val)) opts, args = parse_copy_args(shlex.split(val))
locations: List[str] = [] locations: List[str] = []
for a in args: for a in args:
@ -97,4 +100,4 @@ def parse_copy_instructions(val: str) -> Iterable[Tuple[str, CopyInstruction]]:
home = home_path() home = home_path()
for loc in locations: for loc in locations:
arcname = get_arcname(loc, opts.dest, home) arcname = get_arcname(loc, opts.dest, home)
yield loc, CopyInstruction(arcname, tuple(opts.exclude)) yield str(uuid.uuid4()), CopyInstruction(loc, arcname, tuple(opts.exclude))

View File

@ -98,8 +98,8 @@ def make_tarfile(ssh_opts: SSHOptions, base_env: Dict[str, str]) -> bytes:
buf = io.BytesIO() buf = io.BytesIO()
with tarfile.open(mode='w:bz2', fileobj=buf, encoding='utf-8') as tf: with tarfile.open(mode='w:bz2', fileobj=buf, encoding='utf-8') as tf:
rd = ssh_opts.remote_dir.rstrip('/') rd = ssh_opts.remote_dir.rstrip('/')
for location, ci in ssh_opts.copy.items(): for ci in ssh_opts.copy.values():
tf.add(location, arcname=ci.arcname, filter=filter_from_globs(*ci.exclude_patterns)) tf.add(ci.local_path, arcname=ci.arcname, filter=filter_from_globs(*ci.exclude_patterns))
add_data_as_file(tf, 'data.sh', env_script) add_data_as_file(tf, 'data.sh', env_script)
if ksi: if ksi:
arcname = 'home/' + rd + '/shell-integration' arcname = 'home/' + rd + '/shell-integration'

View File

@ -32,7 +32,7 @@ def env(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, str]]:
def copy(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]: def copy(val: str, current_val: Dict[str, str]) -> Iterable[Tuple[str, CopyInstruction]]:
yield from parse_copy_instructions(val) yield from parse_copy_instructions(val, current_val)
def init_results_dict(ans: Dict[str, Any]) -> Dict[str, Any]: def init_results_dict(ans: Dict[str, Any]) -> Dict[str, Any]:

View File

@ -88,12 +88,14 @@ print(' '.join(map(str, buf)))'''), lines=13, cols=77)
tuple(map(touch, 'simple-file g.1 g.2'.split())) tuple(map(touch, 'simple-file g.1 g.2'.split()))
os.makedirs(f'{local_home}/d1/d2/d3') os.makedirs(f'{local_home}/d1/d2/d3')
touch('d1/d2/x') touch('d1/d2/x')
touch('d1/d2/w.exclude')
os.symlink('d2/x', f'{local_home}/d1/y') os.symlink('d2/x', f'{local_home}/d1/y')
conf = '''\ conf = '''\
copy simple-file copy simple-file
copy --dest=a/sfa simple-file
copy --glob g.* copy --glob g.*
copy d1 copy --exclude */w.* d1
''' '''
copy = load_config(overrides=filter(None, conf.splitlines()))['*'].copy copy = load_config(overrides=filter(None, conf.splitlines()))['*'].copy
self.check_bootstrap( self.check_bootstrap(
@ -103,8 +105,9 @@ copy d1
self.assertTrue(os.path.lexists(f'{remote_home}/.terminfo/78')) self.assertTrue(os.path.lexists(f'{remote_home}/.terminfo/78'))
self.assertTrue(os.path.exists(f'{remote_home}/.terminfo/78/xterm-kitty')) self.assertTrue(os.path.exists(f'{remote_home}/.terminfo/78/xterm-kitty'))
self.assertTrue(os.path.exists(f'{remote_home}/.terminfo/x/xterm-kitty')) self.assertTrue(os.path.exists(f'{remote_home}/.terminfo/x/xterm-kitty'))
with open(os.path.join(remote_home, 'simple-file'), 'r') as f: for w in ('simple-file', 'a/sfa'):
self.ae(f.read(), simple_data) with open(os.path.join(remote_home, w), 'r') as f:
self.ae(f.read(), simple_data)
self.assertTrue(os.path.lexists(f'{remote_home}/d1/y')) self.assertTrue(os.path.lexists(f'{remote_home}/d1/y'))
self.assertTrue(os.path.exists(f'{remote_home}/d1/y')) self.assertTrue(os.path.exists(f'{remote_home}/d1/y'))
self.ae(os.readlink(f'{remote_home}/d1/y'), 'd2/x') self.ae(os.readlink(f'{remote_home}/d1/y'), 'd2/x')
@ -112,6 +115,7 @@ copy d1
contents.discard('.zshrc') # added by check_bootstrap() contents.discard('.zshrc') # added by check_bootstrap()
self.ae(contents, { self.ae(contents, {
'g.1', 'g.2', '.terminfo/kitty.terminfo', 'simple-file', '.terminfo/x/xterm-kitty', 'd1/d2/x', 'd1/y', 'g.1', 'g.2', '.terminfo/kitty.terminfo', 'simple-file', '.terminfo/x/xterm-kitty', 'd1/d2/x', 'd1/y',
'a/sfa'
}) })
def test_ssh_env_vars(self): def test_ssh_env_vars(self):