import torch import torch.optim as optim import wandb from datasets import load_dataset from torch.utils.data import DataLoader from tqdm import tqdm from src.dataset.preprocess import make_transform from src.model.utransformer import UTransformer from src.rf import RF transform = make_transform() model = UTransformer.from_pretrained_backbone( "facebook/dinov3-vits16-pretrain-lvd1689m" ).to("cuda:3") rf = RF(model) optimizer = optim.AdamW(model.parameters(), lr=5e-4) dataset = load_dataset("your-dataset-name") train_dataset = dataset["train"] def preprocess_function(examples): x0_list = [] x1_list = [] for x0_img, x1_img in zip(examples["cloudy_image"], examples["clear_image"]): x0_transformed = transform(x0_img) x1_transformed = transform(x1_img) x0_list.append(x0_transformed) x1_list.append(x1_transformed) return {"x0": x0_list, "x1": x1_list} train_dataset = train_dataset.map( preprocess_function, batched=True, batch_size=32, remove_columns=train_dataset.column_names, ) train_dataset.set_format(type="torch", columns=["x0", "x1"]) dataloader = DataLoader( train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True ) wandb.init(project="cloud-removal-kmu") for epoch in range(100): lossbin = {i: 0 for i in range(10)} losscnt = {i: 1e-6 for i in range(10)} for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/100"): x0 = batch["x0"].to("cuda:3") x1 = batch["x1"].to("cuda:3") optimizer.zero_grad() loss, blsct = rf.forward(x0, x1) loss.backward() optimizer.step() wandb.log({"loss": loss.item()}) for t, l in blsct: bin_idx = min(int(t * 10), 9) lossbin[bin_idx] += l losscnt[bin_idx] += 1 epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)} epoch_metrics["epoch"] = epoch wandb.log(epoch_metrics) if (epoch + 1) % 10 == 0: torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, f"checkpoint_epoch_{epoch + 1}.pt", ) wandb.finish()