115 lines
3.2 KiB
Python
115 lines
3.2 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
|
|
import wandb
|
|
from src.benchmark import benchmark
|
|
from src.dataset.cuhk_cr1 import get_dataset
|
|
from src.dataset.preprocess import denormalize
|
|
from src.model.utransformer import UTransformer
|
|
from src.rf import RF
|
|
|
|
device = "cuda:3"
|
|
|
|
model = UTransformer.from_pretrained_backbone(
|
|
"facebook/dinov3-vits16-pretrain-lvd1689m"
|
|
).to(device)
|
|
rf = RF(model)
|
|
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
|
|
|
|
train_dataset, test_dataset = get_dataset()
|
|
|
|
wandb.init(project="cloud-removal-kmu")
|
|
|
|
if not (wandb.run and wandb.run.name):
|
|
raise Exception("nope")
|
|
|
|
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
|
|
|
|
batch_size = 32
|
|
for epoch in range(100):
|
|
lossbin = {i: 0 for i in range(10)}
|
|
losscnt = {i: 1e-6 for i in range(10)}
|
|
|
|
train_dataset = train_dataset.shuffle(seed=epoch)
|
|
|
|
for i in tqdm(
|
|
range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/100"
|
|
):
|
|
batch = train_dataset[i : i + batch_size]
|
|
x0 = batch["x0"].to(device)
|
|
x1 = batch["x1"].to(device)
|
|
|
|
optimizer.zero_grad()
|
|
loss, blsct = rf.forward(x0, x1)
|
|
loss.backward()
|
|
optimizer.step()
|
|
wandb.log({"loss": loss.item()})
|
|
|
|
for t, lss in blsct:
|
|
bin_idx = min(int(t * 10), 9)
|
|
lossbin[bin_idx] += lss
|
|
losscnt[bin_idx] += 1
|
|
|
|
epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
|
|
epoch_metrics["epoch"] = epoch
|
|
wandb.log(epoch_metrics)
|
|
|
|
if (epoch + 1) % 10 == 0:
|
|
# bench
|
|
rf.model.eval()
|
|
psnr_sum = 0
|
|
ssim_sum = 0
|
|
lpips_sum = 0
|
|
count = 0
|
|
|
|
with torch.no_grad():
|
|
for i in tqdm(
|
|
range(0, len(test_dataset), batch_size),
|
|
desc=f"Benchmark {epoch + 1}/100",
|
|
):
|
|
batch = test_dataset[i : i + batch_size]
|
|
images = rf.sample(batch["x0"].to(device))
|
|
image = denormalize(images[-1]).clamp(0, 1)
|
|
original = denormalize(batch["x1"]).clamp(0, 1)
|
|
|
|
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
|
psnr_sum += psnr.sum().item()
|
|
ssim_sum += ssim.sum().item()
|
|
lpips_sum += lpips.sum().item()
|
|
count += image.shape[0]
|
|
|
|
avg_psnr = psnr_sum / count
|
|
avg_ssim = ssim_sum / count
|
|
avg_lpips = lpips_sum / count
|
|
wandb.log(
|
|
{
|
|
"eval/psnr": avg_psnr,
|
|
"eval/ssim": avg_ssim,
|
|
"eval/lpips": avg_lpips,
|
|
"epoch": epoch + 1,
|
|
}
|
|
)
|
|
rf.model.train()
|
|
|
|
torch.save(
|
|
{
|
|
"epoch": epoch,
|
|
"model_state_dict": model.state_dict(),
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
},
|
|
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
|
|
)
|
|
|
|
torch.save(
|
|
{
|
|
"epoch": 100,
|
|
"model_state_dict": model.state_dict(),
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
},
|
|
f"artifact/{wandb.run.name}/checkpoint_final.pt",
|
|
)
|
|
wandb.finish()
|