This commit is contained in:
senstella
2026-01-17 14:13:43 +09:00
commit ca85a52839
12 changed files with 1692 additions and 0 deletions

97
pwr/__init__.py Normal file
View File

@@ -0,0 +1,97 @@
from .apply import (
FilePool,
PatchApplyError,
apply_bsdiff_controls,
apply_patch,
apply_patch_to_folders,
apply_rsync_ops,
)
from .compression import CompressionError
from .formats import FilePatch, ManifestReader, PatchReader, SignatureReader, WoundsReader
from .proto import (
BlockHash,
BsdiffHeader,
CompressionAlgorithm,
CompressionSettings,
Control,
HashAlgorithm,
ManifestBlockHash,
ManifestHeader,
OverlayHeader,
OverlayOp,
OverlayOpType,
PatchHeader,
Sample,
SignatureHeader,
SyncHeader,
SyncHeaderType,
SyncOp,
SyncOpType,
TlcContainer,
TlcDir,
TlcFile,
TlcSymlink,
Wound,
WoundKind,
WoundsHeader,
)
from .wire import (
BLOCK_SIZE,
MANIFEST_MAGIC,
PATCH_MAGIC,
SIGNATURE_MAGIC,
WOUNDS_MAGIC,
ZIP_INDEX_MAGIC,
WireError,
WireReader,
WireWriter,
)
__all__ = [
"FilePool",
"PatchApplyError",
"apply_bsdiff_controls",
"apply_patch",
"apply_patch_to_folders",
"apply_rsync_ops",
"CompressionError",
"FilePatch",
"ManifestReader",
"PatchReader",
"SignatureReader",
"WoundsReader",
"BlockHash",
"BsdiffHeader",
"CompressionAlgorithm",
"CompressionSettings",
"Control",
"HashAlgorithm",
"ManifestBlockHash",
"ManifestHeader",
"OverlayHeader",
"OverlayOp",
"OverlayOpType",
"PatchHeader",
"Sample",
"SignatureHeader",
"SyncHeader",
"SyncHeaderType",
"SyncOp",
"SyncOpType",
"TlcContainer",
"TlcDir",
"TlcFile",
"TlcSymlink",
"Wound",
"WoundKind",
"WoundsHeader",
"BLOCK_SIZE",
"MANIFEST_MAGIC",
"PATCH_MAGIC",
"SIGNATURE_MAGIC",
"WOUNDS_MAGIC",
"ZIP_INDEX_MAGIC",
"WireError",
"WireReader",
"WireWriter",
]

198
pwr/apply.py Normal file
View File

