import base64 import hashlib import random import string import struct import typing from enum import Enum from functools import cache import requests from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.serialization import load_der_public_key from pypdf import PdfReader, _encryption from pypdf.generic import ( ArrayObject, ByteStringObject, DictionaryObject, IndirectObject, StreamObject, TextStringObject, create_string_object, ) from src.crypto import ( aes_decrypt, aes_encrypt, des3_decrypt, des3_encrypt, makekey_16, makekey_32, rc4_decrypt, ) from src.device import Device from src.drminfo import DRMInfo from src.utils import unpad_aes class Algorithm(Enum): RC4 = 0 AES = 1 AES256 = 2 class EZDRMEncryption(_encryption.Encryption): def __init__(self, device: Device, docs: PdfReader, encrypt_ref: IndirectObject): # load drm info drm = typing.cast(dict, encrypt_ref.get_object()) assert drm is not None and drm.get("/Filter") == "/UDOC_EZDRM" version = drm["/VER"] info = drm["/INFO"] did: str = drm["/DID"] algorithm = Algorithm.RC4 if int(version) == 2: algorithm = Algorithm([int((version - 2) * 10)]) if version < 4: payload = aes_decrypt(makekey_16(did.encode()), base64.b64decode(info)) else: key = did.encode() if len(did) == 16 else makekey_16(did.encode()) for _ in range(11): key = hashlib.sha256(key).digest() payload = aes_decrypt(key, info.encode()) # setup session session = requests.Session() session.headers.update({"User-Agent": device.user_agent}) self.docs = docs self.version = version self.info = DRMInfo(payload[4:]) self.algorithm = algorithm self.session = session self.encrypt_ref = encrypt_ref self.device = device def request(self, url: str, params: str) -> str: ticket = "".join(random.choices(string.digits, k=0x20)) skx = self.public_key.encrypt(ticket.encode(), padding.PKCS1v15()).hex().upper() sha_ticket = hashlib.sha256(ticket.encode()).digest() dx = ( ( des3_encrypt(sha_ticket, params.encode()) if self.version < 4 else aes_encrypt(sha_ticket, params.encode()) ) .hex() .upper() ) ciphertext = ( self.session.get(f"{url}&skx={skx}&dx={dx}").content.strip().decode() ) return ( des3_decrypt(sha_ticket, bytes.fromhex(ciphertext)) if self.version < 4 else aes_decrypt(sha_ticket, bytes.fromhex(ciphertext)) ).decode("euc-kr") @property @cache def public_key(self) -> RSAPublicKey: pkey = self.session.get( f"{self.info.protocol1}{self.info.server1}:{self.info.port1}{self.info.reserved}" ).content.strip() return typing.cast( RSAPublicKey, load_der_public_key(bytes.fromhex(pkey.decode()), default_backend()), ) @property @cache def open_key(self) -> bytes: resp = self.request( f"{self.info.protocol1}{self.info.server1}:{self.info.port1}{self.info.open}", self.device.open_params(self.info.doc_id, len(self.docs.pages), None), ) if not resp.startswith("ACK,1,"): raise Exception("docs expired") return ( makekey_16(resp.split(",")[2].encode()) if self.version < 4 else makekey_32(resp.split(",")[2].encode()) ) def _perform_decrypt(self, objid: int, genno: int, data: bytes, attrs) -> bytes: if self.algorithm == Algorithm.RC4: key = ( self.open_key + struct.pack(" bool: return True def decrypt_object( self, obj: typing.Any, idnum: int, generation: int ) -> typing.Any: if ( self.encrypt_ref is not None and idnum == self.encrypt_ref.idnum and generation == self.encrypt_ref.generation ): return obj return self._dec(obj, idnum, generation) def _dec(self, obj: typing.Any, idnum: int, generation: int) -> typing.Any: if isinstance(obj, IndirectObject): return obj if isinstance(obj, (ByteStringObject, TextStringObject)): plain = self._perform_decrypt(idnum, generation, obj.original_bytes, None) return create_string_object(plain) if isinstance(obj, StreamObject): attrs = {str(k): v for k, v in obj.items()} obj._data = self._perform_decrypt(idnum, generation, obj._data, attrs) for k, v in list(obj.items()): obj[k] = self._dec(v, idnum, generation) return obj if isinstance(obj, DictionaryObject): for k, v in list(obj.items()): obj[k] = self._dec(v, idnum, generation) return obj if isinstance(obj, ArrayObject): for i in range(len(obj)): obj[i] = self._dec(obj[i], idnum, generation) return obj return obj