89 lines
2.2 KiB
Python
89 lines
2.2 KiB
Python
import torch
|
|
import torch.optim as optim
|
|
import wandb
|
|
from datasets import load_dataset
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from src.dataset.preprocess import make_transform
|
|
from src.model.utransformer import UTransformer
|
|
from src.rf import RF
|
|
|
|
transform = make_transform()
|
|
|
|
model = UTransformer.from_pretrained_backbone(
|
|
"facebook/dinov3-vits16-pretrain-lvd1689m"
|
|
).to("cuda:3")
|
|
|
|
rf = RF(model)
|
|
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
|
|
|
|
dataset = load_dataset("your-dataset-name")
|
|
train_dataset = dataset["train"]
|
|
|
|
|
|
def preprocess_function(examples):
|
|
x0_list = []
|
|
x1_list = []
|
|
|
|
for x0_img, x1_img in zip(examples["cloudy_image"], examples["clear_image"]):
|
|
x0_transformed = transform(x0_img)
|
|
x1_transformed = transform(x1_img)
|
|
|
|
x0_list.append(x0_transformed)
|
|
x1_list.append(x1_transformed)
|
|
|
|
return {"x0": x0_list, "x1": x1_list}
|
|
|
|
|
|
train_dataset = train_dataset.map(
|
|
preprocess_function,
|
|
batched=True,
|
|
batch_size=32,
|
|
remove_columns=train_dataset.column_names,
|
|
)
|
|
|
|
train_dataset.set_format(type="torch", columns=["x0", "x1"])
|
|
|
|
dataloader = DataLoader(
|
|
train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True
|
|
)
|
|
|
|
wandb.init(project="cloud-removal-kmu")
|
|
|
|
for epoch in range(100):
|
|
lossbin = {i: 0 for i in range(10)}
|
|
losscnt = {i: 1e-6 for i in range(10)}
|
|
|
|
for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/100"):
|
|
x0 = batch["x0"].to("cuda:3")
|
|
x1 = batch["x1"].to("cuda:3")
|
|
|
|
optimizer.zero_grad()
|
|
loss, blsct = rf.forward(x0, x1)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
wandb.log({"loss": loss.item()})
|
|
|
|
for t, l in blsct:
|
|
bin_idx = min(int(t * 10), 9)
|
|
lossbin[bin_idx] += l
|
|
losscnt[bin_idx] += 1
|
|
|
|
epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
|
|
epoch_metrics["epoch"] = epoch
|
|
wandb.log(epoch_metrics)
|
|
|
|
if (epoch + 1) % 10 == 0:
|
|
torch.save(
|
|
{
|
|
"epoch": epoch,
|
|
"model_state_dict": model.state_dict(),
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
},
|
|
f"checkpoint_epoch_{epoch + 1}.pt",
|
|
)
|
|
|
|
wandb.finish()
|