@@ -0,0 +1,198 @@
from __future__ import annotations
import os
from typing import BinaryIO
from .formats import FilePatch, PatchReader
from .proto import Control, SyncOp, SyncOpType
from .proto import TlcContainer, TlcDir, TlcFile, TlcSymlink
from .wire import BLOCK_SIZE
class PatchApplyError(Exception):
pass
_MODE_MASK = 0o7777
class FilePool:
def __init__(self, paths: list[str]):
self._paths = list(paths)
self._handles: dict[int, BinaryIO] = {}
def open(self, index: int) -> BinaryIO:
if index < 0 or index >= len(self._paths):
raise PatchApplyError(f"file index out of range: {index}")
handle = self._handles.get(index)
if handle is None:
handle = open(self._paths[index], "rb")
self._handles[index] = handle
return handle
def size(self, index: int) -> int:
if index < 0 or index >= len(self._paths):
raise PatchApplyError(f"file index out of range: {index}")
return os.path.getsize(self._paths[index])
def close(self) -> None:
for handle in self._handles.values():
handle.close()
self._handles.clear()
def _copy_range(dst: BinaryIO, src: BinaryIO, length: int, buffer_size: int = 32 * 1024) -> None:
remaining = length
while remaining > 0:
chunk = src.read(min(buffer_size, remaining))
if not chunk:
raise PatchApplyError("unexpected EOF while copying block range")
dst.write(chunk)
remaining -= len(chunk)
def apply_rsync_ops(ops: list[SyncOp], target_pool: FilePool, output: BinaryIO) -> None:
for op in ops:
if op.type == SyncOpType.DATA:
output.write(op.data or b"")
continue
if op.type != SyncOpType.BLOCK_RANGE:
raise PatchApplyError(f"unsupported sync op type: {op.type}")
if op.file_index is None or op.block_index is None or op.block_span is None:
raise PatchApplyError("missing fields in block range op")
if op.block_span <= 0:
raise PatchApplyError("invalid block span in block range op")
file_size = target_pool.size(op.file_index)
last_block_index = op.block_index + op.block_span - 1
last_block_size = BLOCK_SIZE
if BLOCK_SIZE * (last_block_index + 1) > file_size:
last_block_size = file_size % BLOCK_SIZE
op_size = (op.block_span - 1) * BLOCK_SIZE + last_block_size
src = target_pool.open(op.file_index)
src.seek(op.block_index * BLOCK_SIZE)
_copy_range(output, src, op_size)
def apply_bsdiff_controls(controls: list[Control], old: BinaryIO, output: BinaryIO) -> None:
old_offset = 0
for ctrl in controls:
if ctrl.eof:
break
add = ctrl.add or b""
copy = ctrl.copy or b""
seek = ctrl.seek or 0
if add:
old.seek(old_offset)
old_chunk = old.read(len(add))
if len(old_chunk) != len(add):
raise PatchApplyError("unexpected EOF while reading bsdiff add data")
out = bytes(((a + b) & 0xFF) for a, b in zip(old_chunk, add))
output.write(out)
old_offset += len(add)
if copy:
output.write(copy)
old_offset += seek
def _ensure_parent(path: str) -> None:
parent = os.path.dirname(path)
if parent:
os.makedirs(parent, exist_ok=True)
def apply_patch(patch_reader: PatchReader, target_paths: list[str], output_paths: list[str]) -> None:
pool = FilePool(target_paths)
try:
for entry in patch_reader.iter_file_entries():
if entry.sync_header.file_index is None:
raise PatchApplyError("missing file_index in sync header")
out_index = entry.sync_header.file_index
if out_index < 0 or out_index >= len(output_paths):
raise PatchApplyError(f"output index out of range: {out_index}")
out_path = output_paths[out_index]
_ensure_parent(out_path)
with open(out_path, "wb") as out:
if entry.is_rsync():
if entry.sync_ops is None:
raise PatchApplyError("missing rsync ops")
apply_rsync_ops(entry.sync_ops, pool, out)
elif entry.is_bsdiff():
if entry.bsdiff_header is None or entry.bsdiff_controls is None:
raise PatchApplyError("missing bsdiff data")
if entry.bsdiff_header.target_index is None:
raise PatchApplyError("missing target_index in bsdiff header")
old = pool.open(entry.bsdiff_header.target_index)
apply_bsdiff_controls(entry.bsdiff_controls, old, out)
else:
raise PatchApplyError("unknown file patch type")
finally:
pool.close()
def _container_file_paths(container: TlcContainer, root: str) -> list[str]:
paths = []
for f in container.files:
if f.path is None:
raise PatchApplyError("container file missing path")
paths.append(os.path.join(root, f.path))
return paths
def _ensure_dirs(container: TlcContainer, root: str) -> None:
for d in container.dirs:
if d.path is None:
continue
path = os.path.join(root, d.path)
os.makedirs(path, exist_ok=True)
if d.mode is not None:
mode = int(d.mode) & _MODE_MASK
try:
os.chmod(path, mode)
except (PermissionError, OSError):
pass
def _ensure_symlinks(container: TlcContainer, root: str) -> None:
for s in container.symlinks:
if s.path is None or s.dest is None:
continue
path = os.path.join(root, s.path)
if os.path.lexists(path):
continue
parent = os.path.dirname(path)
if parent:
os.makedirs(parent, exist_ok=True)
try:
os.symlink(s.dest, path)
except (AttributeError, OSError):
continue
def apply_patch_to_folders(patch_path: str, target_root: str, output_root: str) -> None:
with PatchReader.open(
patch_path,
container_decoder=(TlcContainer.from_bytes, TlcContainer.from_bytes),
) as reader:
if reader.target_container is None or reader.source_container is None:
raise PatchApplyError("missing containers in patch")
target_container: TlcContainer = reader.target_container
source_container: TlcContainer = reader.source_container
_ensure_dirs(source_container, output_root)
_ensure_symlinks(source_container, output_root)
target_paths = _container_file_paths(target_container, target_root)
output_paths = _container_file_paths(source_container, output_root)
apply_patch(reader, target_paths, output_paths)

