Files
py-pwr/pwr/wire.py
senstella ca85a52839 init
2026-01-17 14:13:43 +09:00

141 lines
3.6 KiB
Python

from __future__ import annotations
import io
import struct
from typing import BinaryIO, Type, TypeVar
from .proto import ProtoError, ProtoMessage
PATCH_MAGIC = 0xFEF5F00
SIGNATURE_MAGIC = PATCH_MAGIC + 1
MANIFEST_MAGIC = PATCH_MAGIC + 2
WOUNDS_MAGIC = PATCH_MAGIC + 3
ZIP_INDEX_MAGIC = PATCH_MAGIC + 4
BLOCK_SIZE = 64 * 1024
class WireError(Exception):
pass
def _read_exact(stream: BinaryIO, size: int) -> bytes:
if size <= 0:
return b""
try:
data = stream.read(size)
if data is not None:
return data
except NotImplementedError:
pass
readinto = getattr(stream, "readinto", None)
if readinto is None:
raise WireError("stream does not support read or readinto")
buf = bytearray(size)
view = memoryview(buf)
total = 0
while total < size:
n = readinto(view[total:])
if n is None:
n = 0
if n == 0:
break
total += n
return bytes(view[:total])
def read_varint(stream: BinaryIO) -> int:
shift = 0
result = 0
while True:
b = _read_exact(stream, 1)
if not b:
raise EOFError("unexpected EOF while reading varint")
byte = b[0]
result |= (byte & 0x7F) << shift
if not (byte & 0x80):
return result
shift += 7
if shift > 64:
raise WireError("varint too long")
def write_varint(stream: BinaryIO, value: int) -> None:
if value < 0:
raise WireError("varint cannot encode negative values")
while True:
to_write = value & 0x7F
value >>= 7
if value:
stream.write(bytes([to_write | 0x80]))
else:
stream.write(bytes([to_write]))
break
def read_magic(stream: BinaryIO) -> int:
data = _read_exact(stream, 4)
if len(data) != 4:
raise EOFError("unexpected EOF while reading magic")
return struct.unpack("<i", data)[0]
def write_magic(stream: BinaryIO, magic: int) -> None:
stream.write(struct.pack("<i", int(magic)))
T = TypeVar("T", bound=ProtoMessage)
class WireReader:
def __init__(self, stream: BinaryIO):
self._stream = stream
def read_magic(self) -> int:
return read_magic(self._stream)
def expect_magic(self, magic: int) -> None:
found = self.read_magic()
if found != magic:
raise WireError(f"wrong magic: expected {magic}, got {found}")
def read_message_bytes(self) -> bytes:
length = read_varint(self._stream)
data = _read_exact(self._stream, length)
if len(data) != length:
raise EOFError("unexpected EOF while reading message")
return data
def read_message(self, msg_cls: Type[T]) -> T:
data = self.read_message_bytes()
try:
return msg_cls.from_bytes(data)
except ProtoError as exc:
raise WireError(str(exc)) from exc
def close(self) -> None:
if isinstance(self._stream, io.IOBase):
self._stream.close()
class WireWriter:
def __init__(self, stream: BinaryIO):
self._stream = stream
def write_magic(self, magic: int) -> None:
write_magic(self._stream, magic)
def write_message(self, msg: ProtoMessage | bytes | bytearray | memoryview) -> None:
if isinstance(msg, (bytes, bytearray, memoryview)):
data = bytes(msg)
else:
data = msg.to_bytes()
write_varint(self._stream, len(data))
self._stream.write(data)
def close(self) -> None:
if isinstance(self._stream, io.IOBase):
self._stream.close()