import os import torch from PIL import Image from torchvision.utils import save_image from tqdm import tqdm from src.benchmark import benchmark from src.dataset.cuhk_cr1 import get_dataset from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF checkpoint_path = "artifact/firm-darkness-98/checkpoint_final.pt" device = "cuda:1" save_dir = "test_images" os.makedirs(save_dir, exist_ok=True) model = UTransformer.from_pretrained_backbone( "facebook/dinov3-vitl16-pretrain-sat493m" ).to(device) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) rf = RF(model) rf.model.eval() _, test_dataset = get_dataset() batch_size = 8 * 4 psnr_sum = 0 ssim_sum = 0 lpips_sum = 0 count = 0 saved_count = 0 max_save = 10 with torch.no_grad(): for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"): batch = test_dataset[i : i + batch_size] images = rf.sample(batch["cloud"].to(device), 1) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["gt"]).clamp(0, 1) if saved_count < max_save: for j in range(min(image.shape[0], max_save - saved_count)): save_image(image[j], f"{save_dir}/pred_{saved_count}.png") save_image(original[j], f"{save_dir}/gt_{saved_count}.png") save_image( denormalize(batch["cloud"][j]).clamp(0, 1), f"{save_dir}/input_{saved_count}.png", ) frames = [] for step_img in images: frame = denormalize(step_img[j]).clamp(0, 1) frame_np = (frame.permute(1, 2, 0).cpu().numpy() * 255).astype( "uint8" ) frames.append(Image.fromarray(frame_np)) frames[0].save( f"{save_dir}/transform_{saved_count}.gif", save_all=True, append_images=frames[1:], duration=100, loop=0, ) saved_count += 1 psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) print(psnr, ssim, lpips) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() lpips_sum += lpips.sum().item() count += image.shape[0] avg_psnr = psnr_sum / count avg_ssim = ssim_sum / count avg_lpips = lpips_sum / count print(f"PSNR: {avg_psnr:.4f}") print(f"SSIM: {avg_ssim:.4f}") print(f"LPIPS: {avg_lpips:.4f}")