158
pwr/compression.py Normal file
View File

@@ -0,0 +1,158 @@
from __future__ import annotations
import io
import gzip
from typing import BinaryIO
from .proto import CompressionAlgorithm, CompressionSettings
class CompressionError(Exception):
pass
def _normalize_algorithm(algorithm: CompressionAlgorithm | int | None) -> CompressionAlgorithm:
if algorithm is None:
return CompressionAlgorithm.NONE
if isinstance(algorithm, CompressionAlgorithm):
return algorithm
try:
return CompressionAlgorithm(int(algorithm))
except ValueError as exc:
raise CompressionError(f"unknown compression algorithm: {algorithm}") from exc
def _brotli_module():
try:
import brotli # type: ignore
return brotli
except ImportError:
try:
import brotlicffi as brotli # type: ignore
return brotli
except ImportError as exc:
raise CompressionError("brotli module not available") from exc
class _BrotliReader(io.RawIOBase):
def __init__(self, raw: BinaryIO):
self._raw = raw
brotli = _brotli_module()
self._decompressor = brotli.Decompressor()
self._buffer = bytearray()
self._eof = False
def readable(self) -> bool:
return True
def _finish(self) -> bytes:
if hasattr(self._decompressor, "finish"):
return self._decompressor.finish()
if hasattr(self._decompressor, "flush"):
return self._decompressor.flush()
return b""
def read(self, size: int = -1) -> bytes:
if size == 0:
return b""
while not self._eof and (size < 0 or len(self._buffer) < size):
chunk = self._raw.read(8192)
if not chunk:
self._buffer.extend(self._finish())
self._eof = True
break
self._buffer.extend(self._decompressor.process(chunk))
if size < 0:
data = bytes(self._buffer)
self._buffer.clear()
return data
data = bytes(self._buffer[:size])
del self._buffer[:size]
return data
def readinto(self, b) -> int:
data = self.read(len(b))
n = len(data)
b[:n] = data
return n
def close(self) -> None:
if self.closed:
return
try:
if hasattr(self._raw, "close"):
self._raw.close()
finally:
super().close()
class _BrotliWriter(io.RawIOBase):
def __init__(self, raw: BinaryIO, quality: int | None):
self._raw = raw
brotli = _brotli_module()
kwargs = {}
if quality is not None:
kwargs["quality"] = int(quality)
self._compressor = brotli.Compressor(**kwargs)
def writable(self) -> bool:
return True
def write(self, b: bytes) -> int:
out = self._compressor.process(b)
self._raw.write(out)
return len(b)
def close(self) -> None:
if self.closed:
return
try:
tail = self._compressor.finish()
if tail:
self._raw.write(tail)
if hasattr(self._raw, "close"):
self._raw.close()
finally:
super().close()
def open_decompressed_reader(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO:
algorithm = _normalize_algorithm(compression.algorithm if compression else None)
if algorithm == CompressionAlgorithm.NONE:
return stream
if algorithm == CompressionAlgorithm.GZIP:
return gzip.GzipFile(fileobj=stream, mode="rb")
if algorithm == CompressionAlgorithm.BROTLI:
return io.BufferedReader(_BrotliReader(stream))
if algorithm == CompressionAlgorithm.ZSTD:
try:
import zstandard as zstd # type: ignore
except ImportError as exc:
raise CompressionError("zstandard module not available") from exc
return zstd.ZstdDecompressor().stream_reader(stream)
raise CompressionError(f"unsupported compression algorithm: {algorithm}")
def open_compressed_writer(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO:
algorithm = _normalize_algorithm(compression.algorithm if compression else None)
quality = compression.quality if compression else None
if algorithm == CompressionAlgorithm.NONE:
return stream
if algorithm == CompressionAlgorithm.GZIP:
level = 9 if quality is None else int(quality)
return gzip.GzipFile(fileobj=stream, mode="wb", compresslevel=level)
if algorithm == CompressionAlgorithm.BROTLI:
return io.BufferedWriter(_BrotliWriter(stream, quality))
if algorithm == CompressionAlgorithm.ZSTD:
try:
import zstandard as zstd # type: ignore
except ImportError as exc:
raise CompressionError("zstandard module not available") from exc
level = 3 if quality is None else int(quality)
return zstd.ZstdCompressor(level=level).stream_writer(stream)
raise CompressionError(f"unsupported compression algorithm: {algorithm}")

286
pwr/formats.py Normal file
View File

@@ -0,0 +1,286 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Iterator
from .compression import open_decompressed_reader
from .proto import (
BlockHash,
BsdiffHeader,
Control,
ManifestBlockHash,
ManifestHeader,
PatchHeader,
SignatureHeader,
SyncHeader,
SyncHeaderType,
SyncOp,
SyncOpType,
Wound,
WoundsHeader,
)
from .wire import (
MANIFEST_MAGIC,
PATCH_MAGIC,
SIGNATURE_MAGIC,
WOUNDS_MAGIC,
WireError,
WireReader,
)
ContainerDecoder = Callable[[bytes], Any]
def _split_decoder(decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None):
if decoder is None:
return None, None
if isinstance(decoder, tuple):
if len(decoder) != 2:
raise ValueError("container decoder tuple must have two elements")
return decoder
return decoder, decoder
@dataclass
class FilePatch:
sync_header: SyncHeader
sync_ops: list[SyncOp] | None = None
bsdiff_header: BsdiffHeader | None = None
bsdiff_controls: list[Control] | None = None
def is_rsync(self) -> bool:
return self.sync_ops is not None
def is_bsdiff(self) -> bool:
return self.bsdiff_header is not None
class PatchReader:
def __init__(self, stream, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None):
self._stream = stream
self._raw_wire = WireReader(stream)
self._wire: WireReader | None = None
self.header: PatchHeader | None = None
self.target_container_raw: bytes | None = None
self.source_container_raw: bytes | None = None
self.target_container: Any | None = None
self.source_container: Any | None = None
self._init_stream(container_decoder)
@classmethod
def open(cls, path: str, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None):
return cls(open(path, "rb"), container_decoder=container_decoder)
def _init_stream(self, container_decoder):
self._raw_wire.expect_magic(PATCH_MAGIC)
self.header = self._raw_wire.read_message(PatchHeader)
decompressed = open_decompressed_reader(self._stream, self.header.compression)
self._wire = WireReader(decompressed)
target_decoder, source_decoder = _split_decoder(container_decoder)
self.target_container_raw, self.target_container = self._read_container(target_decoder)
self.source_container_raw, self.source_container = self._read_container(source_decoder)
def _read_container(self, decoder: ContainerDecoder | None):
assert self._wire is not None
raw = self._wire.read_message_bytes()
parsed = decoder(raw) if decoder else None
return raw, parsed
def iter_file_entries(self) -> Iterator[FilePatch]:
assert self._wire is not None
file_index = 0
while True:
try:
sync_header = self._wire.read_message(SyncHeader)
except EOFError:
return
if sync_header.file_index is None:
sync_header.file_index = file_index
else:
file_index = int(sync_header.file_index)
header_type = sync_header.type
if header_type is None:
header_type = SyncHeaderType.RSYNC
else:
try:
header_type = SyncHeaderType(int(header_type))
except ValueError:
raise WireError(f"unknown sync header type: {header_type}")
if header_type == SyncHeaderType.BSDIFF:
bsdiff_header = self._wire.read_message(BsdiffHeader)
controls: list[Control] = []
while True:
ctrl = self._wire.read_message(Control)
controls.append(ctrl)
if ctrl.eof:
break
sentinel = self._wire.read_message(SyncOp)
if sentinel.type != SyncOpType.HEY_YOU_DID_IT:
raise WireError("expected BSDIFF sentinel SyncOp")
yield FilePatch(
sync_header=sync_header,
bsdiff_header=bsdiff_header,
bsdiff_controls=controls,
)
file_index += 1
continue
if header_type != SyncHeaderType.RSYNC:
raise WireError(f"unsupported sync header type: {header_type}")
ops: list[SyncOp] = []
while True:
op = self._wire.read_message(SyncOp)
if op.type == SyncOpType.HEY_YOU_DID_IT:
break
ops.append(op)
yield FilePatch(sync_header=sync_header, sync_ops=ops)
file_index += 1
def close(self) -> None:
if self._wire:
self._wire.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
class SignatureReader:
def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None):
self._stream = stream
self._raw_wire = WireReader(stream)
self._wire: WireReader | None = None
self.header: SignatureHeader | None = None
self.container_raw: bytes | None = None
self.container: Any | None = None
self._init_stream(container_decoder)
@classmethod
def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None):
return cls(open(path, "rb"), container_decoder=container_decoder)
def _init_stream(self, container_decoder):
self._raw_wire.expect_magic(SIGNATURE_MAGIC)
self.header = self._raw_wire.read_message(SignatureHeader)
decompressed = open_decompressed_reader(self._stream, self.header.compression)
self._wire = WireReader(decompressed)
self.container_raw = self._wire.read_message_bytes()
self.container = container_decoder(self.container_raw) if container_decoder else None
def iter_block_hashes(self) -> Iterator[BlockHash]:
assert self._wire is not None
while True:
try:
yield self._wire.read_message(BlockHash)
except EOFError:
return
def close(self) -> None:
if self._wire:
self._wire.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
class ManifestReader:
def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None):
self._stream = stream
self._raw_wire = WireReader(stream)
self._wire: WireReader | None = None
self.header: ManifestHeader | None = None
self.container_raw: bytes | None = None
self.container: Any | None = None
self._init_stream(container_decoder)
@classmethod
def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None):
return cls(open(path, "rb"), container_decoder=container_decoder)
def _init_stream(self, container_decoder):
self._raw_wire.expect_magic(MANIFEST_MAGIC)
self.header = self._raw_wire.read_message(ManifestHeader)
decompressed = open_decompressed_reader(self._stream, self.header.compression)
self._wire = WireReader(decompressed)
self.container_raw = self._wire.read_message_bytes()
self.container = container_decoder(self.container_raw) if container_decoder else None
def iter_block_hashes(self) -> Iterator[ManifestBlockHash]:
assert self._wire is not None
while True:
try:
yield self._wire.read_message(ManifestBlockHash)
except EOFError:
return
def close(self) -> None:
if self._wire:
self._wire.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
class WoundsReader:
def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None):
self._stream = stream
self._wire = WireReader(stream)
self.header: WoundsHeader | None = None
self.container_raw: bytes | None = None
self.container: Any | None = None
self._init_stream(container_decoder)
@classmethod
def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None):
return cls(open(path, "rb"), container_decoder=container_decoder)
def _init_stream(self, container_decoder):
self._wire.expect_magic(WOUNDS_MAGIC)
self.header = self._wire.read_message(WoundsHeader)
self.container_raw = self._wire.read_message_bytes()
self.container = container_decoder(self.container_raw) if container_decoder else None
def iter_wounds(self) -> Iterator[Wound]:
while True:
try:
yield self._wire.read_message(Wound)
except EOFError:
return
def close(self) -> None:
self._wire.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()

