Files
cloud-removal/quick_eval.py
2025-10-13 23:14:44 +09:00

89 lines
2.6 KiB
Python

import os
import torch
from PIL import Image
from torchvision.utils import save_image
from tqdm import tqdm
from src.benchmark import benchmark
from src.dataset.cuhk_cr1 import get_dataset
from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer
from src.rf import RF
checkpoint_path = "artifact/firm-darkness-98/checkpoint_final.pt"
device = "cuda:1"
save_dir = "test_images"
os.makedirs(save_dir, exist_ok=True)
model = UTransformer.from_pretrained_backbone(
"facebook/dinov3-vitl16-pretrain-sat493m"
).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
rf = RF(model)
rf.model.eval()
_, test_dataset = get_dataset()
batch_size = 8 * 4
psnr_sum = 0
ssim_sum = 0
lpips_sum = 0
count = 0
saved_count = 0
max_save = 10
with torch.no_grad():
for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"):
batch = test_dataset[i : i + batch_size]
images = rf.sample(batch["cloud"].to(device), 1)
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
if saved_count < max_save:
for j in range(min(image.shape[0], max_save - saved_count)):
save_image(image[j], f"{save_dir}/pred_{saved_count}.png")
save_image(original[j], f"{save_dir}/gt_{saved_count}.png")
save_image(
denormalize(batch["cloud"][j]).clamp(0, 1),
f"{save_dir}/input_{saved_count}.png",
)
frames = []
for step_img in images:
frame = denormalize(step_img[j]).clamp(0, 1)
frame_np = (frame.permute(1, 2, 0).cpu().numpy() * 255).astype(
"uint8"
)
frames.append(Image.fromarray(frame_np))
frames[0].save(
f"{save_dir}/transform_{saved_count}.gif",
save_all=True,
append_images=frames[1:],
duration=100,
loop=0,
)
saved_count += 1
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
print(psnr, ssim, lpips)
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image.shape[0]
avg_psnr = psnr_sum / count
avg_ssim = ssim_sum / count
avg_lpips = lpips_sum / count
print(f"PSNR: {avg_psnr:.4f}")
print(f"SSIM: {avg_ssim:.4f}")
print(f"LPIPS: {avg_lpips:.4f}")