fix failures

This commit is contained in:
senstella
2026-01-18 03:29:42 +09:00
parent ca85a52839
commit 6560ac93f1
6 changed files with 243 additions and 65 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

18
main.py
View File

@@ -25,7 +25,11 @@ def _inspect_patch(path: str) -> None:
controls = 0 controls = 0
with PatchReader.open(path) as reader: with PatchReader.open(path) as reader:
compression = reader.header.compression.algorithm if reader.header and reader.header.compression else None compression = (
reader.header.compression.algorithm
if reader.header and reader.header.compression
else None
)
for entry in reader.iter_file_entries(): for entry in reader.iter_file_entries():
files += 1 files += 1
if entry.is_rsync(): if entry.is_rsync():
@@ -50,7 +54,11 @@ def _inspect_patch(path: str) -> None:
def _inspect_signature(path: str) -> None: def _inspect_signature(path: str) -> None:
blocks = 0 blocks = 0
with SignatureReader.open(path) as reader: with SignatureReader.open(path) as reader:
compression = reader.header.compression.algorithm if reader.header and reader.header.compression else None compression = (
reader.header.compression.algorithm
if reader.header and reader.header.compression
else None
)
for _ in reader.iter_block_hashes(): for _ in reader.iter_block_hashes():
blocks += 1 blocks += 1
print("signature") print("signature")
@@ -61,7 +69,11 @@ def _inspect_signature(path: str) -> None:
def _inspect_manifest(path: str) -> None: def _inspect_manifest(path: str) -> None:
hashes = 0 hashes = 0
with ManifestReader.open(path) as reader: with ManifestReader.open(path) as reader:
compression = reader.header.compression.algorithm if reader.header and reader.header.compression else None compression = (
reader.header.compression.algorithm
if reader.header and reader.header.compression
else None
)
for _ in reader.iter_block_hashes(): for _ in reader.iter_block_hashes():
hashes += 1 hashes += 1
print("manifest") print("manifest")

View File

