import os import torch from torchvision.utils import save_image from tqdm import tqdm from src.benchmark import benchmark from src.dataset.cuhk_cr1 import get_dataset from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF checkpoint_path = "artifact/icy-field-12/checkpoint_epoch_260.pt" device = "cuda:2" save_dir = "test_images" os.makedirs(save_dir, exist_ok=True) model = UTransformer.from_pretrained_backbone( "facebook/dinov3-vitl16-pretrain-sat493m" ).to(device) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) rf = RF(model) rf.model.eval() _, test_dataset = get_dataset() batch_size = 1 psnr_sum = 0 ssim_sum = 0 lpips_sum = 0 count = 0 saved_count = 0 max_save = 10 with torch.no_grad(): for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"): batch = test_dataset[i : i + batch_size] images = rf.sample(batch["x0"].to(device)) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["x1"]).clamp(0, 1) if saved_count < max_save: for j in range(min(image.shape[0], max_save - saved_count)): save_image(image[j], f"{save_dir}/pred_{saved_count}.png") save_image(original[j], f"{save_dir}/gt_{saved_count}.png") save_image( denormalize(batch["x0"][j]).clamp(0, 1), f"{save_dir}/input_{saved_count}.png", ) saved_count += 1 psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() lpips_sum += lpips.sum().item() count += image.shape[0] avg_psnr = psnr_sum / count avg_ssim = ssim_sum / count avg_lpips = lpips_sum / count print(f"PSNR: {avg_psnr:.4f}") print(f"SSIM: {avg_ssim:.4f}") print(f"LPIPS: {avg_lpips:.4f}")