from __future__ import annotations import io import struct from typing import BinaryIO, Type, TypeVar from .proto import ProtoError, ProtoMessage PATCH_MAGIC = 0xFEF5F00 SIGNATURE_MAGIC = PATCH_MAGIC + 1 MANIFEST_MAGIC = PATCH_MAGIC + 2 WOUNDS_MAGIC = PATCH_MAGIC + 3 ZIP_INDEX_MAGIC = PATCH_MAGIC + 4 BLOCK_SIZE = 64 * 1024 class WireError(Exception): pass def _read_exact(stream: BinaryIO, size: int) -> bytes: if size <= 0: return b"" try: data = stream.read(size) if data is not None: return data except NotImplementedError: pass readinto = getattr(stream, "readinto", None) if readinto is None: raise WireError("stream does not support read or readinto") buf = bytearray(size) view = memoryview(buf) total = 0 while total < size: n = readinto(view[total:]) if n is None: n = 0 if n == 0: break total += n return bytes(view[:total]) def read_varint(stream: BinaryIO) -> int: shift = 0 result = 0 while True: b = _read_exact(stream, 1) if not b: raise EOFError("unexpected EOF while reading varint") byte = b[0] result |= (byte & 0x7F) << shift if not (byte & 0x80): return result shift += 7 if shift > 64: raise WireError("varint too long") def write_varint(stream: BinaryIO, value: int) -> None: if value < 0: raise WireError("varint cannot encode negative values") while True: to_write = value & 0x7F value >>= 7 if value: stream.write(bytes([to_write | 0x80])) else: stream.write(bytes([to_write])) break def read_magic(stream: BinaryIO) -> int: data = _read_exact(stream, 4) if len(data) != 4: raise EOFError("unexpected EOF while reading magic") return struct.unpack(" None: stream.write(struct.pack(" int: return read_magic(self._stream) def expect_magic(self, magic: int) -> None: found = self.read_magic() if found != magic: raise WireError(f"wrong magic: expected {magic}, got {found}") def read_message_bytes(self) -> bytes: length = read_varint(self._stream) data = _read_exact(self._stream, length) if len(data) != length: raise EOFError("unexpected EOF while reading message") return data def read_message(self, msg_cls: Type[T]) -> T: data = self.read_message_bytes() try: return msg_cls.from_bytes(data) except ProtoError as exc: raise WireError(str(exc)) from exc def close(self) -> None: if isinstance(self._stream, io.IOBase): self._stream.close() class WireWriter: def __init__(self, stream: BinaryIO): self._stream = stream def write_magic(self, magic: int) -> None: write_magic(self._stream, magic) def write_message(self, msg: ProtoMessage | bytes | bytearray | memoryview) -> None: if isinstance(msg, (bytes, bytearray, memoryview)): data = bytes(msg) else: data = msg.to_bytes() write_varint(self._stream, len(data)) self._stream.write(data) def close(self) -> None: if isinstance(self._stream, io.IOBase): self._stream.close()