fix wrong average of psnr
This commit is contained in:
21
main.py
21
main.py
@@ -2,16 +2,17 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torchvision.utils import make_grid
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from src.benchmark import benchmark
|
||||
from src.dataset.cuhk_cr1 import get_dataset
|
||||
from src.dataset.cuhk_cr2 import get_dataset
|
||||
from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
device = "cuda:2"
|
||||
device = "cuda:1"
|
||||
|
||||
model = UTransformer.from_pretrained_backbone(
|
||||
"facebook/dinov3-vitl16-pretrain-sat493m"
|
||||
@@ -21,7 +22,7 @@ optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
wandb.init(project="cloud-removal-kmu", id="icy-field-12", resume="allow")
|
||||
wandb.init(project="cloud-removal-kmu", id="dashing-moon-31", resume="allow")
|
||||
|
||||
if not (wandb.run and wandb.run.name):
|
||||
raise Exception("nope")
|
||||
@@ -36,7 +37,7 @@ if os.path.exists(checkpoint_path):
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
start_epoch = checkpoint["epoch"] + 1
|
||||
|
||||
batch_size = 4
|
||||
batch_size = 8
|
||||
accumulation_steps = 8
|
||||
total_epoch = 1000
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
@@ -89,10 +90,20 @@ for epoch in range(start_epoch, total_epoch):
|
||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["cloud"].to(device))
|
||||
images = rf.sample_heun(batch["cloud"].to(device))
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
|
||||
if i == 0:
|
||||
for step, demo in enumerate([images[0], images[-1]]):
|
||||
images = wandb.Image(
|
||||
make_grid(
|
||||
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
|
||||
),
|
||||
caption=f"step {step}",
|
||||
)
|
||||
wandb.log({"viz/decoded": images})
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
|
||||
Reference in New Issue
Block a user