import math import os from datetime import timedelta import lpips import torch import torch.optim as optim from accelerate import Accelerator, InitProcessGroupKwargs from torch.utils.data import DataLoader from torchvision.utils import make_grid from tqdm import tqdm import wandb from src.benchmark import benchmark from src.dataset.cuhk_cr2 import get_dataset from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF # --- Configuration --- batch_size = 16 accumulation_steps = 2 total_epoch = 500 grad_norm = 1.0 learning_rate = 3e-4 # --- Accelerator Setup --- # Set a longer timeout for initialization, which can be useful when downloading # large models or datasets on multiple nodes. kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800)) accelerator = Accelerator( gradient_accumulation_steps=accumulation_steps, mixed_precision="bf16", # Use "bf16" for modern GPUs, or "fp16" log_with="wandb", kwargs_handlers=[kwargs], ) # --- Dataset Loading --- # Load datasets on the main process. They will be accessible by all processes. train_dataset, test_dataset = get_dataset() train_dataset, test_dataset = ( DataLoader(train_dataset, batch_size=batch_size), DataLoader(test_dataset, batch_size=batch_size), ) # --- LR Scheduler Logic --- # Correctly calculate total steps based on optimizer updates, not micro-batches. # Use math.ceil to account for the last partial batch. num_batches_per_epoch = math.ceil(len(train_dataset) / batch_size) optimizer_steps_per_epoch = math.ceil(num_batches_per_epoch / accumulation_steps) total_steps = optimizer_steps_per_epoch * total_epoch warmup_steps = int(0.05 * total_steps) # --- Model, Optimizer, and RF Helper Initialization --- # Initialize on CPU. Accelerator will move them to the correct device. model = UTransformer.from_pretrained_backbone( "facebook/dinov3-vitl16-pretrain-sat493m" ).bfloat16() lp = lpips.LPIPS(net="vgg") rf = RF(model, "icfm", "lpips_mse", lp) # RF holds a reference to the model optimizer = optim.AdamW(model.parameters(), lr=learning_rate) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epoch) # --- Prepare objects with Accelerator --- # We prepare everything except the train_dataloader, which is created per-epoch # to allow for webdataset's shuffling mechanism. model, lp, optimizer, scheduler, train_dataset, test_dataset = accelerator.prepare( model, lp, optimizer, scheduler, train_dataset, test_dataset ) # --- W&B and Checkpoint Setup --- # Initialize tracker (wandb) on the main process accelerator.init_trackers( project_name="cloud-removal-kmu", config={ "batch_size": batch_size, "accumulation_steps": accumulation_steps, "total_epoch": total_epoch, "learning_rate": learning_rate, "grad_norm": grad_norm, "total_steps": total_steps, }, ) # Use the run name from the tracker for a consistent artifact path # This check is needed in case there are no trackers configured. run_name = "nerf-3" if accelerator.trackers: run_name = accelerator.trackers[0].run.name artifact_dir = f"artifact/{run_name}" checkpoint_dir = os.path.join(artifact_dir, "checkpoints") if accelerator.is_main_process: os.makedirs(checkpoint_dir, exist_ok=True) accelerator.wait_for_everyone() # Ensure directory is created before any process tries to access it # Register scheduler for checkpointing accelerator.register_for_checkpointing(scheduler) start_epoch = 0 # Check if a checkpoint exists to resume training if os.path.exists(checkpoint_dir): try: accelerator.print(f"Resuming from checkpoint: {checkpoint_dir}") accelerator.load_state(checkpoint_dir) # Manually load the epoch from a tracker file if os.path.exists(os.path.join(checkpoint_dir, "epoch_tracker.pt")): start_epoch = ( torch.load(os.path.join(checkpoint_dir, "epoch_tracker.pt"))["epoch"] + 1 ) except Exception as e: accelerator.print( f"Could not load checkpoint. Starting from scratch. Error: {e}" ) # --- Training Loop --- for epoch in range(start_epoch, total_epoch): model.train() lossbin = {i: 0 for i in range(10)} losscnt = {i: 1e-6 for i in range(10)} progress_bar = tqdm( train_dataset, disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch + 1}/{total_epoch}", ) for step, batch in enumerate(progress_bar): cloud, gt = batch["cloud"], batch["gt"] with accelerator.accumulate(model): # Forward pass is automatically handled with mixed precision loss, blsct, loss_list = rf.forward(gt, cloud) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), grad_norm) optimizer.step() optimizer.zero_grad() # Log metrics if accelerator.sync_gradients: avg_loss = accelerator.gather(loss).mean().item() current_step = epoch * optimizer_steps_per_epoch + ( step // accumulation_steps ) accelerator.log( { "train/loss": avg_loss, "train/lr": scheduler.get_last_lr()[0], }, ) accelerator.log(loss_list) # This per-process logging is an approximation. For perfect accuracy, # `blsct` would need to be gathered from all processes. for t, lss in blsct: bin_idx = min(int(t * 10), 9) lossbin[bin_idx] += lss losscnt[bin_idx] += 1 # Log epoch-level metrics from the main process if accelerator.is_main_process: epoch_metrics = { f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10) } epoch_metrics["epoch"] = epoch accelerator.log(epoch_metrics) # --- Evaluation and Checkpointing --- if (epoch + 1) % 50 == 0: model.eval() psnr_sum, ssim_sum, lpips_sum, count = 0.0, 0.0, 0.0, 0 with torch.no_grad(): for i, batch in tqdm( enumerate(test_dataset), disable=not accelerator.is_local_main_process, desc=f"Benchmark {epoch + 1}/{total_epoch}", ): images = rf.sample(batch["cloud"]) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["gt"]).clamp(0, 1) # Gather results from all processes for accurate metrics image_gathered = accelerator.gather_for_metrics(image) original_gathered = accelerator.gather_for_metrics(original) if accelerator.is_main_process: # Log visualization images from the first batch on the main process if i == 0: demo_images = [images[0][:4], images[-1][:4]] for step_idx, demo in enumerate(demo_images): grid = make_grid( denormalize(demo).clamp(0, 1).float().cpu(), nrow=2 ) wandb_image = wandb.Image(grid, caption=f"step {step_idx}") accelerator.log({"viz/decoded": wandb_image}) psnr, ssim, lpips, flawed_lpips = benchmark( image_gathered.cpu(), original_gathered.cpu() ) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() lpips_sum += lpips.sum().item() count += image_gathered.shape[0] accelerator.wait_for_everyone() if accelerator.is_main_process: avg_psnr = psnr_sum / count if count > 0 else 0 avg_ssim = ssim_sum / count if count > 0 else 0 avg_lpips = lpips_sum / count if count > 0 else 0 accelerator.log( { "eval/psnr": avg_psnr, "eval/ssim": avg_ssim, "eval/lpips": avg_lpips, "epoch": epoch + 1, } ) # Save checkpoint on the main process accelerator.save_state(os.path.join(checkpoint_dir, f"epoch_{epoch + 1}")) accelerator.save_state(checkpoint_dir) # Overwrite latest torch.save( {"epoch": epoch}, os.path.join(checkpoint_dir, "epoch_tracker.pt") ) # scheduler.step() # --- Final Save and Cleanup --- accelerator.wait_for_everyone() if accelerator.is_main_process: accelerator.print("Saving final model state.") accelerator.save_state(checkpoint_dir) torch.save( {"epoch": total_epoch - 1}, os.path.join(checkpoint_dir, "epoch_tracker.pt") ) accelerator.end_training()