159 lines
5.0 KiB
Python
159 lines
5.0 KiB
Python
from __future__ import annotations
|
|
|
|
import io
|
|
import gzip
|
|
from typing import BinaryIO
|
|
|
|
from .proto import CompressionAlgorithm, CompressionSettings
|
|
|
|
|
|
class CompressionError(Exception):
|
|
pass
|
|
|
|
|
|
def _normalize_algorithm(algorithm: CompressionAlgorithm | int | None) -> CompressionAlgorithm:
|
|
if algorithm is None:
|
|
return CompressionAlgorithm.NONE
|
|
if isinstance(algorithm, CompressionAlgorithm):
|
|
return algorithm
|
|
try:
|
|
return CompressionAlgorithm(int(algorithm))
|
|
except ValueError as exc:
|
|
raise CompressionError(f"unknown compression algorithm: {algorithm}") from exc
|
|
|
|
|
|
def _brotli_module():
|
|
try:
|
|
import brotli # type: ignore
|
|
|
|
return brotli
|
|
except ImportError:
|
|
try:
|
|
import brotlicffi as brotli # type: ignore
|
|
|
|
return brotli
|
|
except ImportError as exc:
|
|
raise CompressionError("brotli module not available") from exc
|
|
|
|
|
|
class _BrotliReader(io.RawIOBase):
|
|
def __init__(self, raw: BinaryIO):
|
|
self._raw = raw
|
|
brotli = _brotli_module()
|
|
self._decompressor = brotli.Decompressor()
|
|
self._buffer = bytearray()
|
|
self._eof = False
|
|
|
|
def readable(self) -> bool:
|
|
return True
|
|
|
|
def _finish(self) -> bytes:
|
|
if hasattr(self._decompressor, "finish"):
|
|
return self._decompressor.finish()
|
|
if hasattr(self._decompressor, "flush"):
|
|
return self._decompressor.flush()
|
|
return b""
|
|
|
|
def read(self, size: int = -1) -> bytes:
|
|
if size == 0:
|
|
return b""
|
|
|
|
while not self._eof and (size < 0 or len(self._buffer) < size):
|
|
chunk = self._raw.read(8192)
|
|
if not chunk:
|
|
self._buffer.extend(self._finish())
|
|
self._eof = True
|
|
break
|
|
self._buffer.extend(self._decompressor.process(chunk))
|
|
|
|
if size < 0:
|
|
data = bytes(self._buffer)
|
|
self._buffer.clear()
|
|
return data
|
|
|
|
data = bytes(self._buffer[:size])
|
|
del self._buffer[:size]
|
|
return data
|
|
|
|
def readinto(self, b) -> int:
|
|
data = self.read(len(b))
|
|
n = len(data)
|
|
b[:n] = data
|
|
return n
|
|
|
|
def close(self) -> None:
|
|
if self.closed:
|
|
return
|
|
try:
|
|
if hasattr(self._raw, "close"):
|
|
self._raw.close()
|
|
finally:
|
|
super().close()
|
|
|
|
|
|
class _BrotliWriter(io.RawIOBase):
|
|
def __init__(self, raw: BinaryIO, quality: int | None):
|
|
self._raw = raw
|
|
brotli = _brotli_module()
|
|
kwargs = {}
|
|
if quality is not None:
|
|
kwargs["quality"] = int(quality)
|
|
self._compressor = brotli.Compressor(**kwargs)
|
|
|
|
def writable(self) -> bool:
|
|
return True
|
|
|
|
def write(self, b: bytes) -> int:
|
|
out = self._compressor.process(b)
|
|
self._raw.write(out)
|
|
return len(b)
|
|
|
|
def close(self) -> None:
|
|
if self.closed:
|
|
return
|
|
try:
|
|
tail = self._compressor.finish()
|
|
if tail:
|
|
self._raw.write(tail)
|
|
if hasattr(self._raw, "close"):
|
|
self._raw.close()
|
|
finally:
|
|
super().close()
|
|
|
|
|
|
def open_decompressed_reader(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO:
|
|
algorithm = _normalize_algorithm(compression.algorithm if compression else None)
|
|
if algorithm == CompressionAlgorithm.NONE:
|
|
return stream
|
|
if algorithm == CompressionAlgorithm.GZIP:
|
|
return gzip.GzipFile(fileobj=stream, mode="rb")
|
|
if algorithm == CompressionAlgorithm.BROTLI:
|
|
return io.BufferedReader(_BrotliReader(stream))
|
|
if algorithm == CompressionAlgorithm.ZSTD:
|
|
try:
|
|
import zstandard as zstd # type: ignore
|
|
except ImportError as exc:
|
|
raise CompressionError("zstandard module not available") from exc
|
|
return zstd.ZstdDecompressor().stream_reader(stream)
|
|
raise CompressionError(f"unsupported compression algorithm: {algorithm}")
|
|
|
|
|
|
def open_compressed_writer(stream: BinaryIO, compression: CompressionSettings | None) -> BinaryIO:
|
|
algorithm = _normalize_algorithm(compression.algorithm if compression else None)
|
|
quality = compression.quality if compression else None
|
|
if algorithm == CompressionAlgorithm.NONE:
|
|
return stream
|
|
if algorithm == CompressionAlgorithm.GZIP:
|
|
level = 9 if quality is None else int(quality)
|
|
return gzip.GzipFile(fileobj=stream, mode="wb", compresslevel=level)
|
|
if algorithm == CompressionAlgorithm.BROTLI:
|
|
return io.BufferedWriter(_BrotliWriter(stream, quality))
|
|
if algorithm == CompressionAlgorithm.ZSTD:
|
|
try:
|
|
import zstandard as zstd # type: ignore
|
|
except ImportError as exc:
|
|
raise CompressionError("zstandard module not available") from exc
|
|
level = 3 if quality is None else int(quality)
|
|
return zstd.ZstdCompressor(level=level).stream_writer(stream)
|
|
raise CompressionError(f"unsupported compression algorithm: {algorithm}")
|