Files
py-pwr/pwr/proto.py
2026-01-18 03:29:42 +09:00

478 lines
13 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, ClassVar, Dict, Tuple
class ProtoError(Exception):
pass
def _unsigned(value: int, bits: int) -> int:
if value < 0:
value = (1 << bits) + value
return value
def _signed(value: int, bits: int) -> int:
sign_bit = 1 << (bits - 1)
if value & sign_bit:
return value - (1 << bits)
return value
def _encode_varint(value: int) -> bytes:
if value < 0:
raise ProtoError("varint cannot encode negative values")
out = bytearray()
while True:
to_write = value & 0x7F
value >>= 7
if value:
out.append(to_write | 0x80)
else:
out.append(to_write)
break
return bytes(out)
def _encode_key(field_number: int, wire_type: int) -> bytes:
return _encode_varint((field_number << 3) | wire_type)
class ProtoReader:
def __init__(self, data: bytes):
self._data = bytes(data)
self._pos = 0
def eof(self) -> bool:
return self._pos >= len(self._data)
def read_varint(self) -> int:
shift = 0
result = 0
while True:
if self._pos >= len(self._data):
raise ProtoError("unexpected EOF while reading varint")
byte = self._data[self._pos]
self._pos += 1
result |= (byte & 0x7F) << shift
if not (byte & 0x80):
return result
shift += 7
if shift > 64:
raise ProtoError("varint too long")
def read_key(self) -> Tuple[int, int]:
key = self.read_varint()
return key >> 3, key & 0x07
def read_length_delimited(self) -> bytes:
length = self.read_varint()
end = self._pos + length
if end > len(self._data):
raise ProtoError("unexpected EOF while reading length-delimited field")
data = self._data[self._pos : end]
self._pos = end
return data
def skip(self, wire_type: int) -> None:
if wire_type == 0:
self.read_varint()
return
if wire_type == 1:
self._pos += 8
return
if wire_type == 2:
length = self.read_varint()
self._pos += length
return
if wire_type == 5:
self._pos += 4
return
raise ProtoError(f"unsupported wire type: {wire_type}")
class ProtoMessage:
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
def to_bytes(self) -> bytes:
out = bytearray()
for field_number in sorted(self.FIELDS):
name, field_type, extra = self.FIELDS[field_number]
value = getattr(self, name)
if value is None:
continue
repeated = False
if field_type.startswith("repeated_"):
repeated = True
field_type = field_type[len("repeated_") :]
values = value if repeated else [value]
for item in values:
if field_type == "message":
if isinstance(item, (bytes, bytearray, memoryview)):
raw = bytes(item)
else:
raw = item.to_bytes()
out.extend(_encode_key(field_number, 2))
out.extend(_encode_varint(len(raw)))
out.extend(raw)
continue
if field_type == "bytes":
raw = bytes(item)
out.extend(_encode_key(field_number, 2))
out.extend(_encode_varint(len(raw)))
out.extend(raw)
continue
if field_type == "string":
raw = str(item).encode("utf-8")
out.extend(_encode_key(field_number, 2))
out.extend(_encode_varint(len(raw)))
out.extend(raw)
continue
if field_type == "bool":
out.extend(_encode_key(field_number, 0))
out.extend(_encode_varint(1 if item else 0))
continue
if field_type == "enum":
if isinstance(item, IntEnum):
item = int(item)
out.extend(_encode_key(field_number, 0))
out.extend(_encode_varint(_unsigned(int(item), 64)))
continue
if field_type in ("int32", "int64"):
bits = 32 if field_type == "int32" else 64
out.extend(_encode_key(field_number, 0))
out.extend(_encode_varint(_unsigned(int(item), bits)))
continue
if field_type in ("uint32", "uint64"):
out.extend(_encode_key(field_number, 0))
out.extend(_encode_varint(int(item)))
continue
raise ProtoError(f"unsupported field type: {field_type}")
return bytes(out)
@classmethod
def from_bytes(cls, data: bytes):
reader = ProtoReader(data)
obj = cls()
while not reader.eof():
field_number, wire_type = reader.read_key()
spec = cls.FIELDS.get(field_number)
if spec is None:
reader.skip(wire_type)
continue
name, field_type, extra = spec
repeated = False
if field_type.startswith("repeated_"):
repeated = True
field_type = field_type[len("repeated_") :]
if field_type == "message":
raw = reader.read_length_delimited()
if extra is None:
value = raw
else:
value = extra.from_bytes(raw)
elif field_type == "bytes":
value = reader.read_length_delimited()
elif field_type == "string":
value = reader.read_length_delimited().decode("utf-8")
elif field_type == "bool":
value = bool(reader.read_varint())
elif field_type == "enum":
raw_value = reader.read_varint()
if extra is None:
value = raw_value
else:
try:
value = extra(raw_value)
except ValueError:
value = raw_value
elif field_type == "int32":
value = _signed(reader.read_varint(), 32)
elif field_type == "int64":
value = _signed(reader.read_varint(), 64)
elif field_type == "uint32":
value = reader.read_varint()
elif field_type == "uint64":
value = reader.read_varint()
else:
raise ProtoError(f"unsupported field type: {field_type}")
if repeated:
current = getattr(obj, name, None)
if current is None:
current = []
setattr(obj, name, current)
current.append(value)
else:
setattr(obj, name, value)
return obj
class CompressionAlgorithm(IntEnum):
NONE = 0
BROTLI = 1
GZIP = 2
ZSTD = 3
class HashAlgorithm(IntEnum):
SHAKE128_32 = 0
CRC32C = 1
class WoundKind(IntEnum):
FILE = 0
SYMLINK = 1
DIR = 2
CLOSED_FILE = 3
class SyncHeaderType(IntEnum):
RSYNC = 0
BSDIFF = 1
class SyncOpType(IntEnum):
BLOCK_RANGE = 0
DATA = 1
HEY_YOU_DID_IT = 2049
class OverlayOpType(IntEnum):
SKIP = 0
FRESH = 1
HEY_YOU_DID_IT = 2040
@dataclass
class CompressionSettings(ProtoMessage):
algorithm: CompressionAlgorithm | int | None = None
quality: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("algorithm", "enum", CompressionAlgorithm),
2: ("quality", "int32", None),
}
@dataclass
class PatchHeader(ProtoMessage):
compression: CompressionSettings | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("compression", "message", CompressionSettings),
}
@dataclass
class SyncHeader(ProtoMessage):
type: SyncHeaderType | int | None = None
file_index: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("type", "enum", SyncHeaderType),
16: ("file_index", "int64", None),
}
@dataclass
class BsdiffHeader(ProtoMessage):
target_index: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("target_index", "int64", None),
}
@dataclass
class SyncOp(ProtoMessage):
type: SyncOpType | int | None = None
file_index: int | None = None
block_index: int | None = None
block_span: int | None = None
data: bytes | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("type", "enum", SyncOpType),
2: ("file_index", "int64", None),
3: ("block_index", "int64", None),
4: ("block_span", "int64", None),
5: ("data", "bytes", None),
}
@dataclass
class SignatureHeader(ProtoMessage):
compression: CompressionSettings | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("compression", "message", CompressionSettings),
}
@dataclass
class BlockHash(ProtoMessage):
weak_hash: int | None = None
strong_hash: bytes | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("weak_hash", "uint32", None),
2: ("strong_hash", "bytes", None),
}
@dataclass
class ManifestHeader(ProtoMessage):
compression: CompressionSettings | None = None
algorithm: HashAlgorithm | int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("compression", "message", CompressionSettings),
2: ("algorithm", "enum", HashAlgorithm),
}
@dataclass
class ManifestBlockHash(ProtoMessage):
hash: bytes | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("hash", "bytes", None),
}
@dataclass
class WoundsHeader(ProtoMessage):
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
@dataclass
class Wound(ProtoMessage):
index: int | None = None
start: int | None = None
end: int | None = None
kind: WoundKind | int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("index", "int64", None),
2: ("start", "int64", None),
3: ("end", "int64", None),
4: ("kind", "enum", WoundKind),
}
@dataclass
class Control(ProtoMessage):
add: bytes | None = None
copy: bytes | None = None
seek: int | None = None
eof: bool | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("add", "bytes", None),
2: ("copy", "bytes", None),
3: ("seek", "int64", None),
4: ("eof", "bool", None),
}
@dataclass
class OverlayHeader(ProtoMessage):
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
@dataclass
class OverlayOp(ProtoMessage):
type: OverlayOpType | int | None = None
len: int | None = None
data: bytes | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("type", "enum", OverlayOpType),
2: ("len", "int64", None),
3: ("data", "bytes", None),
}
@dataclass
class Sample(ProtoMessage):
data: bytes | None = None
number: int | None = None
eof: bool | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("data", "bytes", None),
2: ("number", "int64", None),
3: ("eof", "bool", None),
}
@dataclass
class TlcDir(ProtoMessage):
path: str | None = None
mode: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("path", "string", None),
2: ("mode", "uint32", None),
}
@dataclass
class TlcFile(ProtoMessage):
path: str | None = None
mode: int | None = None
size: int | None = None
offset: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("path", "string", None),
2: ("mode", "uint32", None),
3: ("size", "int64", None),
4: ("offset", "int64", None),
}
@dataclass
class TlcSymlink(ProtoMessage):
path: str | None = None
mode: int | None = None
dest: str | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("path", "string", None),
2: ("mode", "uint32", None),
3: ("dest", "string", None),
}
@dataclass
class TlcContainer(ProtoMessage):
files: list[TlcFile] = field(default_factory=list)
dirs: list[TlcDir] = field(default_factory=list)
symlinks: list[TlcSymlink] = field(default_factory=list)
size: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("files", "repeated_message", TlcFile),
2: ("dirs", "repeated_message", TlcDir),
3: ("symlinks", "repeated_message", TlcSymlink),
16: ("size", "int64", None),
}