upscale on middle

This commit is contained in:
neulus
2025-10-26 15:42:35 +09:00
parent 6ab33ceb83
commit 09c1c2220a
2 changed files with 83 additions and 72 deletions

81
main.py
View File

@@ -137,50 +137,51 @@ for epoch in range(start_epoch, total_epoch):
wandb.log(epoch_metrics)
if (epoch + 1) % 50 == 0:
rf.model.eval()
psnr_sum = 0
ssim_sum = 0
lpips_sum = 0
count = 0
with autocast(dtype=torch.bfloat16):
rf.model.eval()
psnr_sum = 0
ssim_sum = 0
lpips_sum = 0
count = 0
with torch.no_grad():
for i in tqdm(
range(0, len(test_dataset), batch_size),
desc=f"Benchmark {epoch + 1}/{total_epoch}",
):
batch = test_dataset[i : i + batch_size]
images = rf.sample(batch["cloud"].to(device))
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
with torch.no_grad():
for i in tqdm(
range(0, len(test_dataset), batch_size),
desc=f"Benchmark {epoch + 1}/{total_epoch}",
):
batch = test_dataset[i : i + batch_size]
images = rf.sample(batch["cloud"].to(device))
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
if i == 0:
for step, demo in enumerate([images[0], images[-1]]):
images = wandb.Image(
make_grid(
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
),
caption=f"step {step}",
)
wandb.log({"viz/decoded": images})
if i == 0:
for step, demo in enumerate([images[0], images[-1]]):
images = wandb.Image(
make_grid(
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
),
caption=f"step {step}",
)
wandb.log({"viz/decoded": images})
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image.shape[0]
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image.shape[0]
avg_psnr = psnr_sum / count
avg_ssim = ssim_sum / count
avg_lpips = lpips_sum / count
wandb.log(
{
"eval/psnr": avg_psnr,
"eval/ssim": avg_ssim,
"eval/lpips": avg_lpips,
"epoch": epoch + 1,
}
)
rf.model.train()
avg_psnr = psnr_sum / count
avg_ssim = ssim_sum / count
avg_lpips = lpips_sum / count
wandb.log(
{
"eval/psnr": avg_psnr,
"eval/ssim": avg_ssim,
"eval/lpips": avg_lpips,
"epoch": epoch + 1,
}
)
rf.model.train()
torch.save(
{