import os import torch import torch.optim as optim from tqdm import tqdm import wandb from src.benchmark import benchmark from src.dataset.cuhk_cr1 import get_dataset from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF device = "cuda:0" model = UTransformer.from_pretrained_backbone( "facebook/dinov3-vitl16-pretrain-sat493m" ).to(device) rf = RF(model) optimizer = optim.AdamW(model.parameters(), lr=1e-4) train_dataset, test_dataset = get_dataset() wandb.init(project="cloud-removal-kmu", id="icy-field-11", resume="allow") if not (wandb.run and wandb.run.name): raise Exception("nope") os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True) start_epoch = 0 checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt" if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] + 1 batch_size = 4 accumulation_steps = 4 total_epoch = 1000 for epoch in range(start_epoch, total_epoch): lossbin = {i: 0 for i in range(10)} losscnt = {i: 1e-6 for i in range(10)} train_dataset = train_dataset.shuffle(seed=epoch) for i in tqdm( range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/{total_epoch}", ): batch = train_dataset[i : i + batch_size] x0 = batch["x0"].to(device) x1 = batch["x1"].to(device) loss, blsct = rf.forward(x0, x1) loss = loss / accumulation_steps loss.backward() if (i // batch_size + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() wandb.log({"train/loss": loss.item() * accumulation_steps}) for t, lss in blsct: bin_idx = min(int(t * 10), 9) lossbin[bin_idx] += lss losscnt[bin_idx] += 1 if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0: optimizer.step() optimizer.zero_grad() epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)} epoch_metrics["epoch"] = epoch wandb.log(epoch_metrics) if (epoch + 1) % 10 == 0: rf.model.eval() psnr_sum = 0 ssim_sum = 0 lpips_sum = 0 count = 0 with torch.no_grad(): for i in tqdm( range(0, len(test_dataset), batch_size), desc=f"Benchmark {epoch + 1}/{total_epoch}", ): batch = test_dataset[i : i + batch_size] images = rf.sample(batch["x0"].to(device)) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["x1"]).clamp(0, 1) psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() lpips_sum += lpips.sum().item() count += image.shape[0] avg_psnr = psnr_sum / count avg_ssim = ssim_sum / count avg_lpips = lpips_sum / count wandb.log( { "eval/psnr": avg_psnr, "eval/ssim": avg_ssim, "eval/lpips": avg_lpips, "epoch": epoch + 1, } ) rf.model.train() torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt", ) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, checkpoint_path, ) torch.save( { "epoch": epoch, # type: ignore "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, f"artifact/{wandb.run.name}/checkpoint_final.pt", ) wandb.finish()