import os import torch import torch.optim as optim from tqdm import tqdm import wandb from src.benchmark import benchmark from src.dataset.cuhk_cr1 import get_dataset from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF device = "cuda:3" model = UTransformer.from_pretrained_backbone( "facebook/dinov3-vits16-pretrain-lvd1689m" ).to(device) rf = RF(model) optimizer = optim.AdamW(model.parameters(), lr=5e-4) train_dataset, test_dataset = get_dataset() wandb.init(project="cloud-removal-kmu") if not (wandb.run and wandb.run.name): raise Exception("nope") os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True) batch_size = 32 for epoch in range(100): lossbin = {i: 0 for i in range(10)} losscnt = {i: 1e-6 for i in range(10)} train_dataset = train_dataset.shuffle(seed=epoch) for i in tqdm( range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/100" ): batch = train_dataset[i : i + batch_size] x0 = batch["x0"].to(device) x1 = batch["x1"].to(device) optimizer.zero_grad() loss, blsct = rf.forward(x0, x1) loss.backward() optimizer.step() wandb.log({"loss": loss.item()}) for t, lss in blsct: bin_idx = min(int(t * 10), 9) lossbin[bin_idx] += lss losscnt[bin_idx] += 1 epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)} epoch_metrics["epoch"] = epoch wandb.log(epoch_metrics) if (epoch + 1) % 10 == 0: # bench rf.model.eval() psnr_sum = 0 ssim_sum = 0 lpips_sum = 0 count = 0 with torch.no_grad(): for i in tqdm( range(0, len(test_dataset), batch_size), desc=f"Benchmark {epoch + 1}/100", ): batch = test_dataset[i : i + batch_size] images = rf.sample(batch["x0"].to(device)) image = denormalize(images[-1]).clamp(0, 1) * 255 original = denormalize(batch["x1"]).clamp(0, 1) * 255 psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() lpips_sum += lpips.sum().item() count += image.shape[0] avg_psnr = psnr_sum / count avg_ssim = ssim_sum / count avg_lpips = lpips_sum / count wandb.log( { "eval/psnr": avg_psnr, "eval/ssim": avg_ssim, "eval/lpips": avg_lpips, "epoch": epoch + 1, } ) rf.model.train() torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt", ) torch.save( { "epoch": 100, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, f"artifact/{wandb.run.name}/checkpoint_final.pt", ) wandb.finish()