From e6cff61f99eaaad3f1cb225d69430da901405abe Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Fri, 1 Oct 2021 14:29:14 +0530 Subject: [PATCH] Move management of destination file completely into PatchFile --- kittens/transfer/librsync.py | 26 +++++++++++++++++++++++--- kitty_tests/file_transmission.py | 11 +++-------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/kittens/transfer/librsync.py b/kittens/transfer/librsync.py index fdf2a1979..6777b285d 100644 --- a/kittens/transfer/librsync.py +++ b/kittens/transfer/librsync.py @@ -3,11 +3,12 @@ # License: GPLv3 Copyright: 2021, Kovid Goyal import os +import tempfile from typing import IO, TYPE_CHECKING, Iterator from .rsync import ( IO_BUFFER_SIZE, RsyncError, begin_create_delta, begin_create_signature, - begin_load_signature, begin_patch, iter_job, build_hash_table + begin_load_signature, begin_patch, build_hash_table, iter_job ) if TYPE_CHECKING: @@ -76,6 +77,7 @@ class LoadSignature(StreamingJob): # see whole.c in librsync source for size calculations expected_input_size = 16 * 1024 + autocommit = True def __init__(self) -> None: job, self.signature = begin_load_signature() @@ -97,8 +99,13 @@ class PatchFile(StreamingJob): # see whole.c in librsync source for size calculations expected_input_size = IO_BUFFER_SIZE - def __init__(self, src_path: str): + def __init__(self, src_path: str, output_path: str = ''): + self.overwrite_src = not output_path self.src_file = open(src_path, 'rb') + if self.overwrite_src: + self.dest_file = tempfile.NamedTemporaryFile(mode='wb', dir=os.path.dirname(os.path.abspath(os.path.realpath(src_path))), delete=False) + else: + self.dest_file = open(output_path, 'wb') job = begin_patch(self.read_from_src) super().__init__(job, output_buf_size=4 * IO_BUFFER_SIZE) @@ -109,7 +116,20 @@ class PatchFile(StreamingJob): def close(self) -> None: if not self.src_file.closed: self.src_file.close() - commit = close + count = 100 + while not self.finished: + self() + count -= 1 + if count == 0: + raise Exception('Patching file did not receive enough input') + self.dest_file.close() + if self.overwrite_src: + os.replace(self.dest_file.name, self.src_file.name) + + def write(self, data: bytes) -> None: + output = self(data) + if output: + self.dest_file.write(output) def __enter__(self) -> 'PatchFile': return self diff --git a/kitty_tests/file_transmission.py b/kitty_tests/file_transmission.py index 0857711eb..11b1bb6d5 100644 --- a/kitty_tests/file_transmission.py +++ b/kitty_tests/file_transmission.py @@ -101,16 +101,11 @@ class TestFileTransmission(BaseTest): sig_loader(chunk) sig_loader() self.assertTrue(sig_loader.finished) - with open(c_path, 'wb') as dest, PatchFile(a_path) as patcher: + with PatchFile(a_path, c_path) as patcher: for chunk in delta_for_file(b_path, sig_loader.signature): self.assertFalse(patcher.finished) - output = patcher(chunk) - if output: - dest.write(output) - while not patcher.finished: - output = patcher() - if output: - dest.write(output) + patcher.write(chunk) + self.assertTrue(patcher.finished) with open(b_path, 'rb') as b, open(c_path, 'rb') as c: while True: bc = b.read(4096)