create fid_map after collecting files not before

This commit is contained in:
Kovid Goyal 2021-11-06 12:54:20 +05:30
parent 46f88494e3
commit 8099ae44d7
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -173,7 +173,6 @@ class Manager:
self.transfer_done = False self.transfer_done = False
def start_transfer(self) -> Iterator[str]: def start_transfer(self) -> Iterator[str]:
self.fid_map = {f.file_id: f for f in self.files}
yield FileTransmissionCommand(action=Action.receive, bypass=self.bypass, size=len(self.spec)).serialize() yield FileTransmissionCommand(action=Action.receive, bypass=self.bypass, size=len(self.spec)).serialize()
for i, x in enumerate(self.spec): for i, x in enumerate(self.spec):
yield FileTransmissionCommand(action=Action.file, file_id=str(i), name=x).serialize() yield FileTransmissionCommand(action=Action.file, file_id=str(i), name=x).serialize()
@ -223,6 +222,7 @@ class Manager:
def collect_files(self, cli_opts: TransferCLIOptions) -> None: def collect_files(self, cli_opts: TransferCLIOptions) -> None:
self.files = list(files_for_receive(cli_opts, self.dest, self.files, self.remote_home, self.spec)) self.files = list(files_for_receive(cli_opts, self.dest, self.files, self.remote_home, self.spec))
self.total_transfer_size = sum(max(0, f.expected_size) for f in self.files) self.total_transfer_size = sum(max(0, f.expected_size) for f in self.files)
self.fid_map = {f.file_id: f for f in self.files}
def on_file_transfer_response(self, ftc: FileTransmissionCommand) -> str: def on_file_transfer_response(self, ftc: FileTransmissionCommand) -> str:
if self.state is State.waiting_for_permission: if self.state is State.waiting_for_permission:
@ -332,6 +332,11 @@ class Receive(Handler):
self.confirm_paths() self.confirm_paths()
else: else:
self.start_transfer() self.start_transfer()
if self.manager.transfer_done:
self.exit_after_completion()
def exit_after_completion(self) -> None:
self.quit_loop(0)
def confirm_paths(self) -> None: def confirm_paths(self) -> None:
self.print_check_paths() self.print_check_paths()