init
This commit is contained in:
477
pwr/proto.py
Normal file
477
pwr/proto.py
Normal file
@@ -0,0 +1,477 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, ClassVar, Dict, Tuple
|
||||
|
||||
|
||||
class ProtoError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _unsigned(value: int, bits: int) -> int:
|
||||
if value < 0:
|
||||
value = (1 << bits) + value
|
||||
return value
|
||||
|
||||
|
||||
def _signed(value: int, bits: int) -> int:
|
||||
sign_bit = 1 << (bits - 1)
|
||||
if value & sign_bit:
|
||||
return value - (1 << bits)
|
||||
return value
|
||||
|
||||
|
||||
def _encode_varint(value: int) -> bytes:
|
||||
if value < 0:
|
||||
raise ProtoError("varint cannot encode negative values")
|
||||
out = bytearray()
|
||||
while True:
|
||||
to_write = value & 0x7F
|
||||
value >>= 7
|
||||
if value:
|
||||
out.append(to_write | 0x80)
|
||||
else:
|
||||
out.append(to_write)
|
||||
break
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _encode_key(field_number: int, wire_type: int) -> bytes:
|
||||
return _encode_varint((field_number << 3) | wire_type)
|
||||
|
||||
|
||||
class ProtoReader:
|
||||
def __init__(self, data: bytes):
|
||||
self._data = bytes(data)
|
||||
self._pos = 0
|
||||
|
||||
def eof(self) -> bool:
|
||||
return self._pos >= len(self._data)
|
||||
|
||||
def read_varint(self) -> int:
|
||||
shift = 0
|
||||
result = 0
|
||||
while True:
|
||||
if self._pos >= len(self._data):
|
||||
raise ProtoError("unexpected EOF while reading varint")
|
||||
byte = self._data[self._pos]
|
||||
self._pos += 1
|
||||
result |= (byte & 0x7F) << shift
|
||||
if not (byte & 0x80):
|
||||
return result
|
||||
shift += 7
|
||||
if shift > 64:
|
||||
raise ProtoError("varint too long")
|
||||
|
||||
def read_key(self) -> Tuple[int, int]:
|
||||
key = self.read_varint()
|
||||
return key >> 3, key & 0x07
|
||||
|
||||
def read_length_delimited(self) -> bytes:
|
||||
length = self.read_varint()
|
||||
end = self._pos + length
|
||||
if end > len(self._data):
|
||||
raise ProtoError("unexpected EOF while reading length-delimited field")
|
||||
data = self._data[self._pos:end]
|
||||
self._pos = end
|
||||
return data
|
||||
|
||||
def skip(self, wire_type: int) -> None:
|
||||
if wire_type == 0:
|
||||
self.read_varint()
|
||||
return
|
||||
if wire_type == 1:
|
||||
self._pos += 8
|
||||
return
|
||||
if wire_type == 2:
|
||||
length = self.read_varint()
|
||||
self._pos += length
|
||||
return
|
||||
if wire_type == 5:
|
||||
self._pos += 4
|
||||
return
|
||||
raise ProtoError(f"unsupported wire type: {wire_type}")
|
||||
|
||||
|
||||
class ProtoMessage:
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
out = bytearray()
|
||||
for field_number in sorted(self.FIELDS):
|
||||
name, field_type, extra = self.FIELDS[field_number]
|
||||
value = getattr(self, name)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
repeated = False
|
||||
if field_type.startswith("repeated_"):
|
||||
repeated = True
|
||||
field_type = field_type[len("repeated_") :]
|
||||
|
||||
values = value if repeated else [value]
|
||||
for item in values:
|
||||
if field_type == "message":
|
||||
if isinstance(item, (bytes, bytearray, memoryview)):
|
||||
raw = bytes(item)
|
||||
else:
|
||||
raw = item.to_bytes()
|
||||
out.extend(_encode_key(field_number, 2))
|
||||
out.extend(_encode_varint(len(raw)))
|
||||
out.extend(raw)
|
||||
continue
|
||||
|
||||
if field_type == "bytes":
|
||||
raw = bytes(item)
|
||||
out.extend(_encode_key(field_number, 2))
|
||||
out.extend(_encode_varint(len(raw)))
|
||||
out.extend(raw)
|
||||
continue
|
||||
|
||||
if field_type == "string":
|
||||
raw = str(item).encode("utf-8")
|
||||
out.extend(_encode_key(field_number, 2))
|
||||
out.extend(_encode_varint(len(raw)))
|
||||
out.extend(raw)
|
||||
continue
|
||||
|
||||
if field_type == "bool":
|
||||
out.extend(_encode_key(field_number, 0))
|
||||
out.extend(_encode_varint(1 if item else 0))
|
||||
continue
|
||||
|
||||
if field_type == "enum":
|
||||
if isinstance(item, IntEnum):
|
||||
item = int(item)
|
||||
out.extend(_encode_key(field_number, 0))
|
||||
out.extend(_encode_varint(_unsigned(int(item), 64)))
|
||||
continue
|
||||
|
||||
if field_type in ("int32", "int64"):
|
||||
bits = 32 if field_type == "int32" else 64
|
||||
out.extend(_encode_key(field_number, 0))
|
||||
out.extend(_encode_varint(_unsigned(int(item), bits)))
|
||||
continue
|
||||
|
||||
if field_type in ("uint32", "uint64"):
|
||||
out.extend(_encode_key(field_number, 0))
|
||||
out.extend(_encode_varint(int(item)))
|
||||
continue
|
||||
|
||||
raise ProtoError(f"unsupported field type: {field_type}")
|
||||
|
||||
return bytes(out)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
reader = ProtoReader(data)
|
||||
obj = cls()
|
||||
while not reader.eof():
|
||||
field_number, wire_type = reader.read_key()
|
||||
spec = cls.FIELDS.get(field_number)
|
||||
if spec is None:
|
||||
reader.skip(wire_type)
|
||||
continue
|
||||
name, field_type, extra = spec
|
||||
|
||||
repeated = False
|
||||
if field_type.startswith("repeated_"):
|
||||
repeated = True
|
||||
field_type = field_type[len("repeated_") :]
|
||||
|
||||
if field_type == "message":
|
||||
raw = reader.read_length_delimited()
|
||||
if extra is None:
|
||||
value = raw
|
||||
else:
|
||||
value = extra.from_bytes(raw)
|
||||
elif field_type == "bytes":
|
||||
value = reader.read_length_delimited()
|
||||
elif field_type == "string":
|
||||
value = reader.read_length_delimited().decode("utf-8")
|
||||
elif field_type == "bool":
|
||||
value = bool(reader.read_varint())
|
||||
elif field_type == "enum":
|
||||
raw_value = reader.read_varint()
|
||||
if extra is None:
|
||||
value = raw_value
|
||||
else:
|
||||
try:
|
||||
value = extra(raw_value)
|
||||
except ValueError:
|
||||
value = raw_value
|
||||
elif field_type == "int32":
|
||||
value = _signed(reader.read_varint(), 32)
|
||||
elif field_type == "int64":
|
||||
value = _signed(reader.read_varint(), 64)
|
||||
elif field_type == "uint32":
|
||||
value = reader.read_varint()
|
||||
elif field_type == "uint64":
|
||||
value = reader.read_varint()
|
||||
else:
|
||||
raise ProtoError(f"unsupported field type: {field_type}")
|
||||
|
||||
if repeated:
|
||||
current = getattr(obj, name, None)
|
||||
if current is None:
|
||||
current = []
|
||||
setattr(obj, name, current)
|
||||
current.append(value)
|
||||
else:
|
||||
setattr(obj, name, value)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class CompressionAlgorithm(IntEnum):
|
||||
NONE = 0
|
||||
BROTLI = 1
|
||||
GZIP = 2
|
||||
ZSTD = 3
|
||||
|
||||
|
||||
class HashAlgorithm(IntEnum):
|
||||
SHAKE128_32 = 0
|
||||
CRC32C = 1
|
||||
|
||||
|
||||
class WoundKind(IntEnum):
|
||||
FILE = 0
|
||||
SYMLINK = 1
|
||||
DIR = 2
|
||||
CLOSED_FILE = 3
|
||||
|
||||
|
||||
class SyncHeaderType(IntEnum):
|
||||
RSYNC = 0
|
||||
BSDIFF = 1
|
||||
|
||||
|
||||
class SyncOpType(IntEnum):
|
||||
BLOCK_RANGE = 0
|
||||
DATA = 1
|
||||
HEY_YOU_DID_IT = 2049
|
||||
|
||||
|
||||
class OverlayOpType(IntEnum):
|
||||
SKIP = 0
|
||||
FRESH = 1
|
||||
HEY_YOU_DID_IT = 2040
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionSettings(ProtoMessage):
|
||||
algorithm: CompressionAlgorithm | int | None = None
|
||||
quality: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("algorithm", "enum", CompressionAlgorithm),
|
||||
2: ("quality", "int32", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchHeader(ProtoMessage):
|
||||
compression: CompressionSettings | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("compression", "message", CompressionSettings),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncHeader(ProtoMessage):
|
||||
type: SyncHeaderType | int | None = None
|
||||
file_index: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("type", "enum", SyncHeaderType),
|
||||
16: ("file_index", "int64", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BsdiffHeader(ProtoMessage):
|
||||
target_index: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("target_index", "int64", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncOp(ProtoMessage):
|
||||
type: SyncOpType | int | None = None
|
||||
file_index: int | None = None
|
||||
block_index: int | None = None
|
||||
block_span: int | None = None
|
||||
data: bytes | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("type", "enum", SyncOpType),
|
||||
2: ("file_index", "int64", None),
|
||||
3: ("block_index", "int64", None),
|
||||
4: ("block_span", "int64", None),
|
||||
5: ("data", "bytes", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignatureHeader(ProtoMessage):
|
||||
compression: CompressionSettings | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("compression", "message", CompressionSettings),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockHash(ProtoMessage):
|
||||
weak_hash: int | None = None
|
||||
strong_hash: bytes | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("weak_hash", "uint32", None),
|
||||
2: ("strong_hash", "bytes", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManifestHeader(ProtoMessage):
|
||||
compression: CompressionSettings | None = None
|
||||
algorithm: HashAlgorithm | int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("compression", "message", CompressionSettings),
|
||||
2: ("algorithm", "enum", HashAlgorithm),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManifestBlockHash(ProtoMessage):
|
||||
hash: bytes | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("hash", "bytes", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WoundsHeader(ProtoMessage):
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wound(ProtoMessage):
|
||||
index: int | None = None
|
||||
start: int | None = None
|
||||
end: int | None = None
|
||||
kind: WoundKind | int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("index", "int64", None),
|
||||
2: ("start", "int64", None),
|
||||
3: ("end", "int64", None),
|
||||
4: ("kind", "enum", WoundKind),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Control(ProtoMessage):
|
||||
add: bytes | None = None
|
||||
copy: bytes | None = None
|
||||
seek: int | None = None
|
||||
eof: bool | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("add", "bytes", None),
|
||||
2: ("copy", "bytes", None),
|
||||
3: ("seek", "int64", None),
|
||||
4: ("eof", "bool", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OverlayHeader(ProtoMessage):
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OverlayOp(ProtoMessage):
|
||||
type: OverlayOpType | int | None = None
|
||||
len: int | None = None
|
||||
data: bytes | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("type", "enum", OverlayOpType),
|
||||
2: ("len", "int64", None),
|
||||
3: ("data", "bytes", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sample(ProtoMessage):
|
||||
data: bytes | None = None
|
||||
number: int | None = None
|
||||
eof: bool | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("data", "bytes", None),
|
||||
2: ("number", "int64", None),
|
||||
3: ("eof", "bool", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlcDir(ProtoMessage):
|
||||
path: str | None = None
|
||||
mode: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("path", "string", None),
|
||||
2: ("mode", "uint32", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlcFile(ProtoMessage):
|
||||
path: str | None = None
|
||||
mode: int | None = None
|
||||
size: int | None = None
|
||||
offset: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("path", "string", None),
|
||||
2: ("mode", "uint32", None),
|
||||
3: ("size", "int64", None),
|
||||
4: ("offset", "int64", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlcSymlink(ProtoMessage):
|
||||
path: str | None = None
|
||||
mode: int | None = None
|
||||
dest: str | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("path", "string", None),
|
||||
2: ("mode", "uint32", None),
|
||||
3: ("dest", "string", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlcContainer(ProtoMessage):
|
||||
files: list[TlcFile] = field(default_factory=list)
|
||||
dirs: list[TlcDir] = field(default_factory=list)
|
||||
symlinks: list[TlcSymlink] = field(default_factory=list)
|
||||
size: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("files", "repeated_message", TlcFile),
|
||||
2: ("dirs", "repeated_message", TlcDir),
|
||||
3: ("symlinks", "repeated_message", TlcSymlink),
|
||||
16: ("size", "int64", None),
|
||||
}
|
||||
Reference in New Issue
Block a user