import lpips from pytorch_msssim import ssim from torchmetrics.image import ( PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, ) psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3)) lp = lpips.LPIPS(net="alex") def benchmark(image1, image2): return ( psnr(image1, image2), ssim( image1, image2, data_range=1.0, size_average=False, ), lp(image1 * 2 - 1, image2 * 2 - 1), )