248 lines
8.8 KiB
Python
248 lines
8.8 KiB
Python
import math
|
|
import os
|
|
from datetime import timedelta
|
|
|
|
import lpips
|
|
import torch
|
|
import torch.optim as optim
|
|
from accelerate import Accelerator, InitProcessGroupKwargs
|
|
from torch.utils.data import DataLoader
|
|
from torchvision.utils import make_grid
|
|
from tqdm import tqdm
|
|
|
|
import wandb
|
|
from src.benchmark import benchmark
|
|
from src.dataset.cuhk_cr2 import get_dataset
|
|
from src.dataset.preprocess import denormalize
|
|
from src.model.utransformer import UTransformer
|
|
from src.rf import RF
|
|
|
|
# --- Configuration ---
|
|
batch_size = 16
|
|
accumulation_steps = 2
|
|
total_epoch = 500
|
|
grad_norm = 1.0
|
|
learning_rate = 3e-4
|
|
|
|
# --- Accelerator Setup ---
|
|
# Set a longer timeout for initialization, which can be useful when downloading
|
|
# large models or datasets on multiple nodes.
|
|
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
|
|
accelerator = Accelerator(
|
|
gradient_accumulation_steps=accumulation_steps,
|
|
mixed_precision="bf16", # Use "bf16" for modern GPUs, or "fp16"
|
|
log_with="wandb",
|
|
kwargs_handlers=[kwargs],
|
|
)
|
|
|
|
# --- Dataset Loading ---
|
|
# Load datasets on the main process. They will be accessible by all processes.
|
|
train_dataset, test_dataset = get_dataset()
|
|
|
|
train_dataset, test_dataset = (
|
|
DataLoader(train_dataset, batch_size=batch_size),
|
|
DataLoader(test_dataset, batch_size=batch_size),
|
|
)
|
|
# --- LR Scheduler Logic ---
|
|
# Correctly calculate total steps based on optimizer updates, not micro-batches.
|
|
# Use math.ceil to account for the last partial batch.
|
|
num_batches_per_epoch = math.ceil(len(train_dataset) / batch_size)
|
|
optimizer_steps_per_epoch = math.ceil(num_batches_per_epoch / accumulation_steps)
|
|
total_steps = optimizer_steps_per_epoch * total_epoch
|
|
warmup_steps = int(0.05 * total_steps)
|
|
|
|
# --- Model, Optimizer, and RF Helper Initialization ---
|
|
# Initialize on CPU. Accelerator will move them to the correct device.
|
|
model = UTransformer.from_pretrained_backbone(
|
|
"facebook/dinov3-vitl16-pretrain-sat493m"
|
|
).bfloat16()
|
|
lp = lpips.LPIPS(net="vgg")
|
|
rf = RF(model, "icfm", "lpips_mse", lp) # RF holds a reference to the model
|
|
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epoch)
|
|
|
|
# --- Prepare objects with Accelerator ---
|
|
# We prepare everything except the train_dataloader, which is created per-epoch
|
|
# to allow for webdataset's shuffling mechanism.
|
|
model, lp, optimizer, scheduler, train_dataset, test_dataset = accelerator.prepare(
|
|
model, lp, optimizer, scheduler, train_dataset, test_dataset
|
|
)
|
|
|
|
# --- W&B and Checkpoint Setup ---
|
|
# Initialize tracker (wandb) on the main process
|
|
accelerator.init_trackers(
|
|
project_name="cloud-removal-kmu",
|
|
config={
|
|
"batch_size": batch_size,
|
|
"accumulation_steps": accumulation_steps,
|
|
"total_epoch": total_epoch,
|
|
"learning_rate": learning_rate,
|
|
"grad_norm": grad_norm,
|
|
"total_steps": total_steps,
|
|
},
|
|
)
|
|
|
|
# Use the run name from the tracker for a consistent artifact path
|
|
# This check is needed in case there are no trackers configured.
|
|
run_name = "nerf-3"
|
|
if accelerator.trackers:
|
|
run_name = accelerator.trackers[0].run.name
|
|
|
|
artifact_dir = f"artifact/{run_name}"
|
|
checkpoint_dir = os.path.join(artifact_dir, "checkpoints")
|
|
|
|
if accelerator.is_main_process:
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
accelerator.wait_for_everyone() # Ensure directory is created before any process tries to access it
|
|
|
|
# Register scheduler for checkpointing
|
|
accelerator.register_for_checkpointing(scheduler)
|
|
|
|
start_epoch = 0
|
|
# Check if a checkpoint exists to resume training
|
|
if os.path.exists(checkpoint_dir):
|
|
try:
|
|
accelerator.print(f"Resuming from checkpoint: {checkpoint_dir}")
|
|
accelerator.load_state(checkpoint_dir)
|
|
# Manually load the epoch from a tracker file
|
|
if os.path.exists(os.path.join(checkpoint_dir, "epoch_tracker.pt")):
|
|
start_epoch = (
|
|
torch.load(os.path.join(checkpoint_dir, "epoch_tracker.pt"))["epoch"]
|
|
+ 1
|
|
)
|
|
except Exception as e:
|
|
accelerator.print(
|
|
f"Could not load checkpoint. Starting from scratch. Error: {e}"
|
|
)
|
|
|
|
|
|
# --- Training Loop ---
|
|
for epoch in range(start_epoch, total_epoch):
|
|
model.train()
|
|
lossbin = {i: 0 for i in range(10)}
|
|
losscnt = {i: 1e-6 for i in range(10)}
|
|
|
|
progress_bar = tqdm(
|
|
train_dataset,
|
|
disable=not accelerator.is_local_main_process,
|
|
desc=f"Epoch {epoch + 1}/{total_epoch}",
|
|
)
|
|
|
|
for step, batch in enumerate(progress_bar):
|
|
cloud, gt = batch["cloud"], batch["gt"]
|
|
|
|
with accelerator.accumulate(model):
|
|
# Forward pass is automatically handled with mixed precision
|
|
loss, blsct, loss_list = rf.forward(gt, cloud)
|
|
|
|
accelerator.backward(loss)
|
|
|
|
if accelerator.sync_gradients:
|
|
accelerator.clip_grad_norm_(model.parameters(), grad_norm)
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
# Log metrics
|
|
if accelerator.sync_gradients:
|
|
avg_loss = accelerator.gather(loss).mean().item()
|
|
current_step = epoch * optimizer_steps_per_epoch + (
|
|
step // accumulation_steps
|
|
)
|
|
accelerator.log(
|
|
{
|
|
"train/loss": avg_loss,
|
|
"train/lr": scheduler.get_last_lr()[0],
|
|
},
|
|
)
|
|
accelerator.log(loss_list)
|
|
|
|
# This per-process logging is an approximation. For perfect accuracy,
|
|
# `blsct` would need to be gathered from all processes.
|
|
for t, lss in blsct:
|
|
bin_idx = min(int(t * 10), 9)
|
|
lossbin[bin_idx] += lss
|
|
losscnt[bin_idx] += 1
|
|
|
|
# Log epoch-level metrics from the main process
|
|
if accelerator.is_main_process:
|
|
epoch_metrics = {
|
|
f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)
|
|
}
|
|
epoch_metrics["epoch"] = epoch
|
|
accelerator.log(epoch_metrics)
|
|
|
|
# --- Evaluation and Checkpointing ---
|
|
if (epoch + 1) % 50 == 0:
|
|
model.eval()
|
|
psnr_sum, ssim_sum, lpips_sum, count = 0.0, 0.0, 0.0, 0
|
|
|
|
with torch.no_grad():
|
|
for i, batch in tqdm(
|
|
enumerate(test_dataset),
|
|
disable=not accelerator.is_local_main_process,
|
|
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
|
):
|
|
images = rf.sample(batch["cloud"])
|
|
image = denormalize(images[-1]).clamp(0, 1)
|
|
original = denormalize(batch["gt"]).clamp(0, 1)
|
|
|
|
# Gather results from all processes for accurate metrics
|
|
image_gathered = accelerator.gather_for_metrics(image)
|
|
original_gathered = accelerator.gather_for_metrics(original)
|
|
|
|
if accelerator.is_main_process:
|
|
# Log visualization images from the first batch on the main process
|
|
if i == 0:
|
|
demo_images = [images[0][:4], images[-1][:4]]
|
|
for step_idx, demo in enumerate(demo_images):
|
|
grid = make_grid(
|
|
denormalize(demo).clamp(0, 1).float().cpu(), nrow=2
|
|
)
|
|
wandb_image = wandb.Image(grid, caption=f"step {step_idx}")
|
|
accelerator.log({"viz/decoded": wandb_image})
|
|
|
|
psnr, ssim, lpips, flawed_lpips = benchmark(
|
|
image_gathered.cpu(), original_gathered.cpu()
|
|
)
|
|
psnr_sum += psnr.sum().item()
|
|
ssim_sum += ssim.sum().item()
|
|
lpips_sum += lpips.sum().item()
|
|
count += image_gathered.shape[0]
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
if accelerator.is_main_process:
|
|
avg_psnr = psnr_sum / count if count > 0 else 0
|
|
avg_ssim = ssim_sum / count if count > 0 else 0
|
|
avg_lpips = lpips_sum / count if count > 0 else 0
|
|
accelerator.log(
|
|
{
|
|
"eval/psnr": avg_psnr,
|
|
"eval/ssim": avg_ssim,
|
|
"eval/lpips": avg_lpips,
|
|
"epoch": epoch + 1,
|
|
}
|
|
)
|
|
|
|
# Save checkpoint on the main process
|
|
accelerator.save_state(os.path.join(checkpoint_dir, f"epoch_{epoch + 1}"))
|
|
accelerator.save_state(checkpoint_dir) # Overwrite latest
|
|
torch.save(
|
|
{"epoch": epoch}, os.path.join(checkpoint_dir, "epoch_tracker.pt")
|
|
)
|
|
|
|
# scheduler.step()
|
|
|
|
|
|
# --- Final Save and Cleanup ---
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
accelerator.print("Saving final model state.")
|
|
accelerator.save_state(checkpoint_dir)
|
|
torch.save(
|
|
{"epoch": total_epoch - 1}, os.path.join(checkpoint_dir, "epoch_tracker.pt")
|
|
)
|
|
|
|
accelerator.end_training()
|