upscale on middle
This commit is contained in:
81
main.py
81
main.py
@@ -137,50 +137,51 @@ for epoch in range(start_epoch, total_epoch):
|
||||
wandb.log(epoch_metrics)
|
||||
|
||||
if (epoch + 1) % 50 == 0:
|
||||
rf.model.eval()
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
lpips_sum = 0
|
||||
count = 0
|
||||
with autocast(dtype=torch.bfloat16):
|
||||
rf.model.eval()
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
lpips_sum = 0
|
||||
count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for i in tqdm(
|
||||
range(0, len(test_dataset), batch_size),
|
||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["cloud"].to(device))
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
with torch.no_grad():
|
||||
for i in tqdm(
|
||||
range(0, len(test_dataset), batch_size),
|
||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["cloud"].to(device))
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
|
||||
if i == 0:
|
||||
for step, demo in enumerate([images[0], images[-1]]):
|
||||
images = wandb.Image(
|
||||
make_grid(
|
||||
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
|
||||
),
|
||||
caption=f"step {step}",
|
||||
)
|
||||
wandb.log({"viz/decoded": images})
|
||||
if i == 0:
|
||||
for step, demo in enumerate([images[0], images[-1]]):
|
||||
images = wandb.Image(
|
||||
make_grid(
|
||||
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
|
||||
),
|
||||
caption=f"step {step}",
|
||||
)
|
||||
wandb.log({"viz/decoded": images})
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
count += image.shape[0]
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
count += image.shape[0]
|
||||
|
||||
avg_psnr = psnr_sum / count
|
||||
avg_ssim = ssim_sum / count
|
||||
avg_lpips = lpips_sum / count
|
||||
wandb.log(
|
||||
{
|
||||
"eval/psnr": avg_psnr,
|
||||
"eval/ssim": avg_ssim,
|
||||
"eval/lpips": avg_lpips,
|
||||
"epoch": epoch + 1,
|
||||
}
|
||||
)
|
||||
rf.model.train()
|
||||
avg_psnr = psnr_sum / count
|
||||
avg_ssim = ssim_sum / count
|
||||
avg_lpips = lpips_sum / count
|
||||
wandb.log(
|
||||
{
|
||||
"eval/psnr": avg_psnr,
|
||||
"eval/ssim": avg_ssim,
|
||||
"eval/lpips": avg_lpips,
|
||||
"epoch": epoch + 1,
|
||||
}
|
||||
)
|
||||
rf.model.train()
|
||||
|
||||
torch.save(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user