141 lines
3.6 KiB
Python
141 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
import io
|
|
import struct
|
|
from typing import BinaryIO, Type, TypeVar
|
|
|
|
from .proto import ProtoError, ProtoMessage
|
|
|
|
PATCH_MAGIC = 0xFEF5F00
|
|
SIGNATURE_MAGIC = PATCH_MAGIC + 1
|
|
MANIFEST_MAGIC = PATCH_MAGIC + 2
|
|
WOUNDS_MAGIC = PATCH_MAGIC + 3
|
|
ZIP_INDEX_MAGIC = PATCH_MAGIC + 4
|
|
|
|
BLOCK_SIZE = 64 * 1024
|
|
|
|
|
|
class WireError(Exception):
|
|
pass
|
|
|
|
|
|
def _read_exact(stream: BinaryIO, size: int) -> bytes:
|
|
if size <= 0:
|
|
return b""
|
|
try:
|
|
data = stream.read(size)
|
|
if data is not None:
|
|
return data
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
readinto = getattr(stream, "readinto", None)
|
|
if readinto is None:
|
|
raise WireError("stream does not support read or readinto")
|
|
|
|
buf = bytearray(size)
|
|
view = memoryview(buf)
|
|
total = 0
|
|
while total < size:
|
|
n = readinto(view[total:])
|
|
if n is None:
|
|
n = 0
|
|
if n == 0:
|
|
break
|
|
total += n
|
|
return bytes(view[:total])
|
|
|
|
|
|
def read_varint(stream: BinaryIO) -> int:
|
|
shift = 0
|
|
result = 0
|
|
while True:
|
|
b = _read_exact(stream, 1)
|
|
if not b:
|
|
raise EOFError("unexpected EOF while reading varint")
|
|
byte = b[0]
|
|
result |= (byte & 0x7F) << shift
|
|
if not (byte & 0x80):
|
|
return result
|
|
shift += 7
|
|
if shift > 64:
|
|
raise WireError("varint too long")
|
|
|
|
|
|
def write_varint(stream: BinaryIO, value: int) -> None:
|
|
if value < 0:
|
|
raise WireError("varint cannot encode negative values")
|
|
while True:
|
|
to_write = value & 0x7F
|
|
value >>= 7
|
|
if value:
|
|
stream.write(bytes([to_write | 0x80]))
|
|
else:
|
|
stream.write(bytes([to_write]))
|
|
break
|
|
|
|
|
|
def read_magic(stream: BinaryIO) -> int:
|
|
data = _read_exact(stream, 4)
|
|
if len(data) != 4:
|
|
raise EOFError("unexpected EOF while reading magic")
|
|
return struct.unpack("<i", data)[0]
|
|
|
|
|
|
def write_magic(stream: BinaryIO, magic: int) -> None:
|
|
stream.write(struct.pack("<i", int(magic)))
|
|
|
|
|
|
T = TypeVar("T", bound=ProtoMessage)
|
|
|
|
|
|
class WireReader:
|
|
def __init__(self, stream: BinaryIO):
|
|
self._stream = stream
|
|
|
|
def read_magic(self) -> int:
|
|
return read_magic(self._stream)
|
|
|
|
def expect_magic(self, magic: int) -> None:
|
|
found = self.read_magic()
|
|
if found != magic:
|
|
raise WireError(f"wrong magic: expected {magic}, got {found}")
|
|
|
|
def read_message_bytes(self) -> bytes:
|
|
length = read_varint(self._stream)
|
|
data = _read_exact(self._stream, length)
|
|
if len(data) != length:
|
|
raise EOFError("unexpected EOF while reading message")
|
|
return data
|
|
|
|
def read_message(self, msg_cls: Type[T]) -> T:
|
|
data = self.read_message_bytes()
|
|
try:
|
|
return msg_cls.from_bytes(data)
|
|
except ProtoError as exc:
|
|
raise WireError(str(exc)) from exc
|
|
|
|
def close(self) -> None:
|
|
if isinstance(self._stream, io.IOBase):
|
|
self._stream.close()
|
|
|
|
|
|
class WireWriter:
|
|
def __init__(self, stream: BinaryIO):
|
|
self._stream = stream
|
|
|
|
def write_magic(self, magic: int) -> None:
|
|
write_magic(self._stream, magic)
|
|
|
|
def write_message(self, msg: ProtoMessage | bytes | bytearray | memoryview) -> None:
|
|
if isinstance(msg, (bytes, bytearray, memoryview)):
|
|
data = bytes(msg)
|
|
else:
|
|
data = msg.to_bytes()
|
|
write_varint(self._stream, len(data))
|
|
self._stream.write(data)
|
|
|
|
def close(self) -> None:
|
|
if isinstance(self._stream, io.IOBase):
|
|
self._stream.close()
|