init
This commit is contained in:
97
pwr/__init__.py
Normal file
97
pwr/__init__.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from .apply import (
|
||||
FilePool,
|
||||
PatchApplyError,
|
||||
apply_bsdiff_controls,
|
||||
apply_patch,
|
||||
apply_patch_to_folders,
|
||||
apply_rsync_ops,
|
||||
)
|
||||
from .compression import CompressionError
|
||||
from .formats import FilePatch, ManifestReader, PatchReader, SignatureReader, WoundsReader
|
||||
from .proto import (
|
||||
BlockHash,
|
||||
BsdiffHeader,
|
||||
CompressionAlgorithm,
|
||||
CompressionSettings,
|
||||
Control,
|
||||
HashAlgorithm,
|
||||
ManifestBlockHash,
|
||||
ManifestHeader,
|
||||
OverlayHeader,
|
||||
OverlayOp,
|
||||
OverlayOpType,
|
||||
PatchHeader,
|
||||
Sample,
|
||||
SignatureHeader,
|
||||
SyncHeader,
|
||||
SyncHeaderType,
|
||||
SyncOp,
|
||||
SyncOpType,
|
||||
TlcContainer,
|
||||
TlcDir,
|
||||
TlcFile,
|
||||
TlcSymlink,
|
||||
Wound,
|
||||
WoundKind,
|
||||
WoundsHeader,
|
||||
)
|
||||
from .wire import (
|
||||
BLOCK_SIZE,
|
||||
MANIFEST_MAGIC,
|
||||
PATCH_MAGIC,
|
||||
SIGNATURE_MAGIC,
|
||||
WOUNDS_MAGIC,
|
||||
ZIP_INDEX_MAGIC,
|
||||
WireError,
|
||||
WireReader,
|
||||
WireWriter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FilePool",
|
||||
"PatchApplyError",
|
||||
"apply_bsdiff_controls",
|
||||
"apply_patch",
|
||||
"apply_patch_to_folders",
|
||||
"apply_rsync_ops",
|
||||
"CompressionError",
|
||||
"FilePatch",
|
||||
"ManifestReader",
|
||||
"PatchReader",
|
||||
"SignatureReader",
|
||||
"WoundsReader",
|
||||
"BlockHash",
|
||||
"BsdiffHeader",
|
||||
"CompressionAlgorithm",
|
||||
"CompressionSettings",
|
||||
"Control",
|
||||
"HashAlgorithm",
|
||||
"ManifestBlockHash",
|
||||
"ManifestHeader",
|
||||
"OverlayHeader",
|
||||
"OverlayOp",
|
||||
"OverlayOpType",
|
||||
"PatchHeader",
|
||||
"Sample",
|
||||
"SignatureHeader",
|
||||
"SyncHeader",
|
||||
"SyncHeaderType",
|
||||
"SyncOp",
|
||||
"SyncOpType",
|
||||
"TlcContainer",
|
||||
"TlcDir",
|
||||
"TlcFile",
|
||||
"TlcSymlink",
|
||||
"Wound",
|
||||
"WoundKind",
|
||||
"WoundsHeader",
|
||||
"BLOCK_SIZE",
|
||||
"MANIFEST_MAGIC",
|
||||
"PATCH_MAGIC",
|
||||
"SIGNATURE_MAGIC",
|
||||
"WOUNDS_MAGIC",
|
||||
"ZIP_INDEX_MAGIC",
|
||||
"WireError",
|
||||
"WireReader",
|
||||
"WireWriter",
|
||||
]
|
||||
198
pwr/apply.py
Normal file
198
pwr/apply.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import BinaryIO
|
||||
|
||||
from .formats import FilePatch, PatchReader
|
||||
from .proto import Control, SyncOp, SyncOpType
|
||||
from .proto import TlcContainer, TlcDir, TlcFile, TlcSymlink
|
||||
from .wire import BLOCK_SIZE
|
||||
|
||||
|
||||
class PatchApplyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_MODE_MASK = 0o7777
|
||||
|
||||
|
||||
class FilePool:
|
||||
def __init__(self, paths: list[str]):
|
||||
self._paths = list(paths)
|
||||
self._handles: dict[int, BinaryIO] = {}
|
||||
|
||||
def open(self, index: int) -> BinaryIO:
|
||||
if index < 0 or index >= len(self._paths):
|
||||
raise PatchApplyError(f"file index out of range: {index}")
|
||||
handle = self._handles.get(index)
|
||||
if handle is None:
|
||||
handle = open(self._paths[index], "rb")
|
||||
self._handles[index] = handle
|
||||
return handle
|
||||
|
||||
def size(self, index: int) -> int:
|
||||
if index < 0 or index >= len(self._paths):
|
||||
raise PatchApplyError(f"file index out of range: {index}")
|
||||
return os.path.getsize(self._paths[index])
|
||||
|
||||
def close(self) -> None:
|
||||
for handle in self._handles.values():
|
||||
handle.close()
|
||||
self._handles.clear()
|
||||
|
||||
|
||||
def _copy_range(dst: BinaryIO, src: BinaryIO, length: int, buffer_size: int = 32 * 1024) -> None:
|
||||
remaining = length
|
||||
while remaining > 0:
|
||||
chunk = src.read(min(buffer_size, remaining))
|
||||
if not chunk:
|
||||
raise PatchApplyError("unexpected EOF while copying block range")
|
||||
dst.write(chunk)
|
||||
remaining -= len(chunk)
|
||||
|
||||
|
||||
def apply_rsync_ops(ops: list[SyncOp], target_pool: FilePool, output: BinaryIO) -> None:
|
||||
for op in ops:
|
||||
if op.type == SyncOpType.DATA:
|
||||
output.write(op.data or b"")
|
||||
continue
|
||||
|
||||
if op.type != SyncOpType.BLOCK_RANGE:
|
||||
raise PatchApplyError(f"unsupported sync op type: {op.type}")
|
||||
|
||||
if op.file_index is None or op.block_index is None or op.block_span is None:
|
||||
raise PatchApplyError("missing fields in block range op")
|
||||
if op.block_span <= 0:
|
||||
raise PatchApplyError("invalid block span in block range op")
|
||||
|
||||
file_size = target_pool.size(op.file_index)
|
||||
last_block_index = op.block_index + op.block_span - 1
|
||||
last_block_size = BLOCK_SIZE
|
||||
if BLOCK_SIZE * (last_block_index + 1) > file_size:
|
||||
last_block_size = file_size % BLOCK_SIZE
|
||||
op_size = (op.block_span - 1) * BLOCK_SIZE + last_block_size
|
||||
|
||||
src = target_pool.open(op.file_index)
|
||||
src.seek(op.block_index * BLOCK_SIZE)
|
||||
_copy_range(output, src, op_size)
|
||||
|
||||
|
||||
def apply_bsdiff_controls(controls: list[Control], old: BinaryIO, output: BinaryIO) -> None:
|
||||
old_offset = 0
|
||||
for ctrl in controls:
|
||||
if ctrl.eof:
|
||||
break
|
||||
|
||||
add = ctrl.add or b""
|
||||
copy = ctrl.copy or b""
|
||||
seek = ctrl.seek or 0
|
||||
|
||||
if add:
|
||||
old.seek(old_offset)
|
||||
old_chunk = old.read(len(add))
|
||||
if len(old_chunk) != len(add):
|
||||
raise PatchApplyError("unexpected EOF while reading bsdiff add data")
|
||||
out = bytes(((a + b) & 0xFF) for a, b in zip(old_chunk, add))
|
||||
output.write(out)
|
||||
old_offset += len(add)
|
||||
|
||||
if copy:
|
||||
output.write(copy)
|
||||
|
||||
old_offset += seek
|
||||
|
||||
|
||||
def _ensure_parent(path: str) -> None:
|
||||
parent = os.path.dirname(path)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
|
||||
def apply_patch(patch_reader: PatchReader, target_paths: list[str], output_paths: list[str]) -> None:
|
||||
pool = FilePool(target_paths)
|
||||
try:
|
||||
for entry in patch_reader.iter_file_entries():
|
||||
if entry.sync_header.file_index is None:
|
||||
raise PatchApplyError("missing file_index in sync header")
|
||||
out_index = entry.sync_header.file_index
|
||||
if out_index < 0 or out_index >= len(output_paths):
|
||||
raise PatchApplyError(f"output index out of range: {out_index}")
|
||||
|
||||
out_path = output_paths[out_index]
|
||||
_ensure_parent(out_path)
|
||||
|
||||
with open(out_path, "wb") as out:
|
||||
if entry.is_rsync():
|
||||
if entry.sync_ops is None:
|
||||
raise PatchApplyError("missing rsync ops")
|
||||
apply_rsync_ops(entry.sync_ops, pool, out)
|
||||
elif entry.is_bsdiff():
|
||||
if entry.bsdiff_header is None or entry.bsdiff_controls is None:
|
||||
raise PatchApplyError("missing bsdiff data")
|
||||
if entry.bsdiff_header.target_index is None:
|
||||
raise PatchApplyError("missing target_index in bsdiff header")
|
||||
old = pool.open(entry.bsdiff_header.target_index)
|
||||
apply_bsdiff_controls(entry.bsdiff_controls, old, out)
|
||||
else:
|
||||
raise PatchApplyError("unknown file patch type")
|
||||
finally:
|
||||
pool.close()
|
||||
|
||||
|
||||
def _container_file_paths(container: TlcContainer, root: str) -> list[str]:
|
||||
paths = []
|
||||
for f in container.files:
|
||||
if f.path is None:
|
||||
raise PatchApplyError("container file missing path")
|
||||
paths.append(os.path.join(root, f.path))
|
||||
return paths
|
||||
|
||||
|
||||
def _ensure_dirs(container: TlcContainer, root: str) -> None:
|
||||
for d in container.dirs:
|
||||
if d.path is None:
|
||||
continue
|
||||
path = os.path.join(root, d.path)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if d.mode is not None:
|
||||
mode = int(d.mode) & _MODE_MASK
|
||||
try:
|
||||
os.chmod(path, mode)
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_symlinks(container: TlcContainer, root: str) -> None:
|
||||
for s in container.symlinks:
|
||||
if s.path is None or s.dest is None:
|
||||
continue
|
||||
path = os.path.join(root, s.path)
|
||||
if os.path.lexists(path):
|
||||
continue
|
||||
parent = os.path.dirname(path)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
try:
|
||||
os.symlink(s.dest, path)
|
||||
except (AttributeError, OSError):
|
||||
continue
|
||||
|
||||
|
||||
def apply_patch_to_folders(patch_path: str, target_root: str, output_root: str) -> None:
|
||||
with PatchReader.open(
|
||||
patch_path,
|
||||
container_decoder=(TlcContainer.from_bytes, TlcContainer.from_bytes),
|
||||
) as reader:
|
||||
if reader.target_container is None or reader.source_container is None:
|
||||
raise PatchApplyError("missing containers in patch")
|
||||
|
||||
target_container: TlcContainer = reader.target_container
|
||||
source_container: TlcContainer = reader.source_container
|
||||
|
||||
_ensure_dirs(source_container, output_root)
|
||||
_ensure_symlinks(source_container, output_root)
|
||||
|
||||
target_paths = _container_file_paths(target_container, target_root)
|
||||
output_paths = _container_file_paths(source_container, output_root)
|
||||
|
||||
apply_patch(reader, target_paths, output_paths)
|
||||
158
pwr/compression.py
Normal file
158
pwr/compression.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import gzip
|
||||
from typing import BinaryIO
|
||||
|
||||
from .proto import CompressionAlgorithm, CompressionSettings
|
||||
|
||||
|
||||
class CompressionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _normalize_algorithm(algorithm: CompressionAlgorithm | int | None) -> CompressionAlgorithm:
|
||||
if algorithm is None:
|
||||
return CompressionAlgorithm.NONE
|
||||
if isinstance(algorithm, CompressionAlgorithm):
|
||||
return algorithm
|
||||
try:
|
||||
return CompressionAlgorithm(int(algorithm))
|
||||
except ValueError as exc:
|
||||
raise CompressionError(f"unknown compression algorithm: {algorithm}") from exc
|
||||
|
||||
|
||||
def _brotli_module():
|
||||
try:
|
||||
import brotli # type: ignore
|
||||
|
||||
return brotli
|
||||
except ImportError:
|
||||
try:
|
||||
import brotlicffi as brotli # type: ignore
|
||||
|
||||
return brotli
|
||||
except ImportError as exc:
|
||||
raise CompressionError("brotli module not available") from exc
|
||||
|
||||
|
||||
class _BrotliReader(io.RawIOBase):
|
||||
def __init__(self, raw: BinaryIO):
|
||||
self._raw = raw
|
||||
brotli = _brotli_module()
|
||||
self._decompressor = brotli.Decompressor()
|
||||
self._buffer = bytearray()
|
||||
self._eof = False
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def _finish(self) -> bytes:
|
||||
if hasattr(self._decompressor, "finish"):
|
||||
return self._decompressor.finish()
|
||||
if hasattr(self._decompressor, "flush"):
|
||||
return self._decompressor.flush()
|
||||
return b""
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if size == 0:
|
||||
return b""
|
||||
|
||||
while not self._eof and (size < 0 or len(self._buffer) < size):
|
||||
chunk = self._raw.read(8192)
|
||||
if not chunk:
|
||||
self._buffer.extend(self._finish())
|
||||
self._eof = True
|
||||
break
|
||||
self._buffer.extend(self._decompressor.process(chunk))
|
||||
|
||||
if size < 0:
|
||||
data = bytes(self._buffer)
|
||||
self._buffer.clear()
|
||||
return data
|
||||
|
||||
data = bytes(self._buffer[:size])
|
||||
del self._buffer[:size]
|
||||
return data
|
||||
|
||||
def readinto(self, b) -> int:
|
||||
data = self.read(len(b))
|
||||
n = len(data)
|
||||
b[:n] = data
|
||||
return n
|
||||
|
||||
def close(self) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
try:
|
||||
if hasattr(self._raw, "close"):
|
||||
self._raw.close()
|
||||
finally:
|
||||
super().close()
|
||||
|
||||
|
||||
class _BrotliWriter(io.RawIOBase):
|
||||
def __init__(self, raw: BinaryIO, quality: int | None):
|
||||
self._raw = raw
|
||||
brotli = _brotli_module()
|
||||
kwargs = {}
|
||||
if quality is not None:
|
||||
kwargs["quality"] = int(quality)
|
||||
self._compressor = brotli.Compressor(**kwargs)
|
||||
|
||||
def writable(self) -> bool:
|
||||
return True
|
||||
|
||||
def write(self, b: bytes) -> int:
|
||||
out = self._compressor.process(b)
|
||||
self._raw.write(out)
|
||||
return len(b)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
try:
|
||||
tail = self._compressor.finish()
|
||||
if tail:
|
||||
self._raw.write(tail)
|
||||
if hasattr(self._raw, "close"):
|
||||
self._raw.close()
|
||||
finally:
|
||||
super().close()
|
||||
|
||||
|
||||
def open_decompressed_reader(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO:
|
||||
algorithm = _normalize_algorithm(compression.algorithm if compression else None)
|
||||
if algorithm == CompressionAlgorithm.NONE:
|
||||
return stream
|
||||
if algorithm == CompressionAlgorithm.GZIP:
|
||||
return gzip.GzipFile(fileobj=stream, mode="rb")
|
||||
if algorithm == CompressionAlgorithm.BROTLI:
|
||||
return io.BufferedReader(_BrotliReader(stream))
|
||||
if algorithm == CompressionAlgorithm.ZSTD:
|
||||
try:
|
||||
import zstandard as zstd # type: ignore
|
||||
except ImportError as exc:
|
||||
raise CompressionError("zstandard module not available") from exc
|
||||
return zstd.ZstdDecompressor().stream_reader(stream)
|
||||
raise CompressionError(f"unsupported compression algorithm: {algorithm}")
|
||||
|
||||
|
||||
def open_compressed_writer(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO:
|
||||
algorithm = _normalize_algorithm(compression.algorithm if compression else None)
|
||||
quality = compression.quality if compression else None
|
||||
if algorithm == CompressionAlgorithm.NONE:
|
||||
return stream
|
||||
if algorithm == CompressionAlgorithm.GZIP:
|
||||
level = 9 if quality is None else int(quality)
|
||||
return gzip.GzipFile(fileobj=stream, mode="wb", compresslevel=level)
|
||||
if algorithm == CompressionAlgorithm.BROTLI:
|
||||
return io.BufferedWriter(_BrotliWriter(stream, quality))
|
||||
if algorithm == CompressionAlgorithm.ZSTD:
|
||||
try:
|
||||
import zstandard as zstd # type: ignore
|
||||
except ImportError as exc:
|
||||
raise CompressionError("zstandard module not available") from exc
|
||||
level = 3 if quality is None else int(quality)
|
||||
return zstd.ZstdCompressor(level=level).stream_writer(stream)
|
||||
raise CompressionError(f"unsupported compression algorithm: {algorithm}")
|
||||
286
pwr/formats.py
Normal file
286
pwr/formats.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
from .compression import open_decompressed_reader
|
||||
from .proto import (
|
||||
BlockHash,
|
||||
BsdiffHeader,
|
||||
Control,
|
||||
ManifestBlockHash,
|
||||
ManifestHeader,
|
||||
PatchHeader,
|
||||
SignatureHeader,
|
||||
SyncHeader,
|
||||
SyncHeaderType,
|
||||
SyncOp,
|
||||
SyncOpType,
|
||||
Wound,
|
||||
WoundsHeader,
|
||||
)
|
||||
from .wire import (
|
||||
MANIFEST_MAGIC,
|
||||
PATCH_MAGIC,
|
||||
SIGNATURE_MAGIC,
|
||||
WOUNDS_MAGIC,
|
||||
WireError,
|
||||
WireReader,
|
||||
)
|
||||
|
||||
ContainerDecoder = Callable[[bytes], Any]
|
||||
|
||||
|
||||
def _split_decoder(decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None):
|
||||
if decoder is None:
|
||||
return None, None
|
||||
if isinstance(decoder, tuple):
|
||||
if len(decoder) != 2:
|
||||
raise ValueError("container decoder tuple must have two elements")
|
||||
return decoder
|
||||
return decoder, decoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePatch:
|
||||
sync_header: SyncHeader
|
||||
sync_ops: list[SyncOp] | None = None
|
||||
bsdiff_header: BsdiffHeader | None = None
|
||||
bsdiff_controls: list[Control] | None = None
|
||||
|
||||
def is_rsync(self) -> bool:
|
||||
return self.sync_ops is not None
|
||||
|
||||
def is_bsdiff(self) -> bool:
|
||||
return self.bsdiff_header is not None
|
||||
|
||||
|
||||
class PatchReader:
|
||||
def __init__(self, stream, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None):
|
||||
self._stream = stream
|
||||
self._raw_wire = WireReader(stream)
|
||||
self._wire: WireReader | None = None
|
||||
|
||||
self.header: PatchHeader | None = None
|
||||
self.target_container_raw: bytes | None = None
|
||||
self.source_container_raw: bytes | None = None
|
||||
self.target_container: Any | None = None
|
||||
self.source_container: Any | None = None
|
||||
|
||||
self._init_stream(container_decoder)
|
||||
|
||||
@classmethod
|
||||
def open(cls, path: str, *, container_decoder: ContainerDecoder | tuple[ContainerDecoder | None, ContainerDecoder | None] | None = None):
|
||||
return cls(open(path, "rb"), container_decoder=container_decoder)
|
||||
|
||||
def _init_stream(self, container_decoder):
|
||||
self._raw_wire.expect_magic(PATCH_MAGIC)
|
||||
self.header = self._raw_wire.read_message(PatchHeader)
|
||||
|
||||
decompressed = open_decompressed_reader(self._stream, self.header.compression)
|
||||
self._wire = WireReader(decompressed)
|
||||
|
||||
target_decoder, source_decoder = _split_decoder(container_decoder)
|
||||
self.target_container_raw, self.target_container = self._read_container(target_decoder)
|
||||
self.source_container_raw, self.source_container = self._read_container(source_decoder)
|
||||
|
||||
def _read_container(self, decoder: ContainerDecoder | None):
|
||||
assert self._wire is not None
|
||||
raw = self._wire.read_message_bytes()
|
||||
parsed = decoder(raw) if decoder else None
|
||||
return raw, parsed
|
||||
|
||||
def iter_file_entries(self) -> Iterator[FilePatch]:
|
||||
assert self._wire is not None
|
||||
file_index = 0
|
||||
while True:
|
||||
try:
|
||||
sync_header = self._wire.read_message(SyncHeader)
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
if sync_header.file_index is None:
|
||||
sync_header.file_index = file_index
|
||||
else:
|
||||
file_index = int(sync_header.file_index)
|
||||
|
||||
header_type = sync_header.type
|
||||
if header_type is None:
|
||||
header_type = SyncHeaderType.RSYNC
|
||||
else:
|
||||
try:
|
||||
header_type = SyncHeaderType(int(header_type))
|
||||
except ValueError:
|
||||
raise WireError(f"unknown sync header type: {header_type}")
|
||||
|
||||
if header_type == SyncHeaderType.BSDIFF:
|
||||
bsdiff_header = self._wire.read_message(BsdiffHeader)
|
||||
controls: list[Control] = []
|
||||
while True:
|
||||
ctrl = self._wire.read_message(Control)
|
||||
controls.append(ctrl)
|
||||
if ctrl.eof:
|
||||
break
|
||||
|
||||
sentinel = self._wire.read_message(SyncOp)
|
||||
if sentinel.type != SyncOpType.HEY_YOU_DID_IT:
|
||||
raise WireError("expected BSDIFF sentinel SyncOp")
|
||||
|
||||
yield FilePatch(
|
||||
sync_header=sync_header,
|
||||
bsdiff_header=bsdiff_header,
|
||||
bsdiff_controls=controls,
|
||||
)
|
||||
file_index += 1
|
||||
continue
|
||||
|
||||
if header_type != SyncHeaderType.RSYNC:
|
||||
raise WireError(f"unsupported sync header type: {header_type}")
|
||||
|
||||
ops: list[SyncOp] = []
|
||||
while True:
|
||||
op = self._wire.read_message(SyncOp)
|
||||
if op.type == SyncOpType.HEY_YOU_DID_IT:
|
||||
break
|
||||
ops.append(op)
|
||||
|
||||
yield FilePatch(sync_header=sync_header, sync_ops=ops)
|
||||
file_index += 1
|
||||
|
||||
def close(self) -> None:
|
||||
if self._wire:
|
||||
self._wire.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class SignatureReader:
|
||||
def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None):
|
||||
self._stream = stream
|
||||
self._raw_wire = WireReader(stream)
|
||||
self._wire: WireReader | None = None
|
||||
|
||||
self.header: SignatureHeader | None = None
|
||||
self.container_raw: bytes | None = None
|
||||
self.container: Any | None = None
|
||||
|
||||
self._init_stream(container_decoder)
|
||||
|
||||
@classmethod
|
||||
def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None):
|
||||
return cls(open(path, "rb"), container_decoder=container_decoder)
|
||||
|
||||
def _init_stream(self, container_decoder):
|
||||
self._raw_wire.expect_magic(SIGNATURE_MAGIC)
|
||||
self.header = self._raw_wire.read_message(SignatureHeader)
|
||||
|
||||
decompressed = open_decompressed_reader(self._stream, self.header.compression)
|
||||
self._wire = WireReader(decompressed)
|
||||
|
||||
self.container_raw = self._wire.read_message_bytes()
|
||||
self.container = container_decoder(self.container_raw) if container_decoder else None
|
||||
|
||||
def iter_block_hashes(self) -> Iterator[BlockHash]:
|
||||
assert self._wire is not None
|
||||
while True:
|
||||
try:
|
||||
yield self._wire.read_message(BlockHash)
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
if self._wire:
|
||||
self._wire.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class ManifestReader:
|
||||
def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None):
|
||||
self._stream = stream
|
||||
self._raw_wire = WireReader(stream)
|
||||
self._wire: WireReader | None = None
|
||||
|
||||
self.header: ManifestHeader | None = None
|
||||
self.container_raw: bytes | None = None
|
||||
self.container: Any | None = None
|
||||
|
||||
self._init_stream(container_decoder)
|
||||
|
||||
@classmethod
|
||||
def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None):
|
||||
return cls(open(path, "rb"), container_decoder=container_decoder)
|
||||
|
||||
def _init_stream(self, container_decoder):
|
||||
self._raw_wire.expect_magic(MANIFEST_MAGIC)
|
||||
self.header = self._raw_wire.read_message(ManifestHeader)
|
||||
|
||||
decompressed = open_decompressed_reader(self._stream, self.header.compression)
|
||||
self._wire = WireReader(decompressed)
|
||||
|
||||
self.container_raw = self._wire.read_message_bytes()
|
||||
self.container = container_decoder(self.container_raw) if container_decoder else None
|
||||
|
||||
def iter_block_hashes(self) -> Iterator[ManifestBlockHash]:
|
||||
assert self._wire is not None
|
||||
while True:
|
||||
try:
|
||||
yield self._wire.read_message(ManifestBlockHash)
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
if self._wire:
|
||||
self._wire.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class WoundsReader:
|
||||
def __init__(self, stream, *, container_decoder: ContainerDecoder | None = None):
|
||||
self._stream = stream
|
||||
self._wire = WireReader(stream)
|
||||
|
||||
self.header: WoundsHeader | None = None
|
||||
self.container_raw: bytes | None = None
|
||||
self.container: Any | None = None
|
||||
|
||||
self._init_stream(container_decoder)
|
||||
|
||||
@classmethod
|
||||
def open(cls, path: str, *, container_decoder: ContainerDecoder | None = None):
|
||||
return cls(open(path, "rb"), container_decoder=container_decoder)
|
||||
|
||||
def _init_stream(self, container_decoder):
|
||||
self._wire.expect_magic(WOUNDS_MAGIC)
|
||||
self.header = self._wire.read_message(WoundsHeader)
|
||||
self.container_raw = self._wire.read_message_bytes()
|
||||
self.container = container_decoder(self.container_raw) if container_decoder else None
|
||||
|
||||
def iter_wounds(self) -> Iterator[Wound]:
|
||||
while True:
|
||||
try:
|
||||
yield self._wire.read_message(Wound)
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
self._wire.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
477
pwr/proto.py
Normal file
477
pwr/proto.py
Normal file
@@ -0,0 +1,477 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, ClassVar, Dict, Tuple
|
||||
|
||||
|
||||
class ProtoError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _unsigned(value: int, bits: int) -> int:
|
||||
if value < 0:
|
||||
value = (1 << bits) + value
|
||||
return value
|
||||
|
||||
|
||||
def _signed(value: int, bits: int) -> int:
|
||||
sign_bit = 1 << (bits - 1)
|
||||
if value & sign_bit:
|
||||
return value - (1 << bits)
|
||||
return value
|
||||
|
||||
|
||||
def _encode_varint(value: int) -> bytes:
|
||||
if value < 0:
|
||||
raise ProtoError("varint cannot encode negative values")
|
||||
out = bytearray()
|
||||
while True:
|
||||
to_write = value & 0x7F
|
||||
value >>= 7
|
||||
if value:
|
||||
out.append(to_write | 0x80)
|
||||
else:
|
||||
out.append(to_write)
|
||||
break
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _encode_key(field_number: int, wire_type: int) -> bytes:
|
||||
return _encode_varint((field_number << 3) | wire_type)
|
||||
|
||||
|
||||
class ProtoReader:
|
||||
def __init__(self, data: bytes):
|
||||
self._data = bytes(data)
|
||||
self._pos = 0
|
||||
|
||||
def eof(self) -> bool:
|
||||
return self._pos >= len(self._data)
|
||||
|
||||
def read_varint(self) -> int:
|
||||
shift = 0
|
||||
result = 0
|
||||
while True:
|
||||
if self._pos >= len(self._data):
|
||||
raise ProtoError("unexpected EOF while reading varint")
|
||||
byte = self._data[self._pos]
|
||||
self._pos += 1
|
||||
result |= (byte & 0x7F) << shift
|
||||
if not (byte & 0x80):
|
||||
return result
|
||||
shift += 7
|
||||
if shift > 64:
|
||||
raise ProtoError("varint too long")
|
||||
|
||||
def read_key(self) -> Tuple[int, int]:
|
||||
key = self.read_varint()
|
||||
return key >> 3, key & 0x07
|
||||
|
||||
def read_length_delimited(self) -> bytes:
|
||||
length = self.read_varint()
|
||||
end = self._pos + length
|
||||
if end > len(self._data):
|
||||
raise ProtoError("unexpected EOF while reading length-delimited field")
|
||||
data = self._data[self._pos:end]
|
||||
self._pos = end
|
||||
return data
|
||||
|
||||
def skip(self, wire_type: int) -> None:
|
||||
if wire_type == 0:
|
||||
self.read_varint()
|
||||
return
|
||||
if wire_type == 1:
|
||||
self._pos += 8
|
||||
return
|
||||
if wire_type == 2:
|
||||
length = self.read_varint()
|
||||
self._pos += length
|
||||
return
|
||||
if wire_type == 5:
|
||||
self._pos += 4
|
||||
return
|
||||
raise ProtoError(f"unsupported wire type: {wire_type}")
|
||||
|
||||
|
||||
class ProtoMessage:
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
out = bytearray()
|
||||
for field_number in sorted(self.FIELDS):
|
||||
name, field_type, extra = self.FIELDS[field_number]
|
||||
value = getattr(self, name)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
repeated = False
|
||||
if field_type.startswith("repeated_"):
|
||||
repeated = True
|
||||
field_type = field_type[len("repeated_") :]
|
||||
|
||||
values = value if repeated else [value]
|
||||
for item in values:
|
||||
if field_type == "message":
|
||||
if isinstance(item, (bytes, bytearray, memoryview)):
|
||||
raw = bytes(item)
|
||||
else:
|
||||
raw = item.to_bytes()
|
||||
out.extend(_encode_key(field_number, 2))
|
||||
out.extend(_encode_varint(len(raw)))
|
||||
out.extend(raw)
|
||||
continue
|
||||
|
||||
if field_type == "bytes":
|
||||
raw = bytes(item)
|
||||
out.extend(_encode_key(field_number, 2))
|
||||
out.extend(_encode_varint(len(raw)))
|
||||
out.extend(raw)
|
||||
continue
|
||||
|
||||
if field_type == "string":
|
||||
raw = str(item).encode("utf-8")
|
||||
out.extend(_encode_key(field_number, 2))
|
||||
out.extend(_encode_varint(len(raw)))
|
||||
out.extend(raw)
|
||||
continue
|
||||
|
||||
if field_type == "bool":
|
||||
out.extend(_encode_key(field_number, 0))
|
||||
out.extend(_encode_varint(1 if item else 0))
|
||||
continue
|
||||
|
||||
if field_type == "enum":
|
||||
if isinstance(item, IntEnum):
|
||||
item = int(item)
|
||||
out.extend(_encode_key(field_number, 0))
|
||||
out.extend(_encode_varint(_unsigned(int(item), 64)))
|
||||
continue
|
||||
|
||||
if field_type in ("int32", "int64"):
|
||||
bits = 32 if field_type == "int32" else 64
|
||||
out.extend(_encode_key(field_number, 0))
|
||||
out.extend(_encode_varint(_unsigned(int(item), bits)))
|
||||
continue
|
||||
|
||||
if field_type in ("uint32", "uint64"):
|
||||
out.extend(_encode_key(field_number, 0))
|
||||
out.extend(_encode_varint(int(item)))
|
||||
continue
|
||||
|
||||
raise ProtoError(f"unsupported field type: {field_type}")
|
||||
|
||||
return bytes(out)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
reader = ProtoReader(data)
|
||||
obj = cls()
|
||||
while not reader.eof():
|
||||
field_number, wire_type = reader.read_key()
|
||||
spec = cls.FIELDS.get(field_number)
|
||||
if spec is None:
|
||||
reader.skip(wire_type)
|
||||
continue
|
||||
name, field_type, extra = spec
|
||||
|
||||
repeated = False
|
||||
if field_type.startswith("repeated_"):
|
||||
repeated = True
|
||||
field_type = field_type[len("repeated_") :]
|
||||
|
||||
if field_type == "message":
|
||||
raw = reader.read_length_delimited()
|
||||
if extra is None:
|
||||
value = raw
|
||||
else:
|
||||
value = extra.from_bytes(raw)
|
||||
elif field_type == "bytes":
|
||||
value = reader.read_length_delimited()
|
||||
elif field_type == "string":
|
||||
value = reader.read_length_delimited().decode("utf-8")
|
||||
elif field_type == "bool":
|
||||
value = bool(reader.read_varint())
|
||||
elif field_type == "enum":
|
||||
raw_value = reader.read_varint()
|
||||
if extra is None:
|
||||
value = raw_value
|
||||
else:
|
||||
try:
|
||||
value = extra(raw_value)
|
||||
except ValueError:
|
||||
value = raw_value
|
||||
elif field_type == "int32":
|
||||
value = _signed(reader.read_varint(), 32)
|
||||
elif field_type == "int64":
|
||||
value = _signed(reader.read_varint(), 64)
|
||||
elif field_type == "uint32":
|
||||
value = reader.read_varint()
|
||||
elif field_type == "uint64":
|
||||
value = reader.read_varint()
|
||||
else:
|
||||
raise ProtoError(f"unsupported field type: {field_type}")
|
||||
|
||||
if repeated:
|
||||
current = getattr(obj, name, None)
|
||||
if current is None:
|
||||
current = []
|
||||
setattr(obj, name, current)
|
||||
current.append(value)
|
||||
else:
|
||||
setattr(obj, name, value)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class CompressionAlgorithm(IntEnum):
|
||||
NONE = 0
|
||||
BROTLI = 1
|
||||
GZIP = 2
|
||||
ZSTD = 3
|
||||
|
||||
|
||||
class HashAlgorithm(IntEnum):
|
||||
SHAKE128_32 = 0
|
||||
CRC32C = 1
|
||||
|
||||
|
||||
class WoundKind(IntEnum):
|
||||
FILE = 0
|
||||
SYMLINK = 1
|
||||
DIR = 2
|
||||
CLOSED_FILE = 3
|
||||
|
||||
|
||||
class SyncHeaderType(IntEnum):
|
||||
RSYNC = 0
|
||||
BSDIFF = 1
|
||||
|
||||
|
||||
class SyncOpType(IntEnum):
|
||||
BLOCK_RANGE = 0
|
||||
DATA = 1
|
||||
HEY_YOU_DID_IT = 2049
|
||||
|
||||
|
||||
class OverlayOpType(IntEnum):
|
||||
SKIP = 0
|
||||
FRESH = 1
|
||||
HEY_YOU_DID_IT = 2040
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionSettings(ProtoMessage):
|
||||
algorithm: CompressionAlgorithm | int | None = None
|
||||
quality: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("algorithm", "enum", CompressionAlgorithm),
|
||||
2: ("quality", "int32", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchHeader(ProtoMessage):
|
||||
compression: CompressionSettings | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("compression", "message", CompressionSettings),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncHeader(ProtoMessage):
|
||||
type: SyncHeaderType | int | None = None
|
||||
file_index: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("type", "enum", SyncHeaderType),
|
||||
16: ("file_index", "int64", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BsdiffHeader(ProtoMessage):
|
||||
target_index: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("target_index", "int64", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncOp(ProtoMessage):
|
||||
type: SyncOpType | int | None = None
|
||||
file_index: int | None = None
|
||||
block_index: int | None = None
|
||||
block_span: int | None = None
|
||||
data: bytes | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("type", "enum", SyncOpType),
|
||||
2: ("file_index", "int64", None),
|
||||
3: ("block_index", "int64", None),
|
||||
4: ("block_span", "int64", None),
|
||||
5: ("data", "bytes", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignatureHeader(ProtoMessage):
|
||||
compression: CompressionSettings | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("compression", "message", CompressionSettings),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockHash(ProtoMessage):
|
||||
weak_hash: int | None = None
|
||||
strong_hash: bytes | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("weak_hash", "uint32", None),
|
||||
2: ("strong_hash", "bytes", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManifestHeader(ProtoMessage):
|
||||
compression: CompressionSettings | None = None
|
||||
algorithm: HashAlgorithm | int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("compression", "message", CompressionSettings),
|
||||
2: ("algorithm", "enum", HashAlgorithm),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManifestBlockHash(ProtoMessage):
|
||||
hash: bytes | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("hash", "bytes", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WoundsHeader(ProtoMessage):
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wound(ProtoMessage):
|
||||
index: int | None = None
|
||||
start: int | None = None
|
||||
end: int | None = None
|
||||
kind: WoundKind | int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("index", "int64", None),
|
||||
2: ("start", "int64", None),
|
||||
3: ("end", "int64", None),
|
||||
4: ("kind", "enum", WoundKind),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Control(ProtoMessage):
|
||||
add: bytes | None = None
|
||||
copy: bytes | None = None
|
||||
seek: int | None = None
|
||||
eof: bool | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("add", "bytes", None),
|
||||
2: ("copy", "bytes", None),
|
||||
3: ("seek", "int64", None),
|
||||
4: ("eof", "bool", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OverlayHeader(ProtoMessage):
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OverlayOp(ProtoMessage):
|
||||
type: OverlayOpType | int | None = None
|
||||
len: int | None = None
|
||||
data: bytes | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("type", "enum", OverlayOpType),
|
||||
2: ("len", "int64", None),
|
||||
3: ("data", "bytes", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sample(ProtoMessage):
|
||||
data: bytes | None = None
|
||||
number: int | None = None
|
||||
eof: bool | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("data", "bytes", None),
|
||||
2: ("number", "int64", None),
|
||||
3: ("eof", "bool", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlcDir(ProtoMessage):
|
||||
path: str | None = None
|
||||
mode: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("path", "string", None),
|
||||
2: ("mode", "uint32", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlcFile(ProtoMessage):
|
||||
path: str | None = None
|
||||
mode: int | None = None
|
||||
size: int | None = None
|
||||
offset: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("path", "string", None),
|
||||
2: ("mode", "uint32", None),
|
||||
3: ("size", "int64", None),
|
||||
4: ("offset", "int64", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlcSymlink(ProtoMessage):
|
||||
path: str | None = None
|
||||
mode: int | None = None
|
||||
dest: str | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("path", "string", None),
|
||||
2: ("mode", "uint32", None),
|
||||
3: ("dest", "string", None),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlcContainer(ProtoMessage):
|
||||
files: list[TlcFile] = field(default_factory=list)
|
||||
dirs: list[TlcDir] = field(default_factory=list)
|
||||
symlinks: list[TlcSymlink] = field(default_factory=list)
|
||||
size: int | None = None
|
||||
|
||||
FIELDS: ClassVar[Dict[int, Tuple[str, str, Any]]] = {
|
||||
1: ("files", "repeated_message", TlcFile),
|
||||
2: ("dirs", "repeated_message", TlcDir),
|
||||
3: ("symlinks", "repeated_message", TlcSymlink),
|
||||
16: ("size", "int64", None),
|
||||
}
|
||||
140
pwr/wire.py
Normal file
140
pwr/wire.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
from typing import BinaryIO, Type, TypeVar
|
||||
|
||||
from .proto import ProtoError, ProtoMessage
|
||||
|
||||
PATCH_MAGIC = 0xFEF5F00
|
||||
SIGNATURE_MAGIC = PATCH_MAGIC + 1
|
||||
MANIFEST_MAGIC = PATCH_MAGIC + 2
|
||||
WOUNDS_MAGIC = PATCH_MAGIC + 3
|
||||
ZIP_INDEX_MAGIC = PATCH_MAGIC + 4
|
||||
|
||||
BLOCK_SIZE = 64 * 1024
|
||||
|
||||
|
||||
class WireError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _read_exact(stream: BinaryIO, size: int) -> bytes:
|
||||
if size <= 0:
|
||||
return b""
|
||||
try:
|
||||
data = stream.read(size)
|
||||
if data is not None:
|
||||
return data
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
readinto = getattr(stream, "readinto", None)
|
||||
if readinto is None:
|
||||
raise WireError("stream does not support read or readinto")
|
||||
|
||||
buf = bytearray(size)
|
||||
view = memoryview(buf)
|
||||
total = 0
|
||||
while total < size:
|
||||
n = readinto(view[total:])
|
||||
if n is None:
|
||||
n = 0
|
||||
if n == 0:
|
||||
break
|
||||
total += n
|
||||
return bytes(view[:total])
|
||||
|
||||
|
||||
def read_varint(stream: BinaryIO) -> int:
|
||||
shift = 0
|
||||
result = 0
|
||||
while True:
|
||||
b = _read_exact(stream, 1)
|
||||
if not b:
|
||||
raise EOFError("unexpected EOF while reading varint")
|
||||
byte = b[0]
|
||||
result |= (byte & 0x7F) << shift
|
||||
if not (byte & 0x80):
|
||||
return result
|
||||
shift += 7
|
||||
if shift > 64:
|
||||
raise WireError("varint too long")
|
||||
|
||||
|
||||
def write_varint(stream: BinaryIO, value: int) -> None:
|
||||
if value < 0:
|
||||
raise WireError("varint cannot encode negative values")
|
||||
while True:
|
||||
to_write = value & 0x7F
|
||||
value >>= 7
|
||||
if value:
|
||||
stream.write(bytes([to_write | 0x80]))
|
||||
else:
|
||||
stream.write(bytes([to_write]))
|
||||
break
|
||||
|
||||
|
||||
def read_magic(stream: BinaryIO) -> int:
|
||||
data = _read_exact(stream, 4)
|
||||
if len(data) != 4:
|
||||
raise EOFError("unexpected EOF while reading magic")
|
||||
return struct.unpack("<i", data)[0]
|
||||
|
||||
|
||||
def write_magic(stream: BinaryIO, magic: int) -> None:
|
||||
stream.write(struct.pack("<i", int(magic)))
|
||||
|
||||
|
||||
T = TypeVar("T", bound=ProtoMessage)
|
||||
|
||||
|
||||
class WireReader:
|
||||
def __init__(self, stream: BinaryIO):
|
||||
self._stream = stream
|
||||
|
||||
def read_magic(self) -> int:
|
||||
return read_magic(self._stream)
|
||||
|
||||
def expect_magic(self, magic: int) -> None:
|
||||
found = self.read_magic()
|
||||
if found != magic:
|
||||
raise WireError(f"wrong magic: expected {magic}, got {found}")
|
||||
|
||||
def read_message_bytes(self) -> bytes:
|
||||
length = read_varint(self._stream)
|
||||
data = _read_exact(self._stream, length)
|
||||
if len(data) != length:
|
||||
raise EOFError("unexpected EOF while reading message")
|
||||
return data
|
||||
|
||||
def read_message(self, msg_cls: Type[T]) -> T:
|
||||
data = self.read_message_bytes()
|
||||
try:
|
||||
return msg_cls.from_bytes(data)
|
||||
except ProtoError as exc:
|
||||
raise WireError(str(exc)) from exc
|
||||
|
||||
def close(self) -> None:
|
||||
if isinstance(self._stream, io.IOBase):
|
||||
self._stream.close()
|
||||
|
||||
|
||||
class WireWriter:
|
||||
def __init__(self, stream: BinaryIO):
|
||||
self._stream = stream
|
||||
|
||||
def write_magic(self, magic: int) -> None:
|
||||
write_magic(self._stream, magic)
|
||||
|
||||
def write_message(self, msg: ProtoMessage | bytes | bytearray | memoryview) -> None:
|
||||
if isinstance(msg, (bytes, bytearray, memoryview)):
|
||||
data = bytes(msg)
|
||||
else:
|
||||
data = msg.to_bytes()
|
||||
write_varint(self._stream, len(data))
|
||||
self._stream.write(data)
|
||||
|
||||
def close(self) -> None:
|
||||
if isinstance(self._stream, io.IOBase):
|
||||
self._stream.close()
|
||||
Reference in New Issue
Block a user