import math import os from datetime import timedelta import lpips import torch import torch.optim as optim from accelerate import Accelerator, InitProcessGroupKwargs from hdit import HDiT from torch.utils.data import DataLoader from torchvision.utils import make_grid from tqdm import tqdm import wandb from src.benchmark import benchmark from src.dataset.cuhk_cr2 import get_dataset from src.dataset.preprocess import denormalize from src.rf import RF # --- Configuration --- batch_size = 4 accumulation_steps = 2 total_epoch = 500 grad_norm = 1.0 learning_rate = 3e-4 # --- Accelerator Setup --- # Set a longer timeout for initialization, which can be useful when downloading # large models or datasets on multiple nodes. kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800)) accelerator = Accelerator( gradient_accumulation_steps=accumulation_steps, mixed_precision="bf16", # Use "bf16" for modern GPUs, or "fp16" log_with="wandb", kwargs_handlers=[kwargs], ) # --- Dataset Loading --- # Load datasets on the main process. They will be accessible by all processes. train_dataset, test_dataset = get_dataset() train_dataset, test_dataset = ( DataLoader(train_dataset, batch_size=batch_size), # type: ignore DataLoader(test_dataset, batch_size=batch_size), # type: ignore ) num_batches_per_epoch = math.ceil(len(train_dataset) / batch_size) optimizer_steps_per_epoch = math.ceil(num_batches_per_epoch / accumulation_steps) total_steps = optimizer_steps_per_epoch * total_epoch warmup_steps = int(0.05 * total_steps) # --- Model, Optimizer, and RF Helper Initialization --- # Initialize on CPU. Accelerator will move them to the correct device. model = HDiT( in_channels=4, out_channels=4, patch_size=[1, 1], # type: ignore widths=[256, 512], middle_width=1024, depths=[4, 4], middle_depth=8, mapping_width=512, mapping_depth=2, ).bfloat16() print(sum(p.numel() for p in model.parameters() if p.requires_grad), "params") lp = lpips.LPIPS(net="vgg") rf = RF(model, "icfm", "lpips_mse", lp) # RF holds a reference to the model optimizer = optim.AdamW(model.parameters(), lr=learning_rate) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epoch) # --- Prepare objects with Accelerator --- # We prepare everything except the train_dataloader, which is created per-epoch # to allow for webdataset's shuffling mechanism. model, lp, optimizer, scheduler, train_dataset, test_dataset = accelerator.prepare( model, lp, optimizer, scheduler, train_dataset, test_dataset ) # --- W&B and Checkpoint Setup --- # Initialize tracker (wandb) on the main process accelerator.init_trackers( project_name="cloud-removal-kmu", config={ "batch_size": batch_size, "accumulation_steps": accumulation_steps, "total_epoch": total_epoch, "learning_rate": learning_rate, "grad_norm": grad_norm, "total_steps": total_steps, }, ) # Use the run name from the tracker for a consistent artifact path # This check is needed in case there are no trackers configured. run_name = "testest" if accelerator.trackers: run_name = accelerator.trackers[0].run.name artifact_dir = f"artifact/{run_name}" checkpoint_dir = os.path.join(artifact_dir, "checkpoints") if accelerator.is_main_process: os.makedirs(checkpoint_dir, exist_ok=True) accelerator.wait_for_everyone() # Ensure directory is created before any process tries to access it # Register scheduler for checkpointing accelerator.register_for_checkpointing(scheduler) start_epoch = 0 # Check if a checkpoint exists to resume training if os.path.exists(checkpoint_dir): try: accelerator.print(f"Resuming from checkpoint: {checkpoint_dir}") accelerator.load_state(checkpoint_dir) # Manually load the epoch from a tracker file if os.path.exists(os.path.join(checkpoint_dir, "epoch_tracker.pt")): start_epoch = ( torch.load(os.path.join(checkpoint_dir, "epoch_tracker.pt"))["epoch"] + 1 ) except Exception as e: accelerator.print( f"Could not load checkpoint. Starting from scratch. Error: {e}" ) # --- Training Loop --- for epoch in range(start_epoch, total_epoch): model.train() lossbin = {i: 0 for i in range(10)} losscnt = {i: 1e-6 for i in range(10)} progress_bar = tqdm( train_dataset, disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch + 1}/{total_epoch}", ) for step, batch in enumerate(progress_bar): cloud, gt = batch["cloud"], batch["gt"] with accelerator.accumulate(model): # Forward pass is automatically handled with mixed precision loss, blsct, loss_list = rf.forward( torch.cat((batch["gt"], batch["gt_nir"]), dim=1), torch.cat((batch["cloud"], batch["cloud_nir"]), dim=1), ) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), grad_norm) optimizer.step() optimizer.zero_grad() # Log metrics if accelerator.sync_gradients: avg_loss = accelerator.gather(loss).mean().item() # type: ignore current_step = epoch * optimizer_steps_per_epoch + ( step // accumulation_steps ) accelerator.log( { "train/loss": avg_loss, "train/lr": scheduler.get_last_lr()[0], }, # step=current_step, ) accelerator.log(loss_list) # This per-process logging is an approximation. For perfect accuracy, # `blsct` would need to be gathered from all processes. for t, lss in blsct: bin_idx = min(int(t * 10), 9) lossbin[bin_idx] += lss losscnt[bin_idx] += 1 # Log epoch-level metrics from the main process if accelerator.is_main_process: epoch_metrics = { f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10) } epoch_metrics["epoch"] = epoch accelerator.log(epoch_metrics) # --- Evaluation and Checkpointing --- if (epoch + 1) % 50 == 0: model.eval() psnr_sum, ssim_sum, lpips_sum, flawed_lpips_sum, count = 0.0, 0.0, 0.0, 0.0, 0 with torch.no_grad(): for i, batch in tqdm( enumerate(test_dataset), disable=not accelerator.is_local_main_process, desc=f"Benchmark {epoch + 1}/{total_epoch}", ): images = rf.sample( torch.cat((batch["cloud"], batch["cloud_nir"]), dim=1) ) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["gt"]).clamp(0, 1) # Gather results from all processes for accurate metrics image_gathered = accelerator.gather_for_metrics(image) original_gathered = accelerator.gather_for_metrics(original) if accelerator.is_main_process: # Log visualization images from the first batch on the main process if i == 0: demo_images = [images[0][:4], images[-1][:4]] for step_idx, demo in enumerate(demo_images): grid = make_grid( denormalize(demo).clamp(0, 1).float().cpu(), nrow=2 ) wandb_image = wandb.Image(grid, caption=f"step {step_idx}") accelerator.log({"viz/decoded": wandb_image}) psnr, ssim, lpips, flawed_lpips = benchmark( image_gathered.cpu(), # type: ignore original_gathered.cpu(), # type: ignore ) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() lpips_sum += lpips.sum().item() flawed_lpips_sum += flawed_lpips.sum().item() count += image_gathered.shape[0] # type: ignore accelerator.wait_for_everyone() if accelerator.is_main_process: avg_psnr = psnr_sum / count if count > 0 else 0 avg_ssim = ssim_sum / count if count > 0 else 0 avg_lpips = lpips_sum / count if count > 0 else 0 avg_flawed_lpips = flawed_lpips_sum / count if count > 0 else 0 accelerator.log( { "eval/psnr": avg_psnr, "eval/ssim": avg_ssim, "eval/lpips": avg_lpips, "eval/flawed_lpips": avg_flawed_lpips, "epoch": epoch + 1, } ) # Save checkpoint on the main process accelerator.save_state(os.path.join(checkpoint_dir, f"epoch_{epoch + 1}")) accelerator.save_state(checkpoint_dir) # Overwrite latest torch.save( {"epoch": epoch}, os.path.join(checkpoint_dir, "epoch_tracker.pt") ) # scheduler.step() # --- Final Save and Cleanup --- accelerator.wait_for_everyone() if accelerator.is_main_process: accelerator.print("Saving final model state.") accelerator.save_state(checkpoint_dir) torch.save( {"epoch": total_epoch - 1}, os.path.join(checkpoint_dir, "epoch_tracker.pt") ) accelerator.end_training()