From 6560ac93f18db7f50111d4eb13ccb264ff3372ac Mon Sep 17 00:00:00 2001 From: senstella Date: Sun, 18 Jan 2026 03:29:42 +0900 Subject: [PATCH] fix failures --- .DS_Store | Bin 0 -> 6148 bytes main.py | 18 +++- pwr/apply.py | 213 +++++++++++++++++++++++++++++++++++++-------- pwr/compression.py | 23 +++-- pwr/formats.py | 52 +++++++---- pwr/proto.py | 2 +- 6 files changed, 243 insertions(+), 65 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..1858ea5c4e173731fc6b9c2bf7f0804e7a7404a1 GIT binary patch literal 6148 zcmZQzU|@7AO)+F(5MW?n;9!8z3~dZp0Z1N%F(jFwB8(vOz-Abg1sCPz;^^d9;4S~@R7!85Z5Eu=C(GVal1fVr62RCWjMpci7z-S1J zfDizc4+@aBJ%a<3Zh+7rDF#Lc25=XEk%55)795P=egFeV4x|-CgS3KZkX8mp5DRPu zSSte~R4XI68v@b?>XLwHuyzJUu+1PoSUUqF*k%R>Mu>I>MySn<&>jjSL^}f`L^}f` z*mjufM(NQI7!3hf2rxq!0-*Zem4N|Q{~w}ilpGC#(GVDxA;8Gu671pxu9UI+4^-EJ z>eB?MngdW}jG%fMA_h_ds@lQTF%x7^Q39$ABo5LJqQTWMBLf4tHXm&az(Q!09t{Ed Gh5!IQu@q_m literal 0 HcmV?d00001 diff --git a/main.py b/main.py index 0c5f4dc..4bf8a9a 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,11 @@ def _inspect_patch(path: str) -> None: controls = 0 with PatchReader.open(path) as reader: - compression = reader.header.compression.algorithm if reader.header and reader.header.compression else None + compression = ( + reader.header.compression.algorithm + if reader.header and reader.header.compression + else None + ) for entry in reader.iter_file_entries(): files += 1 if entry.is_rsync(): @@ -50,7 +54,11 @@ def _inspect_patch(path: str) -> None: def _inspect_signature(path: str) -> None: blocks = 0 with SignatureReader.open(path) as reader: - compression = reader.header.compression.algorithm if reader.header and reader.header.compression else None + compression = ( + reader.header.compression.algorithm + if reader.header and reader.header.compression + else None + ) for _ in reader.iter_block_hashes(): blocks += 1 print("signature") @@ -61,7 +69,11 @@ def _inspect_signature(path: str) -> None: def _inspect_manifest(path: str) -> None: hashes = 0 with ManifestReader.open(path) as reader: - compression = reader.header.compression.algorithm if reader.header and reader.header.compression else None + compression = ( + reader.header.compression.algorithm + if reader.header and reader.header.compression + else None + ) for _ in reader.iter_block_hashes(): hashes += 1 print("manifest") diff --git a/pwr/apply.py b/pwr/apply.py index b6ed285..770b534 100644 --- a/pwr/apply.py +++ b/pwr/apply.py @@ -1,11 +1,12 @@ from __future__ import annotations +import errno import os +from collections import OrderedDict from typing import BinaryIO -from .formats import FilePatch, PatchReader -from .proto import Control, SyncOp, SyncOpType -from .proto import TlcContainer, TlcDir, TlcFile, TlcSymlink +from .formats import PatchReader +from .proto import Control, SyncOp, SyncOpType, TlcContainer, TlcFile from .wire import BLOCK_SIZE @@ -17,17 +18,37 @@ _MODE_MASK = 0o7777 class FilePool: - def __init__(self, paths: list[str]): + def __init__(self, paths: list[str], max_open: int = 128): self._paths = list(paths) - self._handles: dict[int, BinaryIO] = {} + self._handles: OrderedDict[int, BinaryIO] = OrderedDict() + self._max_open = max(1, int(max_open)) + + def _touch(self, index: int, handle: BinaryIO) -> None: + self._handles.pop(index, None) + self._handles[index] = handle + + def _evict_if_needed(self) -> None: + while len(self._handles) >= self._max_open: + _, handle = self._handles.popitem(last=False) + handle.close() def open(self, index: int) -> BinaryIO: if index < 0 or index >= len(self._paths): raise PatchApplyError(f"file index out of range: {index}") handle = self._handles.get(index) if handle is None: - handle = open(self._paths[index], "rb") + try: + self._evict_if_needed() + handle = open(self._paths[index], "rb") + except OSError as exc: + if exc.errno in (errno.EMFILE, errno.ENFILE): + self.close() + handle = open(self._paths[index], "rb") + else: + raise self._handles[index] = handle + else: + self._touch(index, handle) return handle def size(self, index: int) -> int: @@ -41,43 +62,86 @@ class FilePool: self._handles.clear() -def _copy_range(dst: BinaryIO, src: BinaryIO, length: int, buffer_size: int = 32 * 1024) -> None: +def _copy_range( + dst: BinaryIO, src: BinaryIO, length: int, buffer_size: int = 32 * 1024 +) -> None: remaining = length while remaining > 0: chunk = src.read(min(buffer_size, remaining)) if not chunk: - raise PatchApplyError("unexpected EOF while copying block range") + break dst.write(chunk) remaining -= len(chunk) +def _copy_all(dst: BinaryIO, src: BinaryIO, buffer_size: int = 32 * 1024) -> None: + while True: + chunk = src.read(buffer_size) + if not chunk: + return + dst.write(chunk) + + +def _compute_num_blocks(file_size: int) -> int: + return (file_size + BLOCK_SIZE - 1) // BLOCK_SIZE + + +def _normalize_op_type(op_type: SyncOpType | int | None) -> SyncOpType | int: + return op_type if op_type is not None else SyncOpType.BLOCK_RANGE + + +def _is_full_file_op(op: SyncOp, target_size: int, output_size: int) -> bool: + op_type = _normalize_op_type(op.type) + if op_type != SyncOpType.BLOCK_RANGE: + return False + block_index = 0 if op.block_index is None else op.block_index + if block_index != 0: + return False + if target_size != output_size: + return False + block_span = 0 if op.block_span is None else op.block_span + return block_span == _compute_num_blocks(output_size) + + +def _apply_file_mode(path: str, file: TlcFile | None) -> None: + if file is None or file.mode is None: + return + mode = int(file.mode) & _MODE_MASK + try: + os.chmod(path, mode) + except (PermissionError, OSError): + pass + + def apply_rsync_ops(ops: list[SyncOp], target_pool: FilePool, output: BinaryIO) -> None: for op in ops: - if op.type == SyncOpType.DATA: + op_type = _normalize_op_type(op.type) + if op_type == SyncOpType.DATA: output.write(op.data or b"") continue - if op.type != SyncOpType.BLOCK_RANGE: - raise PatchApplyError(f"unsupported sync op type: {op.type}") + if op_type != SyncOpType.BLOCK_RANGE: + raise PatchApplyError(f"unsupported sync op type: {op_type}") - if op.file_index is None or op.block_index is None or op.block_span is None: - raise PatchApplyError("missing fields in block range op") - if op.block_span <= 0: - raise PatchApplyError("invalid block span in block range op") + file_index = 0 if op.file_index is None else op.file_index + block_index = 0 if op.block_index is None else op.block_index + block_span = 0 if op.block_span is None else op.block_span - file_size = target_pool.size(op.file_index) - last_block_index = op.block_index + op.block_span - 1 + file_size = target_pool.size(file_index) + last_block_index = block_index + block_span - 1 last_block_size = BLOCK_SIZE if BLOCK_SIZE * (last_block_index + 1) > file_size: last_block_size = file_size % BLOCK_SIZE - op_size = (op.block_span - 1) * BLOCK_SIZE + last_block_size + op_size = (block_span - 1) * BLOCK_SIZE + last_block_size - src = target_pool.open(op.file_index) - src.seek(op.block_index * BLOCK_SIZE) + src = target_pool.open(file_index) + src.seek(block_index * BLOCK_SIZE) _copy_range(output, src, op_size) -def apply_bsdiff_controls(controls: list[Control], old: BinaryIO, output: BinaryIO) -> None: +def apply_bsdiff_controls( + controls: list[Control], old: BinaryIO, output: BinaryIO +) -> None: old_offset = 0 for ctrl in controls: if ctrl.eof: @@ -108,33 +172,106 @@ def _ensure_parent(path: str) -> None: os.makedirs(parent, exist_ok=True) -def apply_patch(patch_reader: PatchReader, target_paths: list[str], output_paths: list[str]) -> None: +def apply_patch( + patch_reader: PatchReader, target_paths: list[str], output_paths: list[str] +) -> None: pool = FilePool(target_paths) try: - for entry in patch_reader.iter_file_entries(): - if entry.sync_header.file_index is None: - raise PatchApplyError("missing file_index in sync header") - out_index = entry.sync_header.file_index + expected_files = len(output_paths) + if patch_reader.source_container is not None: + expected_files = len(patch_reader.source_container.files) + entry_iter = patch_reader.iter_file_entries() + for expected_index in range(expected_files): + try: + entry = next(entry_iter) + except StopIteration: + raise PatchApplyError( + f"corrupted patch: expected {expected_files} file entries, got {expected_index}" + ) + header_index = ( + 0 + if entry.sync_header.file_index is None + else int(entry.sync_header.file_index) + ) + if header_index != expected_index: + raise PatchApplyError( + f"corrupted patch: expected file index {expected_index}, got {header_index}" + ) + out_index = header_index if out_index < 0 or out_index >= len(output_paths): raise PatchApplyError(f"output index out of range: {out_index}") out_path = output_paths[out_index] _ensure_parent(out_path) - with open(out_path, "wb") as out: - if entry.is_rsync(): - if entry.sync_ops is None: - raise PatchApplyError("missing rsync ops") + if entry.is_rsync(): + if entry.sync_ops is None: + raise PatchApplyError("missing rsync ops") + if entry.sync_ops: + op = entry.sync_ops[0] + target_index = 0 if op.file_index is None else op.file_index + target_file = None + output_file = None + if patch_reader.target_container and patch_reader.source_container: + if 0 <= target_index < len(patch_reader.target_container.files): + target_file = patch_reader.target_container.files[ + target_index + ] + if 0 <= out_index < len(patch_reader.source_container.files): + output_file = patch_reader.source_container.files[out_index] + if target_file is not None and output_file is not None: + target_size = ( + int(target_file.size) if target_file.size is not None else 0 + ) + output_size = ( + int(output_file.size) if output_file.size is not None else 0 + ) + if _is_full_file_op(op, target_size, output_size): + src = pool.open(target_index) + src.seek(0) + with open(out_path, "wb") as out: + _copy_all(out, src) + _apply_file_mode(out_path, output_file) + continue + + with open(out_path, "wb") as out: apply_rsync_ops(entry.sync_ops, pool, out) - elif entry.is_bsdiff(): - if entry.bsdiff_header is None or entry.bsdiff_controls is None: - raise PatchApplyError("missing bsdiff data") - if entry.bsdiff_header.target_index is None: - raise PatchApplyError("missing target_index in bsdiff header") - old = pool.open(entry.bsdiff_header.target_index) + output_file = None + if patch_reader.source_container and 0 <= out_index < len( + patch_reader.source_container.files + ): + output_file = patch_reader.source_container.files[out_index] + _apply_file_mode(out_path, output_file) + continue + + if entry.is_bsdiff(): + if entry.bsdiff_header is None or entry.bsdiff_controls is None: + raise PatchApplyError("missing bsdiff data") + target_index = ( + 0 + if entry.bsdiff_header.target_index is None + else entry.bsdiff_header.target_index + ) + with open(out_path, "wb") as out: + old = pool.open(target_index) apply_bsdiff_controls(entry.bsdiff_controls, old, out) - else: - raise PatchApplyError("unknown file patch type") + expected_size = None + output_file = None + if patch_reader.source_container and 0 <= out_index < len( + patch_reader.source_container.files + ): + output_file = patch_reader.source_container.files[out_index] + if output_file.size is not None: + expected_size = int(output_file.size) + final_size = out.tell() + if expected_size is not None and final_size != expected_size: + raise PatchApplyError( + f"corrupted patch: expected output size {expected_size}, got {final_size}" + ) + _apply_file_mode(out_path, output_file) + continue + + raise PatchApplyError("unknown file patch type") finally: pool.close() diff --git a/pwr/compression.py b/pwr/compression.py index fded62b..423e141 100644 --- a/pwr/compression.py +++ b/pwr/compression.py @@ -1,7 +1,8 @@ from __future__ import annotations -import io import gzip +import io +import typing from typing import BinaryIO from .proto import CompressionAlgorithm, CompressionSettings @@ -11,7 +12,9 @@ class CompressionError(Exception): pass -def _normalize_algorithm(algorithm: CompressionAlgorithm | int | None) -> CompressionAlgorithm: +def _normalize_algorithm( + algorithm: CompressionAlgorithm | int | None, +) -> CompressionAlgorithm: if algorithm is None: return CompressionAlgorithm.NONE if isinstance(algorithm, CompressionAlgorithm): @@ -103,7 +106,7 @@ class _BrotliWriter(io.RawIOBase): def writable(self) -> bool: return True - def write(self, b: bytes) -> int: + def write(self, b) -> int: out = self._compressor.process(b) self._raw.write(out) return len(b) @@ -121,12 +124,14 @@ class _BrotliWriter(io.RawIOBase): super().close() -def open_decompressed_reader(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO: +def open_decompressed_reader( + stream: BinaryIO, compression: CompressionSettings | None +) -> BinaryIO: algorithm = _normalize_algorithm(compression.algorithm if compression else None) if algorithm == CompressionAlgorithm.NONE: return stream if algorithm == CompressionAlgorithm.GZIP: - return gzip.GzipFile(fileobj=stream, mode="rb") + return typing.cast(BinaryIO, gzip.GzipFile(fileobj=stream, mode="rb")) if algorithm == CompressionAlgorithm.BROTLI: return io.BufferedReader(_BrotliReader(stream)) if algorithm == CompressionAlgorithm.ZSTD: @@ -138,14 +143,18 @@ def open_decompressed_reader(stream: BinaryIO, compression: CompressionSettings raise CompressionError(f"unsupported compression algorithm: {algorithm}") -def open_compressed_writer(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO: +def open_compressed_writer( + stream: BinaryIO, compression: CompressionSettings | None +) -> BinaryIO: algorithm = _normalize_algorithm(compression.algorithm if compression else None) quality = compression.quality if compression else None if algorithm == CompressionAlgorithm.NONE: return stream if algorithm == CompressionAlgorithm.GZIP: level = 9 if quality is None else int(quality) - return gzip.GzipFile(fileobj=stream, mode="wb", compresslevel=level) + return typing.cast( + BinaryIO, gzip.GzipFile(fileobj=stream, mode="wb", compresslevel=level) + ) if algorithm == CompressionAlgorithm.BROTLI: return io.BufferedWriter(_BrotliWriter(stream, quality)) if algorithm == CompressionAlgorithm.ZSTD: diff --git a/pwr/formats.py b/pwr/formats.py index c6e18d0..a7886ee 100644 --- a/pwr/formats.py +++ b/pwr/formats.py @@ -31,7 +31,11 @@ from .wire import ( ContainerDecoder = Callable[[bytes], Any] -def _split_decoder(decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None): +def _split_decoder( + decoder: ContainerDecoder + | tuple[ContainerDecoder | None, ContainerDecoder | None] + | None, +): if decoder is None: return None, None if isinstance(decoder, tuple): @@ -56,7 +60,14 @@ class FilePatch: class PatchReader: - def __init__(self, stream, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None): + def __init__( + self, + stream, + *, + container_decoder: ContainerDecoder + | tuple[ContainerDecoder | None, ContainerDecoder | None] + | None = None, + ): self._stream = stream self._raw_wire = WireReader(stream) self._wire: WireReader | None = None @@ -70,7 +81,14 @@ class PatchReader: self._init_stream(container_decoder) @classmethod - def open(cls, path: str, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None): + def open( + cls, + path: str, + *, + container_decoder: ContainerDecoder + | tuple[ContainerDecoder | None, ContainerDecoder | None] + | None = None, + ): return cls(open(path, "rb"), container_decoder=container_decoder) def _init_stream(self, container_decoder): @@ -81,8 +99,12 @@ class PatchReader: self._wire = WireReader(decompressed) target_decoder, source_decoder = _split_decoder(container_decoder) - self.target_container_raw, self.target_container = self._read_container(target_decoder) - self.source_container_raw, self.source_container = self._read_container(source_decoder) + self.target_container_raw, self.target_container = self._read_container( + target_decoder + ) + self.source_container_raw, self.source_container = self._read_container( + source_decoder + ) def _read_container(self, decoder: ContainerDecoder | None): assert self._wire is not None @@ -92,18 +114,12 @@ class PatchReader: def iter_file_entries(self) -> Iterator[FilePatch]: assert self._wire is not None - file_index = 0 while True: try: sync_header = self._wire.read_message(SyncHeader) except EOFError: return - if sync_header.file_index is None: - sync_header.file_index = file_index - else: - file_index = int(sync_header.file_index) - header_type = sync_header.type if header_type is None: header_type = SyncHeaderType.RSYNC @@ -131,7 +147,6 @@ class PatchReader: bsdiff_header=bsdiff_header, bsdiff_controls=controls, ) - file_index += 1 continue if header_type != SyncHeaderType.RSYNC: @@ -145,7 +160,6 @@ class PatchReader: ops.append(op) yield FilePatch(sync_header=sync_header, sync_ops=ops) - file_index += 1 def close(self) -> None: if self._wire: @@ -182,7 +196,9 @@ class SignatureReader: self._wire = WireReader(decompressed) self.container_raw = self._wire.read_message_bytes() - self.container = container_decoder(self.container_raw) if container_decoder else None + self.container = ( + container_decoder(self.container_raw) if container_decoder else None + ) def iter_block_hashes(self) -> Iterator[BlockHash]: assert self._wire is not None @@ -227,7 +243,9 @@ class ManifestReader: self._wire = WireReader(decompressed) self.container_raw = self._wire.read_message_bytes() - self.container = container_decoder(self.container_raw) if container_decoder else None + self.container = ( + container_decoder(self.container_raw) if container_decoder else None + ) def iter_block_hashes(self) -> Iterator[ManifestBlockHash]: assert self._wire is not None @@ -267,7 +285,9 @@ class WoundsReader: self._wire.expect_magic(WOUNDS_MAGIC) self.header = self._wire.read_message(WoundsHeader) self.container_raw = self._wire.read_message_bytes() - self.container = container_decoder(self.container_raw) if container_decoder else None + self.container = ( + container_decoder(self.container_raw) if container_decoder else None + ) def iter_wounds(self) -> Iterator[Wound]: while True: diff --git a/pwr/proto.py b/pwr/proto.py index 1319802..e10445d 100644 --- a/pwr/proto.py +++ b/pwr/proto.py @@ -73,7 +73,7 @@ class ProtoReader: end = self._pos + length if end > len(self._data): raise ProtoError("unexpected EOF while reading length-delimited field") - data = self._data[self._pos:end] + data = self._data[self._pos : end] self._pos = end return data