import math import os import lovely_tensors as lt import torch import torch.optim as optim from torch.cuda.amp import autocast from torchvision.utils import make_grid from tqdm import tqdm import wandb from src.benchmark import benchmark from src.dataset.cuhk_cr2 import get_dataset from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF lt.monkey_patch() train_dataset, test_dataset = get_dataset() device = "cuda:1" batch_size = 16 accumulation_steps = 2 total_epoch = 500 steps_per_epoch = len(train_dataset) // batch_size total_steps = steps_per_epoch * total_epoch warmup_steps = int(0.05 * total_steps) grad_norm = 1.0 model = ( UTransformer.from_pretrained_backbone("facebook/dinov3-vitl16-pretrain-sat493m") .to(device) .bfloat16() ) rf = RF(model, "icfm", "lpips_mse") optimizer = optim.AdamW(model.parameters(), lr=3e-4) # scheduler def get_lr(step: int) -> float: if step < warmup_steps: return step / warmup_steps else: progress = (step - warmup_steps) / (total_steps - warmup_steps) return 0.5 * (1 + math.cos(math.pi * progress)) scheduler = optim.lr_scheduler.LambdaLR(optimizer, get_lr) wandb.init(project="cloud-removal-kmu", resume="allow") # phase 2 # model.requires_grad_(True) if not (wandb.run and wandb.run.name): raise Exception("nope") os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True) start_epoch = 0 checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt" if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if "scheduler_state_dict" in checkpoint: scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) start_epoch = checkpoint["epoch"] + 1 for epoch in range(start_epoch, total_epoch): lossbin = {i: 0 for i in range(10)} losscnt = {i: 1e-6 for i in range(10)} train_dataset = train_dataset.shuffle(seed=epoch) for i in tqdm( range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/{total_epoch}", ): batch = train_dataset[i : i + batch_size] cloud = batch["cloud"].to(device) gt = batch["gt"].to(device) with autocast(dtype=torch.bfloat16): loss, blsct, loss_list = rf.forward(gt, cloud) loss = loss / accumulation_steps loss.backward() if (i // batch_size + 1) % accumulation_steps == 0: # total_norm = torch.nn.utils.clip_grad_norm_( # model.parameters(), max_norm=grad_norm # ) optimizer.step() scheduler.step() optimizer.zero_grad() # wandb.log( # { # "train/grad_norm": total_norm.item(), # } # ) wandb.log( { "train/loss": loss.item() * accumulation_steps, "train/lr": scheduler.get_last_lr()[0], } ) wandb.log(loss_list) for t, lss in blsct: bin_idx = min(int(t * 10), 9) lossbin[bin_idx] += lss losscnt[bin_idx] += 1 if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0: # total_norm = torch.nn.utils.clip_grad_norm_( # model.parameters(), max_norm=grad_norm # ) optimizer.step() scheduler.step() optimizer.zero_grad() # wandb.log( # { # "train/grad_norm": total_norm.item(), # } # ) epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)} epoch_metrics["epoch"] = epoch wandb.log(epoch_metrics) if (epoch + 1) % 50 == 0: rf.model.eval() psnr_sum = 0 ssim_sum = 0 lpips_sum = 0 count = 0 with torch.no_grad(): for i in tqdm( range(0, len(test_dataset), batch_size), desc=f"Benchmark {epoch + 1}/{total_epoch}", ): batch = test_dataset[i : i + batch_size] images = rf.sample(batch["cloud"].to(device)) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["gt"]).clamp(0, 1) if i == 0: for step, demo in enumerate([images[0], images[-1]]): images = wandb.Image( make_grid( denormalize(demo).clamp(0, 1).float()[:4], nrow=2 ), caption=f"step {step}", ) wandb.log({"viz/decoded": images}) psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() lpips_sum += lpips.sum().item() count += image.shape[0] avg_psnr = psnr_sum / count avg_ssim = ssim_sum / count avg_lpips = lpips_sum / count wandb.log( { "eval/psnr": avg_psnr, "eval/ssim": avg_ssim, "eval/lpips": avg_lpips, "epoch": epoch + 1, } ) rf.model.train() torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), }, f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt", ) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), }, checkpoint_path, ) torch.save( { "epoch": epoch, # type: ignore "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), }, f"artifact/{wandb.run.name}/checkpoint_final.pt", ) wandb.finish()