477
pwr/proto.py Normal file
View File

@@ -0,0 +1,477 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, ClassVar, Dict, Tuple
class ProtoError(Exception):
pass
def _unsigned(value: int, bits: int) -> int:
if value < 0:
value = (1 << bits) + value
return value
def _signed(value: int, bits: int) -> int:
sign_bit = 1 << (bits - 1)
if value & sign_bit:
return value - (1 << bits)
return value
def _encode_varint(value: int) -> bytes:
if value < 0:
raise ProtoError("varint cannot encode negative values")
out = bytearray()
while True:
to_write = value & 0x7F
value >>= 7
if value:
out.append(to_write | 0x80)
else:
out.append(to_write)
break
return bytes(out)
def _encode_key(field_number: int, wire_type: int) -> bytes:
return _encode_varint((field_number << 3) | wire_type)
class ProtoReader:
def __init__(self, data: bytes):
self._data = bytes(data)
self._pos = 0
def eof(self) -> bool:
return self._pos >= len(self._data)
def read_varint(self) -> int:
shift = 0
result = 0
while True:
if self._pos >= len(self._data):
raise ProtoError("unexpected EOF while reading varint")
byte = self._data[self._pos]
self._pos += 1
result |= (byte & 0x7F) << shift
if not (byte & 0x80):
return result
shift += 7
if shift > 64:
raise ProtoError("varint too long")
def read_key(self) -> Tuple[int, int]:
key = self.read_varint()
return key >> 3, key & 0x07
def read_length_delimited(self) -> bytes:
length = self.read_varint()
end = self._pos + length
if end > len(self._data):
raise ProtoError("unexpected EOF while reading length-delimited field")
data = self._data[self._pos:end]
self._pos = end
return data
def skip(self, wire_type: int) -> None:
if wire_type == 0:
self.read_varint()
return
if wire_type == 1:
self._pos += 8
return
if wire_type == 2:
length = self.read_varint()
self._pos += length
return
if wire_type == 5:
self._pos += 4
return
raise ProtoError(f"unsupported wire type: {wire_type}")
class ProtoMessage:
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
def to_bytes(self) -> bytes:
out = bytearray()
for field_number in sorted(self.FIELDS):
name, field_type, extra = self.FIELDS[field_number]
value = getattr(self, name)
if value is None:
continue
repeated = False
if field_type.startswith("repeated_"):
repeated = True
field_type = field_type[len("repeated_") :]
values = value if repeated else [value]
for item in values:
if field_type == "message":
if isinstance(item, (bytes, bytearray, memoryview)):
raw = bytes(item)
else:
raw = item.to_bytes()
out.extend(_encode_key(field_number, 2))
out.extend(_encode_varint(len(raw)))
out.extend(raw)
continue
if field_type == "bytes":
raw = bytes(item)
out.extend(_encode_key(field_number, 2))
out.extend(_encode_varint(len(raw)))
out.extend(raw)
continue
if field_type == "string":
raw = str(item).encode("utf-8")
out.extend(_encode_key(field_number, 2))
out.extend(_encode_varint(len(raw)))
out.extend(raw)
continue
if field_type == "bool":
out.extend(_encode_key(field_number, 0))
out.extend(_encode_varint(1 if item else 0))
continue
if field_type == "enum":
if isinstance(item, IntEnum):
item = int(item)
out.extend(_encode_key(field_number, 0))
out.extend(_encode_varint(_unsigned(int(item), 64)))
continue
if field_type in ("int32", "int64"):
bits = 32 if field_type == "int32" else 64
out.extend(_encode_key(field_number, 0))
out.extend(_encode_varint(_unsigned(int(item), bits)))
continue
if field_type in ("uint32", "uint64"):
out.extend(_encode_key(field_number, 0))
out.extend(_encode_varint(int(item)))
continue
raise ProtoError(f"unsupported field type: {field_type}")
return bytes(out)
@classmethod
def from_bytes(cls, data: bytes):
reader = ProtoReader(data)
obj = cls()
while not reader.eof():
field_number, wire_type = reader.read_key()
spec = cls.FIELDS.get(field_number)
if spec is None:
reader.skip(wire_type)
continue
name, field_type, extra = spec
repeated = False
if field_type.startswith("repeated_"):
repeated = True
field_type = field_type[len("repeated_") :]
if field_type == "message":
raw = reader.read_length_delimited()
if extra is None:
value = raw
else:
value = extra.from_bytes(raw)
elif field_type == "bytes":
value = reader.read_length_delimited()
elif field_type == "string":
value = reader.read_length_delimited().decode("utf-8")
elif field_type == "bool":
value = bool(reader.read_varint())
elif field_type == "enum":
raw_value = reader.read_varint()
if extra is None:
value = raw_value
else:
try:
value = extra(raw_value)
except ValueError:
value = raw_value
elif field_type == "int32":
value = _signed(reader.read_varint(), 32)
elif field_type == "int64":
value = _signed(reader.read_varint(), 64)
elif field_type == "uint32":
value = reader.read_varint()
elif field_type == "uint64":
value = reader.read_varint()
else:
raise ProtoError(f"unsupported field type: {field_type}")
if repeated:
current = getattr(obj, name, None)
if current is None:
current = []
setattr(obj, name, current)
current.append(value)
else:
setattr(obj, name, value)
return obj
class CompressionAlgorithm(IntEnum):
NONE = 0
BROTLI = 1
GZIP = 2
ZSTD = 3
class HashAlgorithm(IntEnum):
SHAKE128_32 = 0
CRC32C = 1
class WoundKind(IntEnum):
FILE = 0
SYMLINK = 1
DIR = 2
CLOSED_FILE = 3
class SyncHeaderType(IntEnum):
RSYNC = 0
BSDIFF = 1
class SyncOpType(IntEnum):
BLOCK_RANGE = 0
DATA = 1
HEY_YOU_DID_IT = 2049
class OverlayOpType(IntEnum):
SKIP = 0
FRESH = 1
HEY_YOU_DID_IT = 2040
@dataclass
class CompressionSettings(ProtoMessage):
algorithm: CompressionAlgorithm | int | None = None
quality: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("algorithm", "enum", CompressionAlgorithm),
2: ("quality", "int32", None),
}
@dataclass
class PatchHeader(ProtoMessage):
compression: CompressionSettings | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("compression", "message", CompressionSettings),
}
@dataclass
class SyncHeader(ProtoMessage):
type: SyncHeaderType | int | None = None
file_index: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("type", "enum", SyncHeaderType),
16: ("file_index", "int64", None),
}
@dataclass
class BsdiffHeader(ProtoMessage):
target_index: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("target_index", "int64", None),
}
@dataclass
class SyncOp(ProtoMessage):
type: SyncOpType | int | None = None
file_index: int | None = None
block_index: int | None = None
block_span: int | None = None
data: bytes | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("type", "enum", SyncOpType),
2: ("file_index", "int64", None),
3: ("block_index", "int64", None),
4: ("block_span", "int64", None),
5: ("data", "bytes", None),
}
@dataclass
class SignatureHeader(ProtoMessage):
compression: CompressionSettings | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("compression", "message", CompressionSettings),
}
@dataclass
class BlockHash(ProtoMessage):
weak_hash: int | None = None
strong_hash: bytes | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("weak_hash", "uint32", None),
2: ("strong_hash", "bytes", None),
}
@dataclass
class ManifestHeader(ProtoMessage):
compression: CompressionSettings | None = None
algorithm: HashAlgorithm | int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("compression", "message", CompressionSettings),
2: ("algorithm", "enum", HashAlgorithm),
}
@dataclass
class ManifestBlockHash(ProtoMessage):
hash: bytes | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("hash", "bytes", None),
}
@dataclass
class WoundsHeader(ProtoMessage):
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
@dataclass
class Wound(ProtoMessage):
index: int | None = None
start: int | None = None
end: int | None = None
kind: WoundKind | int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("index", "int64", None),
2: ("start", "int64", None),
3: ("end", "int64", None),
4: ("kind", "enum", WoundKind),
}
@dataclass
class Control(ProtoMessage):
add: bytes | None = None
copy: bytes | None = None
seek: int | None = None
eof: bool | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("add", "bytes", None),
2: ("copy", "bytes", None),
3: ("seek", "int64", None),
4: ("eof", "bool", None),
}
@dataclass
class OverlayHeader(ProtoMessage):
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
@dataclass
class OverlayOp(ProtoMessage):
type: OverlayOpType | int | None = None
len: int | None = None
data: bytes | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("type", "enum", OverlayOpType),
2: ("len", "int64", None),
3: ("data", "bytes", None),
}
@dataclass
class Sample(ProtoMessage):
data: bytes | None = None
number: int | None = None
eof: bool | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("data", "bytes", None),
2: ("number", "int64", None),
3: ("eof", "bool", None),
}
@dataclass
class TlcDir(ProtoMessage):
path: str | None = None
mode: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("path", "string", None),
2: ("mode", "uint32", None),
}
@dataclass
class TlcFile(ProtoMessage):
path: str | None = None
mode: int | None = None
size: int | None = None
offset: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("path", "string", None),
2: ("mode", "uint32", None),
3: ("size", "int64", None),
4: ("offset", "int64", None),
}
@dataclass
class TlcSymlink(ProtoMessage):
path: str | None = None
mode: int | None = None
dest: str | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("path", "string", None),
2: ("mode", "uint32", None),
3: ("dest", "string", None),
}
@dataclass
class TlcContainer(ProtoMessage):
files: list[TlcFile] = field(default_factory=list)
dirs: list[TlcDir] = field(default_factory=list)
symlinks: list[TlcSymlink] = field(default_factory=list)
size: int | None = None
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
1: ("files", "repeated_message", TlcFile),
2: ("dirs", "repeated_message", TlcDir),
3: ("symlinks", "repeated_message", TlcSymlink),
16: ("size", "int64", None),
}