@@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
import errno
import os import os
from collections import OrderedDict
from typing import BinaryIO from typing import BinaryIO
from .formats import FilePatch, PatchReader from .formats import PatchReader
from .proto import Control, SyncOp, SyncOpType from .proto import Control, SyncOp, SyncOpType, TlcContainer, TlcFile
from .proto import TlcContainer, TlcDir, TlcFile, TlcSymlink
from .wire import BLOCK_SIZE from .wire import BLOCK_SIZE
@@ -17,17 +18,37 @@ _MODE_MASK = 0o7777
class FilePool: class FilePool:
def __init__(self, paths: list[str]): def __init__(self, paths: list[str], max_open: int = 128):
self._paths = list(paths) self._paths = list(paths)
self._handles: dict[int, BinaryIO] = {} self._handles: OrderedDict[int, BinaryIO] = OrderedDict()
self._max_open = max(1, int(max_open))
def _touch(self, index: int, handle: BinaryIO) -> None:
self._handles.pop(index, None)
self._handles[index] = handle
def _evict_if_needed(self) -> None:
while len(self._handles) >= self._max_open:
_, handle = self._handles.popitem(last=False)
handle.close()
def open(self, index: int) -> BinaryIO: def open(self, index: int) -> BinaryIO:
if index < 0 or index >= len(self._paths): if index < 0 or index >= len(self._paths):
raise PatchApplyError(f"file index out of range: {index}") raise PatchApplyError(f"file index out of range: {index}")
handle = self._handles.get(index) handle = self._handles.get(index)
if handle is None: if handle is None:
handle = open(self._paths[index], "rb") try:
self._evict_if_needed()
handle = open(self._paths[index], "rb")
except OSError as exc:
if exc.errno in (errno.EMFILE, errno.ENFILE):
self.close()
handle = open(self._paths[index], "rb")
else:
raise
self._handles[index] = handle self._handles[index] = handle
else:
self._touch(index, handle)
return handle return handle
def size(self, index: int) -> int: def size(self, index: int) -> int:
@@ -41,43 +62,86 @@ class FilePool:
self._handles.clear() self._handles.clear()
def _copy_range(dst: BinaryIO, src: BinaryIO, length: int, buffer_size: int = 32 * 1024) -> None: def _copy_range(
dst: BinaryIO, src: BinaryIO, length: int, buffer_size: int = 32 * 1024
) -> None:
remaining = length remaining = length
while remaining > 0: while remaining > 0:
chunk = src.read(min(buffer_size, remaining)) chunk = src.read(min(buffer_size, remaining))
if not chunk: if not chunk:
raise PatchApplyError("unexpected EOF while copying block range") break
dst.write(chunk) dst.write(chunk)
remaining -= len(chunk) remaining -= len(chunk)
def _copy_all(dst: BinaryIO, src: BinaryIO, buffer_size: int = 32 * 1024) -> None:
while True:
chunk = src.read(buffer_size)
if not chunk:
return
dst.write(chunk)
def _compute_num_blocks(file_size: int) -> int:
return (file_size + BLOCK_SIZE - 1) // BLOCK_SIZE
def _normalize_op_type(op_type: SyncOpType | int | None) -> SyncOpType | int:
return op_type if op_type is not None else SyncOpType.BLOCK_RANGE
def _is_full_file_op(op: SyncOp, target_size: int, output_size: int) -> bool:
op_type = _normalize_op_type(op.type)
if op_type != SyncOpType.BLOCK_RANGE:
return False
block_index = 0 if op.block_index is None else op.block_index
if block_index != 0:
return False
if target_size != output_size:
return False
block_span = 0 if op.block_span is None else op.block_span
return block_span == _compute_num_blocks(output_size)
def _apply_file_mode(path: str, file: TlcFile | None) -> None:
if file is None or file.mode is None:
return
mode = int(file.mode) & _MODE_MASK
try:
os.chmod(path, mode)
except (PermissionError, OSError):
pass
def apply_rsync_ops(ops: list[SyncOp], target_pool: FilePool, output: BinaryIO) -> None: def apply_rsync_ops(ops: list[SyncOp], target_pool: FilePool, output: BinaryIO) -> None:
for op in ops: for op in ops:
if op.type == SyncOpType.DATA: op_type = _normalize_op_type(op.type)
if op_type == SyncOpType.DATA:
output.write(op.data or b"") output.write(op.data or b"")
continue continue
if op.type != SyncOpType.BLOCK_RANGE: if op_type != SyncOpType.BLOCK_RANGE:
raise PatchApplyError(f"unsupported sync op type: {op.type}") raise PatchApplyError(f"unsupported sync op type: {op_type}")
if op.file_index is None or op.block_index is None or op.block_span is None: file_index = 0 if op.file_index is None else op.file_index
raise PatchApplyError("missing fields in block range op") block_index = 0 if op.block_index is None else op.block_index
if op.block_span <= 0: block_span = 0 if op.block_span is None else op.block_span
raise PatchApplyError("invalid block span in block range op")
file_size = target_pool.size(op.file_index) file_size = target_pool.size(file_index)
last_block_index = op.block_index + op.block_span - 1 last_block_index = block_index + block_span - 1
last_block_size = BLOCK_SIZE last_block_size = BLOCK_SIZE
if BLOCK_SIZE * (last_block_index + 1) > file_size: if BLOCK_SIZE * (last_block_index + 1) > file_size:
last_block_size = file_size % BLOCK_SIZE last_block_size = file_size % BLOCK_SIZE
op_size = (op.block_span - 1) * BLOCK_SIZE + last_block_size op_size = (block_span - 1) * BLOCK_SIZE + last_block_size
src = target_pool.open(op.file_index) src = target_pool.open(file_index)
src.seek(op.block_index * BLOCK_SIZE) src.seek(block_index * BLOCK_SIZE)
_copy_range(output, src, op_size) _copy_range(output, src, op_size)
def apply_bsdiff_controls(controls: list[Control], old: BinaryIO, output: BinaryIO) -> None: def apply_bsdiff_controls(
controls: list[Control], old: BinaryIO, output: BinaryIO
) -> None:
old_offset = 0 old_offset = 0
for ctrl in controls: for ctrl in controls:
if ctrl.eof: if ctrl.eof:
@@ -108,33 +172,106 @@ def _ensure_parent(path: str) -> None:
os.makedirs(parent, exist_ok=True) os.makedirs(parent, exist_ok=True)
def apply_patch(patch_reader: PatchReader, target_paths: list[str], output_paths: list[str]) -> None: def apply_patch(
patch_reader: PatchReader, target_paths: list[str], output_paths: list[str]
) -> None:
pool = FilePool(target_paths) pool = FilePool(target_paths)
try: try:
for entry in patch_reader.iter_file_entries(): expected_files = len(output_paths)
if entry.sync_header.file_index is None: if patch_reader.source_container is not None:
raise PatchApplyError("missing file_index in sync header") expected_files = len(patch_reader.source_container.files)
out_index = entry.sync_header.file_index entry_iter = patch_reader.iter_file_entries()
for expected_index in range(expected_files):
try:
entry = next(entry_iter)
except StopIteration:
raise PatchApplyError(
f"corrupted patch: expected {expected_files} file entries, got {expected_index}"
)
header_index = (
0
if entry.sync_header.file_index is None
else int(entry.sync_header.file_index)
)
if header_index != expected_index:
raise PatchApplyError(
f"corrupted patch: expected file index {expected_index}, got {header_index}"
)
out_index = header_index
if out_index < 0 or out_index >= len(output_paths): if out_index < 0 or out_index >= len(output_paths):
raise PatchApplyError(f"output index out of range: {out_index}") raise PatchApplyError(f"output index out of range: {out_index}")
out_path = output_paths[out_index] out_path = output_paths[out_index]
_ensure_parent(out_path) _ensure_parent(out_path)
with open(out_path, "wb") as out: if entry.is_rsync():
if entry.is_rsync(): if entry.sync_ops is None:
if entry.sync_ops is None: raise PatchApplyError("missing rsync ops")
raise PatchApplyError("missing rsync ops") if entry.sync_ops:
op = entry.sync_ops[0]
target_index = 0 if op.file_index is None else op.file_index
target_file = None
output_file = None
if patch_reader.target_container and patch_reader.source_container:
if 0 <= target_index < len(patch_reader.target_container.files):
target_file = patch_reader.target_container.files[
target_index
]
if 0 <= out_index < len(patch_reader.source_container.files):
output_file = patch_reader.source_container.files[out_index]
if target_file is not None and output_file is not None:
target_size = (
int(target_file.size) if target_file.size is not None else 0
)
output_size = (
int(output_file.size) if output_file.size is not None else 0
)
if _is_full_file_op(op, target_size, output_size):
src = pool.open(target_index)
src.seek(0)
with open(out_path, "wb") as out:
_copy_all(out, src)
_apply_file_mode(out_path, output_file)
continue
with open(out_path, "wb") as out:
apply_rsync_ops(entry.sync_ops, pool, out) apply_rsync_ops(entry.sync_ops, pool, out)
elif entry.is_bsdiff(): output_file = None
if entry.bsdiff_header is None or entry.bsdiff_controls is None: if patch_reader.source_container and 0 <= out_index < len(
raise PatchApplyError("missing bsdiff data") patch_reader.source_container.files
if entry.bsdiff_header.target_index is None: ):
raise PatchApplyError("missing target_index in bsdiff header") output_file = patch_reader.source_container.files[out_index]
old = pool.open(entry.bsdiff_header.target_index) _apply_file_mode(out_path, output_file)
continue
if entry.is_bsdiff():
if entry.bsdiff_header is None or entry.bsdiff_controls is None:
raise PatchApplyError("missing bsdiff data")
target_index = (
0
if entry.bsdiff_header.target_index is None
else entry.bsdiff_header.target_index
)
with open(out_path, "wb") as out:
old = pool.open(target_index)
apply_bsdiff_controls(entry.bsdiff_controls, old, out) apply_bsdiff_controls(entry.bsdiff_controls, old, out)
else: expected_size = None
raise PatchApplyError("unknown file patch type") output_file = None
if patch_reader.source_container and 0 <= out_index < len(
patch_reader.source_container.files
):
output_file = patch_reader.source_container.files[out_index]
if output_file.size is not None:
expected_size = int(output_file.size)
final_size = out.tell()
if expected_size is not None and final_size != expected_size:
raise PatchApplyError(
f"corrupted patch: expected output size {expected_size}, got {final_size}"
)
_apply_file_mode(out_path, output_file)
continue
raise PatchApplyError("unknown file patch type")
finally: finally:
pool.close() pool.close()

