Files
cloud-removal/train_accelerator.py
2025-11-02 22:53:04 +09:00

248 lines
8.8 KiB
Python

import math
import os
from datetime import timedelta
import lpips
import torch
import torch.optim as optim
from accelerate import Accelerator, InitProcessGroupKwargs
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from tqdm import tqdm
import wandb
from src.benchmark import benchmark
from src.dataset.cuhk_cr2 import get_dataset
from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer
from src.rf import RF
# --- Configuration ---
batch_size = 16
accumulation_steps = 2
total_epoch = 500
grad_norm = 1.0
learning_rate = 3e-4
# --- Accelerator Setup ---
# Set a longer timeout for initialization, which can be useful when downloading
# large models or datasets on multiple nodes.
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
accelerator = Accelerator(
gradient_accumulation_steps=accumulation_steps,
mixed_precision="bf16", # Use "bf16" for modern GPUs, or "fp16"
log_with="wandb",
kwargs_handlers=[kwargs],
)
# --- Dataset Loading ---
# Load datasets on the main process. They will be accessible by all processes.
train_dataset, test_dataset = get_dataset()
train_dataset, test_dataset = (
DataLoader(train_dataset, batch_size=batch_size),
DataLoader(test_dataset, batch_size=batch_size),
)
# --- LR Scheduler Logic ---
# Correctly calculate total steps based on optimizer updates, not micro-batches.
# Use math.ceil to account for the last partial batch.
num_batches_per_epoch = math.ceil(len(train_dataset) / batch_size)
optimizer_steps_per_epoch = math.ceil(num_batches_per_epoch / accumulation_steps)
total_steps = optimizer_steps_per_epoch * total_epoch
warmup_steps = int(0.05 * total_steps)
# --- Model, Optimizer, and RF Helper Initialization ---
# Initialize on CPU. Accelerator will move them to the correct device.
model = UTransformer.from_pretrained_backbone(
"facebook/dinov3-vitl16-pretrain-sat493m"
).bfloat16()
lp = lpips.LPIPS(net="vgg")
rf = RF(model, "icfm", "lpips_mse", lp) # RF holds a reference to the model
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epoch)
# --- Prepare objects with Accelerator ---
# We prepare everything except the train_dataloader, which is created per-epoch
# to allow for webdataset's shuffling mechanism.
model, lp, optimizer, scheduler, train_dataset, test_dataset = accelerator.prepare(
model, lp, optimizer, scheduler, train_dataset, test_dataset
)
# --- W&B and Checkpoint Setup ---
# Initialize tracker (wandb) on the main process
accelerator.init_trackers(
project_name="cloud-removal-kmu",
config={
"batch_size": batch_size,
"accumulation_steps": accumulation_steps,
"total_epoch": total_epoch,
"learning_rate": learning_rate,
"grad_norm": grad_norm,
"total_steps": total_steps,
},
)
# Use the run name from the tracker for a consistent artifact path
# This check is needed in case there are no trackers configured.
run_name = "nerf-3"
if accelerator.trackers:
run_name = accelerator.trackers[0].run.name
artifact_dir = f"artifact/{run_name}"
checkpoint_dir = os.path.join(artifact_dir, "checkpoints")
if accelerator.is_main_process:
os.makedirs(checkpoint_dir, exist_ok=True)
accelerator.wait_for_everyone() # Ensure directory is created before any process tries to access it
# Register scheduler for checkpointing
accelerator.register_for_checkpointing(scheduler)
start_epoch = 0
# Check if a checkpoint exists to resume training
if os.path.exists(checkpoint_dir):
try:
accelerator.print(f"Resuming from checkpoint: {checkpoint_dir}")
accelerator.load_state(checkpoint_dir)
# Manually load the epoch from a tracker file
if os.path.exists(os.path.join(checkpoint_dir, "epoch_tracker.pt")):
start_epoch = (
torch.load(os.path.join(checkpoint_dir, "epoch_tracker.pt"))["epoch"]
+ 1
)
except Exception as e:
accelerator.print(
f"Could not load checkpoint. Starting from scratch. Error: {e}"
)
# --- Training Loop ---
for epoch in range(start_epoch, total_epoch):
model.train()
lossbin = {i: 0 for i in range(10)}
losscnt = {i: 1e-6 for i in range(10)}
progress_bar = tqdm(
train_dataset,
disable=not accelerator.is_local_main_process,
desc=f"Epoch {epoch + 1}/{total_epoch}",
)
for step, batch in enumerate(progress_bar):
cloud, gt = batch["cloud"], batch["gt"]
with accelerator.accumulate(model):
# Forward pass is automatically handled with mixed precision
loss, blsct, loss_list = rf.forward(gt, cloud)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), grad_norm)
optimizer.step()
optimizer.zero_grad()
# Log metrics
if accelerator.sync_gradients:
avg_loss = accelerator.gather(loss).mean().item()
current_step = epoch * optimizer_steps_per_epoch + (
step // accumulation_steps
)
accelerator.log(
{
"train/loss": avg_loss,
"train/lr": scheduler.get_last_lr()[0],
},
)
accelerator.log(loss_list)
# This per-process logging is an approximation. For perfect accuracy,
# `blsct` would need to be gathered from all processes.
for t, lss in blsct:
bin_idx = min(int(t * 10), 9)
lossbin[bin_idx] += lss
losscnt[bin_idx] += 1
# Log epoch-level metrics from the main process
if accelerator.is_main_process:
epoch_metrics = {
f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)
}
epoch_metrics["epoch"] = epoch
accelerator.log(epoch_metrics)
# --- Evaluation and Checkpointing ---
if (epoch + 1) % 50 == 0:
model.eval()
psnr_sum, ssim_sum, lpips_sum, count = 0.0, 0.0, 0.0, 0
with torch.no_grad():
for i, batch in tqdm(
enumerate(test_dataset),
disable=not accelerator.is_local_main_process,
desc=f"Benchmark {epoch + 1}/{total_epoch}",
):
images = rf.sample(batch["cloud"])
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
# Gather results from all processes for accurate metrics
image_gathered = accelerator.gather_for_metrics(image)
original_gathered = accelerator.gather_for_metrics(original)
if accelerator.is_main_process:
# Log visualization images from the first batch on the main process
if i == 0:
demo_images = [images[0][:4], images[-1][:4]]
for step_idx, demo in enumerate(demo_images):
grid = make_grid(
denormalize(demo).clamp(0, 1).float().cpu(), nrow=2
)
wandb_image = wandb.Image(grid, caption=f"step {step_idx}")
accelerator.log({"viz/decoded": wandb_image})
psnr, ssim, lpips, flawed_lpips = benchmark(
image_gathered.cpu(), original_gathered.cpu()
)
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image_gathered.shape[0]
accelerator.wait_for_everyone()
if accelerator.is_main_process:
avg_psnr = psnr_sum / count if count > 0 else 0
avg_ssim = ssim_sum / count if count > 0 else 0
avg_lpips = lpips_sum / count if count > 0 else 0
accelerator.log(
{
"eval/psnr": avg_psnr,
"eval/ssim": avg_ssim,
"eval/lpips": avg_lpips,
"epoch": epoch + 1,
}
)
# Save checkpoint on the main process
accelerator.save_state(os.path.join(checkpoint_dir, f"epoch_{epoch + 1}"))
accelerator.save_state(checkpoint_dir) # Overwrite latest
torch.save(
{"epoch": epoch}, os.path.join(checkpoint_dir, "epoch_tracker.pt")
)
# scheduler.step()
# --- Final Save and Cleanup ---
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.print("Saving final model state.")
accelerator.save_state(checkpoint_dir)
torch.save(
{"epoch": total_epoch - 1}, os.path.join(checkpoint_dir, "epoch_tracker.pt")
)
accelerator.end_training()