fix wrong average of psnr

This commit is contained in:
neulus
2025-10-02 19:40:00 +09:00
parent a601dc6095
commit 6bb6c09638
17 changed files with 221 additions and 39 deletions

View File

@@ -10,8 +10,8 @@ from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer
from src.rf import RF
checkpoint_path = "artifact/icy-field-12/checkpoint_epoch_260.pt"
device = "cuda:2"
checkpoint_path = "artifact/daily-forest-25/checkpoint_final.pt"
device = "cuda:1"
save_dir = "test_images"
os.makedirs(save_dir, exist_ok=True)
@@ -28,7 +28,7 @@ rf.model.eval()
_, test_dataset = get_dataset()
batch_size = 1
batch_size = 8
psnr_sum = 0
ssim_sum = 0
lpips_sum = 0
@@ -39,7 +39,7 @@ max_save = 10
with torch.no_grad():
for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"):
batch = test_dataset[i : i + batch_size]
images = rf.sample(batch["cloud"].to(device))
images = rf.sample_heun(batch["cloud"].to(device), 1)
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
@@ -49,12 +49,13 @@ with torch.no_grad():
save_image(image[j], f"{save_dir}/pred_{saved_count}.png")
save_image(original[j], f"{save_dir}/gt_{saved_count}.png")
save_image(
denormalize(batch["x0"][j]).clamp(0, 1),
denormalize(batch["cloud"][j]).clamp(0, 1),
f"{save_dir}/input_{saved_count}.png",
)
saved_count += 1
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
print(psnr, ssim, lpips)
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()