things
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.utils import save_image
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -10,7 +11,7 @@ from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
checkpoint_path = "artifact/daily-forest-25/checkpoint_final.pt"
|
||||
checkpoint_path = "artifact/firm-darkness-98/checkpoint_final.pt"
|
||||
device = "cuda:1"
|
||||
save_dir = "test_images"
|
||||
|
||||
@@ -28,7 +29,7 @@ rf.model.eval()
|
||||
|
||||
_, test_dataset = get_dataset()
|
||||
|
||||
batch_size = 8
|
||||
batch_size = 8 * 4
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
lpips_sum = 0
|
||||
@@ -39,7 +40,7 @@ max_save = 10
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample_heun(batch["cloud"].to(device), 1)
|
||||
images = rf.sample(batch["cloud"].to(device), 1)
|
||||
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
@@ -52,6 +53,23 @@ with torch.no_grad():
|
||||
denormalize(batch["cloud"][j]).clamp(0, 1),
|
||||
f"{save_dir}/input_{saved_count}.png",
|
||||
)
|
||||
|
||||
frames = []
|
||||
for step_img in images:
|
||||
frame = denormalize(step_img[j]).clamp(0, 1)
|
||||
frame_np = (frame.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
"uint8"
|
||||
)
|
||||
frames.append(Image.fromarray(frame_np))
|
||||
|
||||
frames[0].save(
|
||||
f"{save_dir}/transform_{saved_count}.gif",
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
duration=100,
|
||||
loop=0,
|
||||
)
|
||||
|
||||
saved_count += 1
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
|
||||
Reference in New Issue
Block a user