things
This commit is contained in:
265
train_accelerator_hdit.py
Normal file
265
train_accelerator_hdit.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import math
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs
|
||||
from hdit import HDiT
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.utils import make_grid
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from src.benchmark import benchmark
|
||||
from src.dataset.cuhk_cr2 import get_dataset
|
||||
from src.dataset.preprocess import denormalize
|
||||
from src.rf import RF
|
||||
|
||||
# --- Configuration ---
|
||||
batch_size = 4
|
||||
accumulation_steps = 2
|
||||
total_epoch = 500
|
||||
grad_norm = 1.0
|
||||
learning_rate = 3e-4
|
||||
|
||||
# --- Accelerator Setup ---
|
||||
# Set a longer timeout for initialization, which can be useful when downloading
|
||||
# large models or datasets on multiple nodes.
|
||||
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=accumulation_steps,
|
||||
mixed_precision="bf16", # Use "bf16" for modern GPUs, or "fp16"
|
||||
log_with="wandb",
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
# --- Dataset Loading ---
|
||||
# Load datasets on the main process. They will be accessible by all processes.
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
train_dataset, test_dataset = (
|
||||
DataLoader(train_dataset, batch_size=batch_size), # type: ignore
|
||||
DataLoader(test_dataset, batch_size=batch_size), # type: ignore
|
||||
)
|
||||
|
||||
num_batches_per_epoch = math.ceil(len(train_dataset) / batch_size)
|
||||
optimizer_steps_per_epoch = math.ceil(num_batches_per_epoch / accumulation_steps)
|
||||
total_steps = optimizer_steps_per_epoch * total_epoch
|
||||
warmup_steps = int(0.05 * total_steps)
|
||||
|
||||
# --- Model, Optimizer, and RF Helper Initialization ---
|
||||
# Initialize on CPU. Accelerator will move them to the correct device.
|
||||
model = HDiT(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=[1, 1], # type: ignore
|
||||
widths=[256, 512],
|
||||
middle_width=1024,
|
||||
depths=[4, 4],
|
||||
middle_depth=8,
|
||||
mapping_width=512,
|
||||
mapping_depth=2,
|
||||
).bfloat16()
|
||||
|
||||
print(sum(p.numel() for p in model.parameters() if p.requires_grad), "params")
|
||||
lp = lpips.LPIPS(net="vgg")
|
||||
rf = RF(model, "icfm", "lpips_mse", lp) # RF holds a reference to the model
|
||||
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epoch)
|
||||
|
||||
# --- Prepare objects with Accelerator ---
|
||||
# We prepare everything except the train_dataloader, which is created per-epoch
|
||||
# to allow for webdataset's shuffling mechanism.
|
||||
model, lp, optimizer, scheduler, train_dataset, test_dataset = accelerator.prepare(
|
||||
model, lp, optimizer, scheduler, train_dataset, test_dataset
|
||||
)
|
||||
|
||||
# --- W&B and Checkpoint Setup ---
|
||||
# Initialize tracker (wandb) on the main process
|
||||
accelerator.init_trackers(
|
||||
project_name="cloud-removal-kmu",
|
||||
config={
|
||||
"batch_size": batch_size,
|
||||
"accumulation_steps": accumulation_steps,
|
||||
"total_epoch": total_epoch,
|
||||
"learning_rate": learning_rate,
|
||||
"grad_norm": grad_norm,
|
||||
"total_steps": total_steps,
|
||||
},
|
||||
)
|
||||
|
||||
# Use the run name from the tracker for a consistent artifact path
|
||||
# This check is needed in case there are no trackers configured.
|
||||
run_name = "testest"
|
||||
if accelerator.trackers:
|
||||
run_name = accelerator.trackers[0].run.name
|
||||
|
||||
artifact_dir = f"artifact/{run_name}"
|
||||
checkpoint_dir = os.path.join(artifact_dir, "checkpoints")
|
||||
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone() # Ensure directory is created before any process tries to access it
|
||||
|
||||
# Register scheduler for checkpointing
|
||||
accelerator.register_for_checkpointing(scheduler)
|
||||
|
||||
start_epoch = 0
|
||||
# Check if a checkpoint exists to resume training
|
||||
if os.path.exists(checkpoint_dir):
|
||||
try:
|
||||
accelerator.print(f"Resuming from checkpoint: {checkpoint_dir}")
|
||||
accelerator.load_state(checkpoint_dir)
|
||||
# Manually load the epoch from a tracker file
|
||||
if os.path.exists(os.path.join(checkpoint_dir, "epoch_tracker.pt")):
|
||||
start_epoch = (
|
||||
torch.load(os.path.join(checkpoint_dir, "epoch_tracker.pt"))["epoch"]
|
||||
+ 1
|
||||
)
|
||||
except Exception as e:
|
||||
accelerator.print(
|
||||
f"Could not load checkpoint. Starting from scratch. Error: {e}"
|
||||
)
|
||||
|
||||
|
||||
# --- Training Loop ---
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
model.train()
|
||||
lossbin = {i: 0 for i in range(10)}
|
||||
losscnt = {i: 1e-6 for i in range(10)}
|
||||
|
||||
progress_bar = tqdm(
|
||||
train_dataset,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc=f"Epoch {epoch + 1}/{total_epoch}",
|
||||
)
|
||||
|
||||
for step, batch in enumerate(progress_bar):
|
||||
cloud, gt = batch["cloud"], batch["gt"]
|
||||
|
||||
with accelerator.accumulate(model):
|
||||
# Forward pass is automatically handled with mixed precision
|
||||
loss, blsct, loss_list = rf.forward(
|
||||
torch.cat((batch["gt"], batch["gt_nir"]), dim=1),
|
||||
torch.cat((batch["cloud"], batch["cloud_nir"]), dim=1),
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Log metrics
|
||||
if accelerator.sync_gradients:
|
||||
avg_loss = accelerator.gather(loss).mean().item() # type: ignore
|
||||
current_step = epoch * optimizer_steps_per_epoch + (
|
||||
step // accumulation_steps
|
||||
)
|
||||
accelerator.log(
|
||||
{
|
||||
"train/loss": avg_loss,
|
||||
"train/lr": scheduler.get_last_lr()[0],
|
||||
},
|
||||
# step=current_step,
|
||||
)
|
||||
accelerator.log(loss_list)
|
||||
|
||||
# This per-process logging is an approximation. For perfect accuracy,
|
||||
# `blsct` would need to be gathered from all processes.
|
||||
for t, lss in blsct:
|
||||
bin_idx = min(int(t * 10), 9)
|
||||
lossbin[bin_idx] += lss
|
||||
losscnt[bin_idx] += 1
|
||||
|
||||
# Log epoch-level metrics from the main process
|
||||
if accelerator.is_main_process:
|
||||
epoch_metrics = {
|
||||
f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)
|
||||
}
|
||||
epoch_metrics["epoch"] = epoch
|
||||
accelerator.log(epoch_metrics)
|
||||
|
||||
# --- Evaluation and Checkpointing ---
|
||||
if (epoch + 1) % 50 == 0:
|
||||
model.eval()
|
||||
psnr_sum, ssim_sum, lpips_sum, flawed_lpips_sum, count = 0.0, 0.0, 0.0, 0.0, 0
|
||||
|
||||
with torch.no_grad():
|
||||
for i, batch in tqdm(
|
||||
enumerate(test_dataset),
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
images = rf.sample(
|
||||
torch.cat((batch["cloud"], batch["cloud_nir"]), dim=1)
|
||||
)
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
|
||||
# Gather results from all processes for accurate metrics
|
||||
image_gathered = accelerator.gather_for_metrics(image)
|
||||
original_gathered = accelerator.gather_for_metrics(original)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
# Log visualization images from the first batch on the main process
|
||||
if i == 0:
|
||||
demo_images = [images[0][:4], images[-1][:4]]
|
||||
for step_idx, demo in enumerate(demo_images):
|
||||
grid = make_grid(
|
||||
denormalize(demo).clamp(0, 1).float().cpu(), nrow=2
|
||||
)
|
||||
wandb_image = wandb.Image(grid, caption=f"step {step_idx}")
|
||||
accelerator.log({"viz/decoded": wandb_image})
|
||||
|
||||
psnr, ssim, lpips, flawed_lpips = benchmark(
|
||||
image_gathered.cpu(), # type: ignore
|
||||
original_gathered.cpu(), # type: ignore
|
||||
)
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
flawed_lpips_sum += flawed_lpips.sum().item()
|
||||
count += image_gathered.shape[0] # type: ignore
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
avg_psnr = psnr_sum / count if count > 0 else 0
|
||||
avg_ssim = ssim_sum / count if count > 0 else 0
|
||||
avg_lpips = lpips_sum / count if count > 0 else 0
|
||||
avg_flawed_lpips = flawed_lpips_sum / count if count > 0 else 0
|
||||
accelerator.log(
|
||||
{
|
||||
"eval/psnr": avg_psnr,
|
||||
"eval/ssim": avg_ssim,
|
||||
"eval/lpips": avg_lpips,
|
||||
"eval/flawed_lpips": avg_flawed_lpips,
|
||||
"epoch": epoch + 1,
|
||||
}
|
||||
)
|
||||
|
||||
# Save checkpoint on the main process
|
||||
accelerator.save_state(os.path.join(checkpoint_dir, f"epoch_{epoch + 1}"))
|
||||
accelerator.save_state(checkpoint_dir) # Overwrite latest
|
||||
torch.save(
|
||||
{"epoch": epoch}, os.path.join(checkpoint_dir, "epoch_tracker.pt")
|
||||
)
|
||||
|
||||
# scheduler.step()
|
||||
|
||||
|
||||
# --- Final Save and Cleanup ---
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
accelerator.print("Saving final model state.")
|
||||
accelerator.save_state(checkpoint_dir)
|
||||
torch.save(
|
||||
{"epoch": total_epoch - 1}, os.path.join(checkpoint_dir, "epoch_tracker.pt")
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
Reference in New Issue
Block a user