from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable, Iterator from .compression import open_decompressed_reader from .proto import ( BlockHash, BsdiffHeader, Control, ManifestBlockHash, ManifestHeader, PatchHeader, SignatureHeader, SyncHeader, SyncHeaderType, SyncOp, SyncOpType, Wound, WoundsHeader, ) from .wire import ( MANIFEST_MAGIC, PATCH_MAGIC, SIGNATURE_MAGIC, WOUNDS_MAGIC, WireError, WireReader, ) ContainerDecoder = Callable[[bytes], Any] def _split_decoder(decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None): if decoder is None: return None, None if isinstance(decoder, tuple): if len(decoder) != 2: raise ValueError("container decoder tuple must have two elements") return decoder return decoder, decoder @dataclass class FilePatch: sync_header: SyncHeader sync_ops: list[SyncOp] | None = None bsdiff_header: BsdiffHeader | None = None bsdiff_controls: list[Control] | None = None def is_rsync(self) -> bool: return self.sync_ops is not None def is_bsdiff(self) -> bool: return self.bsdiff_header is not None class PatchReader: def __init__(self, stream, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None): self._stream = stream self._raw_wire = WireReader(stream) self._wire: WireReader | None = None self.header: PatchHeader | None = None self.target_container_raw: bytes | None = None self.source_container_raw: bytes | None = None self.target_container: Any | None = None self.source_container: Any | None = None self._init_stream(container_decoder) @classmethod def open(cls, path: str, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None): return cls(open(path, "rb"), container_decoder=container_decoder) def _init_stream(self, container_decoder): self._raw_wire.expect_magic(PATCH_MAGIC) self.header = self._raw_wire.read_message(PatchHeader) decompressed = open_decompressed_reader(self._stream, self.header.compression) self._wire = WireReader(decompressed) target_decoder, source_decoder = _split_decoder(container_decoder) self.target_container_raw, self.target_container = self._read_container(target_decoder) self.source_container_raw, self.source_container = self._read_container(source_decoder) def _read_container(self, decoder: ContainerDecoder | None): assert self._wire is not None raw = self._wire.read_message_bytes() parsed = decoder(raw) if decoder else None return raw, parsed def iter_file_entries(self) -> Iterator[FilePatch]: assert self._wire is not None file_index = 0 while True: try: sync_header = self._wire.read_message(SyncHeader) except EOFError: return if sync_header.file_index is None: sync_header.file_index = file_index else: file_index = int(sync_header.file_index) header_type = sync_header.type if header_type is None: header_type = SyncHeaderType.RSYNC else: try: header_type = SyncHeaderType(int(header_type)) except ValueError: raise WireError(f"unknown sync header type: {header_type}") if header_type == SyncHeaderType.BSDIFF: bsdiff_header = self._wire.read_message(BsdiffHeader) controls: list[Control] = [] while True: ctrl = self._wire.read_message(Control) controls.append(ctrl) if ctrl.eof: break sentinel = self._wire.read_message(SyncOp) if sentinel.type != SyncOpType.HEY_YOU_DID_IT: raise WireError("expected BSDIFF sentinel SyncOp") yield FilePatch( sync_header=sync_header, bsdiff_header=bsdiff_header, bsdiff_controls=controls, ) file_index += 1 continue if header_type != SyncHeaderType.RSYNC: raise WireError(f"unsupported sync header type: {header_type}") ops: list[SyncOp] = [] while True: op = self._wire.read_message(SyncOp) if op.type == SyncOpType.HEY_YOU_DID_IT: break ops.append(op) yield FilePatch(sync_header=sync_header, sync_ops=ops) file_index += 1 def close(self) -> None: if self._wire: self._wire.close() def __enter__(self): return self def __exit__(self, exc_type, exc, tb): self.close() class SignatureReader: def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None): self._stream = stream self._raw_wire = WireReader(stream) self._wire: WireReader | None = None self.header: SignatureHeader | None = None self.container_raw: bytes | None = None self.container: Any | None = None self._init_stream(container_decoder) @classmethod def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None): return cls(open(path, "rb"), container_decoder=container_decoder) def _init_stream(self, container_decoder): self._raw_wire.expect_magic(SIGNATURE_MAGIC) self.header = self._raw_wire.read_message(SignatureHeader) decompressed = open_decompressed_reader(self._stream, self.header.compression) self._wire = WireReader(decompressed) self.container_raw = self._wire.read_message_bytes() self.container = container_decoder(self.container_raw) if container_decoder else None def iter_block_hashes(self) -> Iterator[BlockHash]: assert self._wire is not None while True: try: yield self._wire.read_message(BlockHash) except EOFError: return def close(self) -> None: if self._wire: self._wire.close() def __enter__(self): return self def __exit__(self, exc_type, exc, tb): self.close() class ManifestReader: def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None): self._stream = stream self._raw_wire = WireReader(stream) self._wire: WireReader | None = None self.header: ManifestHeader | None = None self.container_raw: bytes | None = None self.container: Any | None = None self._init_stream(container_decoder) @classmethod def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None): return cls(open(path, "rb"), container_decoder=container_decoder) def _init_stream(self, container_decoder): self._raw_wire.expect_magic(MANIFEST_MAGIC) self.header = self._raw_wire.read_message(ManifestHeader) decompressed = open_decompressed_reader(self._stream, self.header.compression) self._wire = WireReader(decompressed) self.container_raw = self._wire.read_message_bytes() self.container = container_decoder(self.container_raw) if container_decoder else None def iter_block_hashes(self) -> Iterator[ManifestBlockHash]: assert self._wire is not None while True: try: yield self._wire.read_message(ManifestBlockHash) except EOFError: return def close(self) -> None: if self._wire: self._wire.close() def __enter__(self): return self def __exit__(self, exc_type, exc, tb): self.close() class WoundsReader: def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None): self._stream = stream self._wire = WireReader(stream) self.header: WoundsHeader | None = None self.container_raw: bytes | None = None self.container: Any | None = None self._init_stream(container_decoder) @classmethod def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None): return cls(open(path, "rb"), container_decoder=container_decoder) def _init_stream(self, container_decoder): self._wire.expect_magic(WOUNDS_MAGIC) self.header = self._wire.read_message(WoundsHeader) self.container_raw = self._wire.read_message_bytes() self.container = container_decoder(self.container_raw) if container_decoder else None def iter_wounds(self) -> Iterator[Wound]: while True: try: yield self._wire.read_message(Wound) except EOFError: return def close(self) -> None: self._wire.close() def __enter__(self): return self def __exit__(self, exc_type, exc, tb): self.close()