140
pwr/wire.py Normal file
View File

@@ -0,0 +1,140 @@
from __future__ import annotations
import io
import struct
from typing import BinaryIO, Type, TypeVar
from .proto import ProtoError, ProtoMessage
PATCH_MAGIC = 0xFEF5F00
SIGNATURE_MAGIC = PATCH_MAGIC + 1
MANIFEST_MAGIC = PATCH_MAGIC + 2
WOUNDS_MAGIC = PATCH_MAGIC + 3
ZIP_INDEX_MAGIC = PATCH_MAGIC + 4
BLOCK_SIZE = 64 * 1024
class WireError(Exception):
pass
def _read_exact(stream: BinaryIO, size: int) -> bytes:
if size <= 0:
return b""
try:
data = stream.read(size)
if data is not None:
return data
except NotImplementedError:
pass
readinto = getattr(stream, "readinto", None)
if readinto is None:
raise WireError("stream does not support read or readinto")
buf = bytearray(size)
view = memoryview(buf)
total = 0
while total < size:
n = readinto(view[total:])
if n is None:
n = 0
if n == 0:
break
total += n
return bytes(view[:total])
def read_varint(stream: BinaryIO) -> int:
shift = 0
result = 0
while True:
b = _read_exact(stream, 1)
if not b:
raise EOFError("unexpected EOF while reading varint")
byte = b[0]
result |= (byte & 0x7F) << shift
if not (byte & 0x80):
return result
shift += 7
if shift > 64:
raise WireError("varint too long")
def write_varint(stream: BinaryIO, value: int) -> None:
if value < 0:
raise WireError("varint cannot encode negative values")
while True:
to_write = value & 0x7F
value >>= 7
if value:
stream.write(bytes([to_write | 0x80]))
else:
stream.write(bytes([to_write]))
break
def read_magic(stream: BinaryIO) -> int:
data = _read_exact(stream, 4)
if len(data) != 4:
raise EOFError("unexpected EOF while reading magic")
return struct.unpack("<i", data)[0]
def write_magic(stream: BinaryIO, magic: int) -> None:
stream.write(struct.pack("<i", int(magic)))
T = TypeVar("T", bound=ProtoMessage)
class WireReader:
def __init__(self, stream: BinaryIO):
self._stream = stream
def read_magic(self) -> int:
return read_magic(self._stream)
def expect_magic(self, magic: int) -> None:
found = self.read_magic()
if found != magic:
raise WireError(f"wrong magic: expected {magic}, got {found}")
def read_message_bytes(self) -> bytes:
length = read_varint(self._stream)
data = _read_exact(self._stream, length)
if len(data) != length:
raise EOFError("unexpected EOF while reading message")
return data
def read_message(self, msg_cls: Type[T]) -> T:
data = self.read_message_bytes()
try:
return msg_cls.from_bytes(data)
except ProtoError as exc:
raise WireError(str(exc)) from exc
def close(self) -> None:
if isinstance(self._stream, io.IOBase):
self._stream.close()
class WireWriter:
def __init__(self, stream: BinaryIO):
self._stream = stream
def write_magic(self, magic: int) -> None:
write_magic(self._stream, magic)
def write_message(self, msg: ProtoMessage | bytes | bytearray | memoryview) -> None:
if isinstance(msg, (bytes, bytearray, memoryview)):
data = bytes(msg)
else:
data = msg.to_bytes()
write_varint(self._stream, len(data))
self._stream.write(data)
def close(self) -> None:
if isinstance(self._stream, io.IOBase):
self._stream.close()