Files
cloud-removal/main.py
2025-09-29 16:19:24 +09:00

98 lines
2.7 KiB
Python

import os
from typing import cast
import torch
import torch.optim as optim
import wandb
from datasets import DatasetDict, load_dataset
from tqdm import tqdm
from src.dataset.preprocess import make_transform
from src.model.utransformer import UTransformer
from src.rf import RF
transform = make_transform()
model = UTransformer.from_pretrained_backbone(
"facebook/dinov3-vits16-pretrain-lvd1689m"
).to("cuda:3")
rf = RF(model)
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
dataset = cast(DatasetDict, load_dataset("your-dataset-name"))
train_dataset = dataset["train"]
def preprocess_function(examples):
x0_list = []
x1_list = []
for x0_img, x1_img in zip(examples["cloudy"], examples["clear"]):
x0_transformed = transform(x0_img)
x1_transformed = transform(x1_img)
x0_list.append(x0_transformed)
x1_list.append(x1_transformed)
return {"x0": x0_list, "x1": x1_list}
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
batch_size=32,
remove_columns=train_dataset.column_names,
)
train_dataset.set_format(type="torch", columns=["x0", "x1"])
wandb.init(project="cloud-removal-kmu")
if not (wandb.run and wandb.run.name):
raise Exception("nope")
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
batch_size = 16
for epoch in range(100):
lossbin = {i: 0 for i in range(10)}
losscnt = {i: 1e-6 for i in range(10)}
train_dataset = train_dataset.shuffle(seed=epoch)
for i in tqdm(
range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/100"
):
batch = train_dataset[i : i + batch_size]
x0 = torch.stack(batch["x0"]).to("cuda:3")
x1 = torch.stack(batch["x1"]).to("cuda:3")
optimizer.zero_grad()
loss, blsct = rf.forward(x0, x1)
loss.backward()
optimizer.step()
wandb.log({"loss": loss.item()})
for t, lss in blsct:
bin_idx = min(int(t * 10), 9)
lossbin[bin_idx] += lss
losscnt[bin_idx] += 1
epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
epoch_metrics["epoch"] = epoch
wandb.log(epoch_metrics)
if (epoch + 1) % 10 == 0:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
)
torch.save(
{
"epoch": 100,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
f"artifact/{wandb.run.name}/checkpoint_final.pt",
)
wandb.finish()