test
This commit is contained in:
69
quick_eval.py
Normal file
69
quick_eval.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.benchmark import benchmark
|
||||
from src.dataset.cuhk_cr1 import get_dataset
|
||||
from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
checkpoint_path = "artifact/wild-wave-3/checkpoint_epoch_100.pt"
|
||||
device = "cuda:0"
|
||||
save_dir = "test_images"
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
model = UTransformer.from_pretrained_backbone(
|
||||
"facebook/dinov3-vits16-pretrain-lvd1689m"
|
||||
).to(device)
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
rf = RF(model)
|
||||
rf.model.eval()
|
||||
|
||||
_, test_dataset = get_dataset()
|
||||
|
||||
batch_size = 32
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
lpips_sum = 0
|
||||
count = 0
|
||||
saved_count = 0
|
||||
max_save = 10
|
||||
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["x0"].to(device))
|
||||
|
||||
image = denormalize(images[-1]).clamp(0, 1) * 255
|
||||
original = denormalize(batch["x1"]).clamp(0, 1) * 255
|
||||
|
||||
if saved_count < max_save:
|
||||
for j in range(min(image.shape[0], max_save - saved_count)):
|
||||
save_image(image[j] / 255, f"{save_dir}/pred_{saved_count}.png")
|
||||
save_image(original[j] / 255, f"{save_dir}/gt_{saved_count}.png")
|
||||
save_image(
|
||||
denormalize(batch["x0"][j]).clamp(0, 1),
|
||||
f"{save_dir}/input_{saved_count}.png",
|
||||
)
|
||||
saved_count += 1
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
count += image.shape[0]
|
||||
|
||||
avg_psnr = psnr_sum / count
|
||||
avg_ssim = ssim_sum / count
|
||||
avg_lpips = lpips_sum / count
|
||||
|
||||
print(f"PSNR: {avg_psnr:.4f}")
|
||||
print(f"SSIM: {avg_ssim:.4f}")
|
||||
print(f"LPIPS: {avg_lpips:.4f}")
|
||||
Reference in New Issue
Block a user