from __future__ import annotations import io import gzip from typing import BinaryIO from .proto import CompressionAlgorithm, CompressionSettings class CompressionError(Exception): pass def _normalize_algorithm(algorithm: CompressionAlgorithm | int | None) -> CompressionAlgorithm: if algorithm is None: return CompressionAlgorithm.NONE if isinstance(algorithm, CompressionAlgorithm): return algorithm try: return CompressionAlgorithm(int(algorithm)) except ValueError as exc: raise CompressionError(f"unknown compression algorithm: {algorithm}") from exc def _brotli_module(): try: import brotli # type: ignore return brotli except ImportError: try: import brotlicffi as brotli # type: ignore return brotli except ImportError as exc: raise CompressionError("brotli module not available") from exc class _BrotliReader(io.RawIOBase): def __init__(self, raw: BinaryIO): self._raw = raw brotli = _brotli_module() self._decompressor = brotli.Decompressor() self._buffer = bytearray() self._eof = False def readable(self) -> bool: return True def _finish(self) -> bytes: if hasattr(self._decompressor, "finish"): return self._decompressor.finish() if hasattr(self._decompressor, "flush"): return self._decompressor.flush() return b"" def read(self, size: int = -1) -> bytes: if size == 0: return b"" while not self._eof and (size < 0 or len(self._buffer) < size): chunk = self._raw.read(8192) if not chunk: self._buffer.extend(self._finish()) self._eof = True break self._buffer.extend(self._decompressor.process(chunk)) if size < 0: data = bytes(self._buffer) self._buffer.clear() return data data = bytes(self._buffer[:size]) del self._buffer[:size] return data def readinto(self, b) -> int: data = self.read(len(b)) n = len(data) b[:n] = data return n def close(self) -> None: if self.closed: return try: if hasattr(self._raw, "close"): self._raw.close() finally: super().close() class _BrotliWriter(io.RawIOBase): def __init__(self, raw: BinaryIO, quality: int | None): self._raw = raw brotli = _brotli_module() kwargs = {} if quality is not None: kwargs["quality"] = int(quality) self._compressor = brotli.Compressor(**kwargs) def writable(self) -> bool: return True def write(self, b: bytes) -> int: out = self._compressor.process(b) self._raw.write(out) return len(b) def close(self) -> None: if self.closed: return try: tail = self._compressor.finish() if tail: self._raw.write(tail) if hasattr(self._raw, "close"): self._raw.close() finally: super().close() def open_decompressed_reader(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO: algorithm = _normalize_algorithm(compression.algorithm if compression else None) if algorithm == CompressionAlgorithm.NONE: return stream if algorithm == CompressionAlgorithm.GZIP: return gzip.GzipFile(fileobj=stream, mode="rb") if algorithm == CompressionAlgorithm.BROTLI: return io.BufferedReader(_BrotliReader(stream)) if algorithm == CompressionAlgorithm.ZSTD: try: import zstandard as zstd # type: ignore except ImportError as exc: raise CompressionError("zstandard module not available") from exc return zstd.ZstdDecompressor().stream_reader(stream) raise CompressionError(f"unsupported compression algorithm: {algorithm}") def open_compressed_writer(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO: algorithm = _normalize_algorithm(compression.algorithm if compression else None) quality = compression.quality if compression else None if algorithm == CompressionAlgorithm.NONE: return stream if algorithm == CompressionAlgorithm.GZIP: level = 9 if quality is None else int(quality) return gzip.GzipFile(fileobj=stream, mode="wb", compresslevel=level) if algorithm == CompressionAlgorithm.BROTLI: return io.BufferedWriter(_BrotliWriter(stream, quality)) if algorithm == CompressionAlgorithm.ZSTD: try: import zstandard as zstd # type: ignore except ImportError as exc: raise CompressionError("zstandard module not available") from exc level = 3 if quality is None else int(quality) return zstd.ZstdCompressor(level=level).stream_writer(stream) raise CompressionError(f"unsupported compression algorithm: {algorithm}")