Files
cloud-removal/main.py
2025-11-02 22:55:07 +09:00

216 lines
6.5 KiB
Python

import math
import os
import lovely_tensors as lt
import torch
import torch.optim as optim
from torch.cuda.amp import autocast
from torchvision.utils import make_grid
from tqdm import tqdm
import wandb
from src.benchmark import benchmark
from src.dataset.cuhk_cr2 import get_dataset
from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer
from src.rf import RF
lt.monkey_patch()
train_dataset, test_dataset = get_dataset()
device = "cuda:1"
batch_size = 32
accumulation_steps = 2
total_epoch = 500
steps_per_epoch = len(train_dataset) // (batch_size)
total_steps = steps_per_epoch * total_epoch
warmup_steps = int(0.05 * total_steps)
grad_norm = 1.0
model = (
UTransformer.from_pretrained_backbone("facebook/dinov3-vitl16-pretrain-sat493m")
.to(device)
.bfloat16()
)
rf = RF(model, "icfm", "lpips_mse")
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
# scheduler
def get_lr(step: int) -> float:
if step < warmup_steps:
return step / warmup_steps
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, get_lr)
wandb.init(project="cloud-removal-kmu", resume="allow")
# phase 2
# model.requires_grad_(True)
if not (wandb.run and wandb.run.name):
raise Exception("nope")
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
start_epoch = 0
checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if "scheduler_state_dict" in checkpoint:
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
start_epoch = checkpoint["epoch"] + 1
for epoch in range(start_epoch, total_epoch):
lossbin = {i: 0 for i in range(10)}
losscnt = {i: 1e-6 for i in range(10)}
train_dataset = train_dataset.shuffle(seed=epoch)
for i in tqdm(
range(0, len(train_dataset), batch_size),
desc=f"Epoch {epoch + 1}/{total_epoch}",
):
batch = train_dataset[i : i + batch_size]
cloud = batch["cloud"].to(device)
gt = batch["gt"].to(device)
with autocast(dtype=torch.bfloat16):
loss, blsct, loss_list = rf.forward(gt, cloud)
loss = loss / accumulation_steps
loss.backward()
if (i // batch_size + 1) % accumulation_steps == 0:
# total_norm = torch.nn.utils.clip_grad_norm_(
# model.parameters(), max_norm=grad_norm
# )
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# wandb.log(
# {
# "train/grad_norm": total_norm.item(),
# }
# )
wandb.log(
{
"train/loss": loss.item() * accumulation_steps,
"train/lr": scheduler.get_last_lr()[0],
}
)
wandb.log(loss_list)
for t, lss in blsct:
bin_idx = min(int(t * 10), 9)
lossbin[bin_idx] += lss
losscnt[bin_idx] += 1
if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0:
# total_norm = torch.nn.utils.clip_grad_norm_(
# model.parameters(), max_norm=grad_norm
# )
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# wandb.log(
# {
# "train/grad_norm": total_norm.item(),
# }
# )
epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
epoch_metrics["epoch"] = epoch
wandb.log(epoch_metrics)
if (epoch + 1) % 50 == 0:
with autocast(dtype=torch.bfloat16):
rf.model.eval()
psnr_sum = 0
ssim_sum = 0
lpips_sum = 0
count = 0
with torch.no_grad():
for i in tqdm(
range(0, len(test_dataset), batch_size),
desc=f"Benchmark {epoch + 1}/{total_epoch}",
):
batch = test_dataset[i : i + batch_size]
images = rf.sample(batch["cloud"].to(device))
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
if i == 0:
for step, demo in enumerate([images[0], images[-1]]):
images = wandb.Image(
make_grid(
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
),
caption=f"step {step}",
)
wandb.log({"viz/decoded": images})
psnr, ssim, lpips, flawed_lpips = benchmark(image.cpu(), original.cpu())
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image.shape[0]
avg_psnr = psnr_sum / count
avg_ssim = ssim_sum / count
avg_lpips = lpips_sum / count
wandb.log(
{
"eval/psnr": avg_psnr,
"eval/ssim": avg_ssim,
"eval/lpips": avg_lpips,
"epoch": epoch + 1,
}
)
rf.model.train()
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
},
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
)
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
},
checkpoint_path,
)
torch.save(
{
"epoch": epoch, # type: ignore
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
},
f"artifact/{wandb.run.name}/checkpoint_final.pt",
)
wandb.finish()