View File

@@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import io
import gzip import gzip
import io
import typing
from typing import BinaryIO from typing import BinaryIO
from .proto import CompressionAlgorithm, CompressionSettings from .proto import CompressionAlgorithm, CompressionSettings
@@ -11,7 +12,9 @@ class CompressionError(Exception):
pass pass
def _normalize_algorithm(algorithm: CompressionAlgorithm | int | None) -> CompressionAlgorithm: def _normalize_algorithm(
algorithm: CompressionAlgorithm | int | None,
) -> CompressionAlgorithm:
if algorithm is None: if algorithm is None:
return CompressionAlgorithm.NONE return CompressionAlgorithm.NONE
if isinstance(algorithm, CompressionAlgorithm): if isinstance(algorithm, CompressionAlgorithm):
@@ -103,7 +106,7 @@ class _BrotliWriter(io.RawIOBase):
def writable(self) -> bool: def writable(self) -> bool:
return True return True
def write(self, b: bytes) -> int: def write(self, b) -> int:
out = self._compressor.process(b) out = self._compressor.process(b)
self._raw.write(out) self._raw.write(out)
return len(b) return len(b)
@@ -121,12 +124,14 @@ class _BrotliWriter(io.RawIOBase):
super().close() super().close()
def open_decompressed_reader(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO: def open_decompressed_reader(
stream: BinaryIO, compression: CompressionSettings | None
) -> BinaryIO:
algorithm = _normalize_algorithm(compression.algorithm if compression else None) algorithm = _normalize_algorithm(compression.algorithm if compression else None)
if algorithm == CompressionAlgorithm.NONE: if algorithm == CompressionAlgorithm.NONE:
return stream return stream
if algorithm == CompressionAlgorithm.GZIP: if algorithm == CompressionAlgorithm.GZIP:
return gzip.GzipFile(fileobj=stream, mode="rb") return typing.cast(BinaryIO, gzip.GzipFile(fileobj=stream, mode="rb"))
if algorithm == CompressionAlgorithm.BROTLI: if algorithm == CompressionAlgorithm.BROTLI:
return io.BufferedReader(_BrotliReader(stream)) return io.BufferedReader(_BrotliReader(stream))
if algorithm == CompressionAlgorithm.ZSTD: if algorithm == CompressionAlgorithm.ZSTD:
@@ -138,14 +143,18 @@ def open_decompressed_reader(stream: BinaryIO, compression: CompressionSettings
raise CompressionError(f"unsupported compression algorithm: {algorithm}") raise CompressionError(f"unsupported compression algorithm: {algorithm}")
def open_compressed_writer(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO: def open_compressed_writer(
stream: BinaryIO, compression: CompressionSettings | None
) -> BinaryIO:
algorithm = _normalize_algorithm(compression.algorithm if compression else None) algorithm = _normalize_algorithm(compression.algorithm if compression else None)
quality = compression.quality if compression else None quality = compression.quality if compression else None
if algorithm == CompressionAlgorithm.NONE: if algorithm == CompressionAlgorithm.NONE:
return stream return stream
if algorithm == CompressionAlgorithm.GZIP: if algorithm == CompressionAlgorithm.GZIP:
level = 9 if quality is None else int(quality) level = 9 if quality is None else int(quality)
return gzip.GzipFile(fileobj=stream, mode="wb", compresslevel=level) return typing.cast(
BinaryIO, gzip.GzipFile(fileobj=stream, mode="wb", compresslevel=level)
)
if algorithm == CompressionAlgorithm.BROTLI: if algorithm == CompressionAlgorithm.BROTLI:
return io.BufferedWriter(_BrotliWriter(stream, quality)) return io.BufferedWriter(_BrotliWriter(stream, quality))
if algorithm == CompressionAlgorithm.ZSTD: if algorithm == CompressionAlgorithm.ZSTD:

View File

@@ -31,7 +31,11 @@ from .wire import (
ContainerDecoder = Callable[[bytes], Any] ContainerDecoder = Callable[[bytes], Any]
def _split_decoder(decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None): def _split_decoder(
decoder: ContainerDecoder
| tuple[ContainerDecoder | None, ContainerDecoder | None]
| None,
):
if decoder is None: if decoder is None:
return None, None return None, None
if isinstance(decoder, tuple): if isinstance(decoder, tuple):
@@ -56,7 +60,14 @@ class FilePatch:
class PatchReader: class PatchReader:
def __init__(self, stream, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None): def __init__(
self,
stream,
*,
container_decoder: ContainerDecoder
| tuple[ContainerDecoder | None, ContainerDecoder | None]
| None = None,
):
self._stream = stream self._stream = stream
self._raw_wire = WireReader(stream) self._raw_wire = WireReader(stream)
self._wire: WireReader | None = None self._wire: WireReader | None = None
@@ -70,7 +81,14 @@ class PatchReader:
self._init_stream(container_decoder) self._init_stream(container_decoder)
@classmethod @classmethod
def open(cls, path: str, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None): def open(
cls,
path: str,
*,
container_decoder: ContainerDecoder
| tuple[ContainerDecoder | None, ContainerDecoder | None]
| None = None,
):
return cls(open(path, "rb"), container_decoder=container_decoder) return cls(open(path, "rb"), container_decoder=container_decoder)
def _init_stream(self, container_decoder): def _init_stream(self, container_decoder):
@@ -81,8 +99,12 @@ class PatchReader:
self._wire = WireReader(decompressed) self._wire = WireReader(decompressed)
target_decoder, source_decoder = _split_decoder(container_decoder) target_decoder, source_decoder = _split_decoder(container_decoder)
self.target_container_raw, self.target_container = self._read_container(target_decoder) self.target_container_raw, self.target_container = self._read_container(
self.source_container_raw, self.source_container = self._read_container(source_decoder) target_decoder
)
self.source_container_raw, self.source_container = self._read_container(
source_decoder
)
def _read_container(self, decoder: ContainerDecoder | None): def _read_container(self, decoder: ContainerDecoder | None):
assert self._wire is not None assert self._wire is not None
@@ -92,18 +114,12 @@ class PatchReader:
def iter_file_entries(self) -> Iterator[FilePatch]: def iter_file_entries(self) -> Iterator[FilePatch]:
assert self._wire is not None assert self._wire is not None
file_index = 0
while True: while True:
try: try:
sync_header = self._wire.read_message(SyncHeader) sync_header = self._wire.read_message(SyncHeader)
except EOFError: except EOFError:
return return
if sync_header.file_index is None:
sync_header.file_index = file_index
else:
file_index = int(sync_header.file_index)
header_type = sync_header.type header_type = sync_header.type
if header_type is None: if header_type is None:
header_type = SyncHeaderType.RSYNC header_type = SyncHeaderType.RSYNC
@@ -131,7 +147,6 @@ class PatchReader:
bsdiff_header=bsdiff_header, bsdiff_header=bsdiff_header,
bsdiff_controls=controls, bsdiff_controls=controls,
) )
file_index += 1
continue continue
if header_type != SyncHeaderType.RSYNC: if header_type != SyncHeaderType.RSYNC:
@@ -145,7 +160,6 @@ class PatchReader:
ops.append(op) ops.append(op)
yield FilePatch(sync_header=sync_header, sync_ops=ops) yield FilePatch(sync_header=sync_header, sync_ops=ops)
file_index += 1
def close(self) -> None: def close(self) -> None:
if self._wire: if self._wire:
@@ -182,7 +196,9 @@ class SignatureReader:
self._wire = WireReader(decompressed) self._wire = WireReader(decompressed)
self.container_raw = self._wire.read_message_bytes() self.container_raw = self._wire.read_message_bytes()
self.container = container_decoder(self.container_raw) if container_decoder else None self.container = (
container_decoder(self.container_raw) if container_decoder else None
)
def iter_block_hashes(self) -> Iterator[BlockHash]: def iter_block_hashes(self) -> Iterator[BlockHash]:
assert self._wire is not None assert self._wire is not None
@@ -227,7 +243,9 @@ class ManifestReader:
self._wire = WireReader(decompressed) self._wire = WireReader(decompressed)
self.container_raw = self._wire.read_message_bytes() self.container_raw = self._wire.read_message_bytes()
self.container = container_decoder(self.container_raw) if container_decoder else None self.container = (
container_decoder(self.container_raw) if container_decoder else None
)
def iter_block_hashes(self) -> Iterator[ManifestBlockHash]: def iter_block_hashes(self) -> Iterator[ManifestBlockHash]:
assert self._wire is not None assert self._wire is not None
@@ -267,7 +285,9 @@ class WoundsReader:
self._wire.expect_magic(WOUNDS_MAGIC) self._wire.expect_magic(WOUNDS_MAGIC)
self.header = self._wire.read_message(WoundsHeader) self.header = self._wire.read_message(WoundsHeader)
self.container_raw = self._wire.read_message_bytes() self.container_raw = self._wire.read_message_bytes()
self.container = container_decoder(self.container_raw) if container_decoder else None self.container = (
container_decoder(self.container_raw) if container_decoder else None
)
def iter_wounds(self) -> Iterator[Wound]: def iter_wounds(self) -> Iterator[Wound]:
while True: while True:

View File

@@ -73,7 +73,7 @@ class ProtoReader:
end = self._pos + length end = self._pos + length
if end > len(self._data): if end > len(self._data):
raise ProtoError("unexpected EOF while reading length-delimited field") raise ProtoError("unexpected EOF while reading length-delimited field")
data = self._data[self._pos:end] data = self._data[self._pos : end]
self._pos = end self._pos = end
return data return data