librsync actually blocks on output buffer size as well as input availability. So handle that

This commit is contained in:
Kovid Goyal 2021-10-02 09:26:35 +05:30
parent f85f39e662
commit 5729e33412
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 80 additions and 67 deletions

View File

@ -4,7 +4,7 @@
import os import os
import tempfile import tempfile
from typing import IO, TYPE_CHECKING, Iterator from typing import IO, TYPE_CHECKING, Iterator, Union
from .rsync import ( from .rsync import (
IO_BUFFER_SIZE, RsyncError, begin_create_delta, begin_create_signature, IO_BUFFER_SIZE, RsyncError, begin_create_delta, begin_create_signature,
@ -22,34 +22,41 @@ class StreamingJob:
def __init__(self, job: 'JobCapsule', output_buf_size: int = IO_BUFFER_SIZE): def __init__(self, job: 'JobCapsule', output_buf_size: int = IO_BUFFER_SIZE):
self.job = job self.job = job
self.finished = False self.finished = False
self.prev_unused_input = b''
self.calls_with_no_data = 0 self.calls_with_no_data = 0
self.output_buf = bytearray(output_buf_size) self.output_buf = bytearray(output_buf_size)
self.uncomsumed_data = b''
def __call__(self, input_data: bytes = b'') -> memoryview: def __call__(self, input_data: Union[memoryview, bytes] = b'') -> Iterator[memoryview]:
if self.finished: if self.finished:
if input_data: if input_data:
raise RsyncError('There was too much input data') raise RsyncError('There was too much input data')
return memoryview(self.output_buf)[:0] return memoryview(self.output_buf)[:0]
no_more_data = not input_data if self.uncomsumed_data:
if self.prev_unused_input: input_data = self.uncomsumed_data + bytes(input_data)
input_data = self.prev_unused_input + input_data self.uncomsumed_data = b''
self.prev_unused_input = b'' while True:
self.finished, sz_of_unused_input, output_size = iter_job(self.job, input_data, self.output_buf) self.finished, sz_of_unused_input, output_size = iter_job(self.job, input_data, self.output_buf)
if sz_of_unused_input > 0 and not self.finished: if output_size:
if no_more_data: yield memoryview(self.output_buf)[:output_size]
raise RsyncError(f"{sz_of_unused_input} bytes of input data were not used")
self.prev_unused_input = bytes(input_data[-sz_of_unused_input:])
if self.finished: if self.finished:
self.commit() break
if no_more_data and not output_size: if not sz_of_unused_input and len(input_data):
self.calls_with_no_data += 1 break
if self.calls_with_no_data > 3: # prevent infinite loop consumed_some_input = sz_of_unused_input < len(input_data)
raise RsyncError('There was not enough input data') produced_some_output = output_size > 0
return memoryview(self.output_buf)[:output_size] if not consumed_some_input and not produced_some_output:
break
input_data = memoryview(input_data)[-sz_of_unused_input:]
if sz_of_unused_input:
self.uncomsumed_data = bytes(input_data[-sz_of_unused_input:])
def commit(self) -> None: def get_remaining_output(self) -> Iterator[memoryview]:
pass if not self.finished:
yield from self()
if not self.finished:
raise RsyncError('Insufficient input data')
if self.uncomsumed_data:
raise RsyncError(f'{len(self.uncomsumed_data)} bytes if unconsumed input data')
def drive_job_on_file(f: IO[bytes], job: 'JobCapsule', input_buf_size: int = IO_BUFFER_SIZE, output_buf_size: int = IO_BUFFER_SIZE) -> Iterator[memoryview]: def drive_job_on_file(f: IO[bytes], job: 'JobCapsule', input_buf_size: int = IO_BUFFER_SIZE, output_buf_size: int = IO_BUFFER_SIZE) -> Iterator[memoryview]:
@ -57,9 +64,11 @@ def drive_job_on_file(f: IO[bytes], job: 'JobCapsule', input_buf_size: int = IO_
input_buf = bytearray(input_buf_size) input_buf = bytearray(input_buf_size)
while not sj.finished: while not sj.finished:
sz = f.readinto(input_buf) # type: ignore sz = f.readinto(input_buf) # type: ignore
result = sj(memoryview(input_buf)[:sz]) if not sz:
if len(result) > 0: del input_buf
yield result yield from sj.get_remaining_output()
break
yield from sj(memoryview(input_buf)[:sz])
def signature_of_file(path: str) -> Iterator[memoryview]: def signature_of_file(path: str) -> Iterator[memoryview]:
@ -83,7 +92,13 @@ class LoadSignature(StreamingJob):
job, self.signature = begin_load_signature() job, self.signature = begin_load_signature()
super().__init__(job, output_buf_size=0) super().__init__(job, output_buf_size=0)
def add_chunk(self, chunk: bytes) -> None:
for ignored in self(chunk):
pass
def commit(self) -> None: def commit(self) -> None:
for ignored in self.get_remaining_output():
pass
build_hash_table(self.signature) build_hash_table(self.signature)
@ -115,6 +130,7 @@ class PatchFile(StreamingJob):
def close(self) -> None: def close(self) -> None:
if not self.src_file.closed: if not self.src_file.closed:
self.get_remaining_output()
self.src_file.close() self.src_file.close()
count = 100 count = 100
while not self.finished: while not self.finished:
@ -127,8 +143,7 @@ class PatchFile(StreamingJob):
os.replace(self.dest_file.name, self.src_file.name) os.replace(self.dest_file.name, self.src_file.name)
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:
output = self(data) for output in self(data):
if output:
self.dest_file.write(output) self.dest_file.write(output)
def __enter__(self) -> 'PatchFile': def __enter__(self) -> 'PatchFile':
@ -144,12 +159,10 @@ def develop() -> None:
sig_loader = LoadSignature() sig_loader = LoadSignature()
with open(src + '.sig', 'wb') as f: with open(src + '.sig', 'wb') as f:
for chunk in signature_of_file(src): for chunk in signature_of_file(src):
sig_loader(chunk) sig_loader.add_chunk(chunk)
f.write(chunk) f.write(chunk)
sig_loader() sig_loader.commit()
with open(src + '.delta', 'wb') as f, PatchFile(src, src + '.output') as patcher: with open(src + '.delta', 'wb') as f, PatchFile(src, src + '.output') as patcher:
for chunk in delta_for_file(src, sig_loader.signature): for chunk in delta_for_file(src, sig_loader.signature):
f.write(chunk) f.write(chunk)
patcher.write(chunk) patcher.write(chunk)
if not patcher.finished:
patcher.write(b'')

View File

@ -89,27 +89,16 @@ iter_job(PyObject *self UNUSED, PyObject *args) {
.avail_in=input_buf.len, .next_in=input_buf.buf, .eof_in=eof, .avail_in=input_buf.len, .next_in=input_buf.buf, .eof_in=eof,
.avail_out=output_buf.len, .next_out=output_buf.buf .avail_out=output_buf.len, .next_out=output_buf.buf
}; };
Py_ssize_t output_size = 0;
rs_result result = RS_DONE;
while (true) {
size_t before = buffer.avail_out; size_t before = buffer.avail_out;
result = rs_job_iter(job, &buffer); rs_result result = rs_job_iter(job, &buffer);
output_size += before - buffer.avail_out; Py_ssize_t output_size = before - buffer.avail_out;
if (result == RS_DONE || result == RS_BLOCKED) break; if (result == RS_DONE || result == RS_BLOCKED) {
if (buffer.avail_in) { Py_ssize_t unused_input = buffer.avail_in;
PyBuffer_Release(&output_buf); return Py_BuildValue("Onn", result == RS_DONE ? Py_True : Py_False, unused_input, output_size);
if (PyByteArray_Resize(output_array, MAX(IO_BUFFER_SIZE, (size_t)PyByteArray_GET_SIZE(output_array) * 2)) != 0) return NULL;
if (PyObject_GetBuffer(output_array, &output_buf, PyBUF_WRITE) != 0) return NULL;
buffer.avail_out = output_buf.len - output_size;
buffer.next_out = (char*)output_buf.buf + output_size;
continue;
} }
PyErr_SetString(RsyncError, rs_strerror(result)); PyErr_SetString(RsyncError, rs_strerror(result));
return NULL; return NULL;
} }
Py_ssize_t unused_input = buffer.avail_in;
return Py_BuildValue("Onn", result == RS_DONE ? Py_True : Py_False, unused_input, output_size);
}
static PyObject* static PyObject*
begin_load_signature(PyObject *self UNUSED, PyObject *args UNUSED) { begin_load_signature(PyObject *self UNUSED, PyObject *args UNUSED) {

View File

@ -392,8 +392,9 @@ class SendManager:
return return
sl = file.signature_loader sl = file.signature_loader
assert sl is not None assert sl is not None
sl(ftc.data) sl.add_chunk(ftc.data)
if ftc.action is Action.end_data: if ftc.action is Action.end_data:
sl.commit()
file.start_delta_calculation() file.start_delta_calculation()
self.update_collective_statuses() self.update_collective_statuses()

View File

@ -91,28 +91,38 @@ class TestFileTransmission(BaseTest):
a_path = os.path.join(self.tdir, 'a') a_path = os.path.join(self.tdir, 'a')
b_path = os.path.join(self.tdir, 'b') b_path = os.path.join(self.tdir, 'b')
c_path = os.path.join(self.tdir, 'c') c_path = os.path.join(self.tdir, 'c')
def files_equal(a_path, c_path):
self.ae(os.path.getsize(a_path), os.path.getsize(b_path))
with open(b_path, 'rb') as b, open(c_path, 'rb') as c:
self.ae(b.read(), c.read())
def patch(old_path, new_path, output_path, max_delta_len=0):
sig_loader = LoadSignature()
for chunk in signature_of_file(old_path):
sig_loader.add_chunk(chunk)
sig_loader.commit()
self.assertTrue(sig_loader.finished)
delta_len = 0
with PatchFile(old_path, output_path) as patcher:
for chunk in delta_for_file(new_path, sig_loader.signature):
self.assertFalse(patcher.finished)
patcher.write(chunk)
delta_len += len(chunk)
self.assertTrue(patcher.finished)
if max_delta_len:
self.assertLessEqual(delta_len, max_delta_len)
files_equal(output_path, new_path)
sz = 1024 * 1024 + 37 sz = 1024 * 1024 + 37
with open(a_path, 'wb') as f: with open(a_path, 'wb') as f:
f.write(os.urandom(sz)) f.write(os.urandom(sz))
with open(b_path, 'wb') as f: with open(b_path, 'wb') as f:
f.write(os.urandom(sz)) f.write(os.urandom(sz))
sig_loader = LoadSignature()
for chunk in signature_of_file(a_path): patch(a_path, b_path, c_path)
sig_loader(chunk) # test size of delta
sig_loader() patch(a_path, a_path, c_path, max_delta_len=256)
self.assertTrue(sig_loader.finished)
with PatchFile(a_path, c_path) as patcher:
for chunk in delta_for_file(b_path, sig_loader.signature):
self.assertFalse(patcher.finished)
patcher.write(chunk)
self.assertTrue(patcher.finished)
with open(b_path, 'rb') as b, open(c_path, 'rb') as c:
while True:
bc = b.read(4096)
cc = c.read(4096)
self.ae(bc, cc)
if not bc and not cc:
break
def test_file_put(self): def test_file_put(self):
# send refusal # send refusal