This commit is contained in:
neulus
2025-11-02 22:53:04 +09:00
parent e51017897d
commit ca589fab7c
13 changed files with 888 additions and 96 deletions

247
train_accelerator.py Normal file
View File

@@ -0,0 +1,247 @@
import math
import os
from datetime import timedelta
import lpips
import torch
import torch.optim as optim
from accelerate import Accelerator, InitProcessGroupKwargs
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from tqdm import tqdm
import wandb
from src.benchmark import benchmark
from src.dataset.cuhk_cr2 import get_dataset
from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer
from src.rf import RF
# --- Configuration ---
batch_size = 16
accumulation_steps = 2
total_epoch = 500
grad_norm = 1.0
learning_rate = 3e-4
# --- Accelerator Setup ---
# Set a longer timeout for initialization, which can be useful when downloading
# large models or datasets on multiple nodes.
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
accelerator = Accelerator(
gradient_accumulation_steps=accumulation_steps,
mixed_precision="bf16", # Use "bf16" for modern GPUs, or "fp16"
log_with="wandb",
kwargs_handlers=[kwargs],
)
# --- Dataset Loading ---
# Load datasets on the main process. They will be accessible by all processes.
train_dataset, test_dataset = get_dataset()
train_dataset, test_dataset = (
DataLoader(train_dataset, batch_size=batch_size),
DataLoader(test_dataset, batch_size=batch_size),
)
# --- LR Scheduler Logic ---
# Correctly calculate total steps based on optimizer updates, not micro-batches.
# Use math.ceil to account for the last partial batch.
num_batches_per_epoch = math.ceil(len(train_dataset) / batch_size)
optimizer_steps_per_epoch = math.ceil(num_batches_per_epoch / accumulation_steps)
total_steps = optimizer_steps_per_epoch * total_epoch
warmup_steps = int(0.05 * total_steps)
# --- Model, Optimizer, and RF Helper Initialization ---
# Initialize on CPU. Accelerator will move them to the correct device.
model = UTransformer.from_pretrained_backbone(
"facebook/dinov3-vitl16-pretrain-sat493m"
).bfloat16()
lp = lpips.LPIPS(net="vgg")
rf = RF(model, "icfm", "lpips_mse", lp) # RF holds a reference to the model
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epoch)
# --- Prepare objects with Accelerator ---
# We prepare everything except the train_dataloader, which is created per-epoch
# to allow for webdataset's shuffling mechanism.
model, lp, optimizer, scheduler, train_dataset, test_dataset = accelerator.prepare(
model, lp, optimizer, scheduler, train_dataset, test_dataset
)
# --- W&B and Checkpoint Setup ---
# Initialize tracker (wandb) on the main process
accelerator.init_trackers(
project_name="cloud-removal-kmu",
config={
"batch_size": batch_size,
"accumulation_steps": accumulation_steps,
"total_epoch": total_epoch,
"learning_rate": learning_rate,
"grad_norm": grad_norm,
"total_steps": total_steps,
},
)
# Use the run name from the tracker for a consistent artifact path
# This check is needed in case there are no trackers configured.
run_name = "nerf-3"
if accelerator.trackers:
run_name = accelerator.trackers[0].run.name
artifact_dir = f"artifact/{run_name}"
checkpoint_dir = os.path.join(artifact_dir, "checkpoints")
if accelerator.is_main_process:
os.makedirs(checkpoint_dir, exist_ok=True)
accelerator.wait_for_everyone() # Ensure directory is created before any process tries to access it
# Register scheduler for checkpointing
accelerator.register_for_checkpointing(scheduler)
start_epoch = 0
# Check if a checkpoint exists to resume training
if os.path.exists(checkpoint_dir):
try:
accelerator.print(f"Resuming from checkpoint: {checkpoint_dir}")
accelerator.load_state(checkpoint_dir)
# Manually load the epoch from a tracker file
if os.path.exists(os.path.join(checkpoint_dir, "epoch_tracker.pt")):
start_epoch = (
torch.load(os.path.join(checkpoint_dir, "epoch_tracker.pt"))["epoch"]
+ 1
)
except Exception as e:
accelerator.print(
f"Could not load checkpoint. Starting from scratch. Error: {e}"
)
# --- Training Loop ---
for epoch in range(start_epoch, total_epoch):
model.train()
lossbin = {i: 0 for i in range(10)}
losscnt = {i: 1e-6 for i in range(10)}
progress_bar = tqdm(
train_dataset,
disable=not accelerator.is_local_main_process,
desc=f"Epoch {epoch + 1}/{total_epoch}",
)
for step, batch in enumerate(progress_bar):
cloud, gt = batch["cloud"], batch["gt"]
with accelerator.accumulate(model):
# Forward pass is automatically handled with mixed precision
loss, blsct, loss_list = rf.forward(gt, cloud)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), grad_norm)
optimizer.step()
optimizer.zero_grad()
# Log metrics
if accelerator.sync_gradients:
avg_loss = accelerator.gather(loss).mean().item()
current_step = epoch * optimizer_steps_per_epoch + (
step // accumulation_steps
)
accelerator.log(
{
"train/loss": avg_loss,
"train/lr": scheduler.get_last_lr()[0],
},
)
accelerator.log(loss_list)
# This per-process logging is an approximation. For perfect accuracy,
# `blsct` would need to be gathered from all processes.
for t, lss in blsct:
bin_idx = min(int(t * 10), 9)
lossbin[bin_idx] += lss
losscnt[bin_idx] += 1
# Log epoch-level metrics from the main process
if accelerator.is_main_process:
epoch_metrics = {
f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)
}
epoch_metrics["epoch"] = epoch
accelerator.log(epoch_metrics)
# --- Evaluation and Checkpointing ---
if (epoch + 1) % 50 == 0:
model.eval()
psnr_sum, ssim_sum, lpips_sum, count = 0.0, 0.0, 0.0, 0
with torch.no_grad():
for i, batch in tqdm(
enumerate(test_dataset),
disable=not accelerator.is_local_main_process,
desc=f"Benchmark {epoch + 1}/{total_epoch}",
):
images = rf.sample(batch["cloud"])
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
# Gather results from all processes for accurate metrics
image_gathered = accelerator.gather_for_metrics(image)
original_gathered = accelerator.gather_for_metrics(original)
if accelerator.is_main_process:
# Log visualization images from the first batch on the main process
if i == 0:
demo_images = [images[0][:4], images[-1][:4]]
for step_idx, demo in enumerate(demo_images):
grid = make_grid(
denormalize(demo).clamp(0, 1).float().cpu(), nrow=2
)
wandb_image = wandb.Image(grid, caption=f"step {step_idx}")
accelerator.log({"viz/decoded": wandb_image})
psnr, ssim, lpips, flawed_lpips = benchmark(
image_gathered.cpu(), original_gathered.cpu()
)
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image_gathered.shape[0]
accelerator.wait_for_everyone()
if accelerator.is_main_process:
avg_psnr = psnr_sum / count if count > 0 else 0
avg_ssim = ssim_sum / count if count > 0 else 0
avg_lpips = lpips_sum / count if count > 0 else 0
accelerator.log(
{
"eval/psnr": avg_psnr,
"eval/ssim": avg_ssim,
"eval/lpips": avg_lpips,
"epoch": epoch + 1,
}
)
# Save checkpoint on the main process
accelerator.save_state(os.path.join(checkpoint_dir, f"epoch_{epoch + 1}"))
accelerator.save_state(checkpoint_dir) # Overwrite latest
torch.save(
{"epoch": epoch}, os.path.join(checkpoint_dir, "epoch_tracker.pt")
)
# scheduler.step()
# --- Final Save and Cleanup ---
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.print("Saving final model state.")
accelerator.save_state(checkpoint_dir)
torch.save(
{"epoch": total_epoch - 1}, os.path.join(checkpoint_dir, "epoch_tracker.pt")
)
accelerator.end_training()