things
This commit is contained in:
93
main.py
93
main.py
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
@@ -12,17 +13,43 @@ from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
device = "cuda:1"
|
||||
|
||||
batch_size = 8 * 4
|
||||
accumulation_steps = 4
|
||||
total_epoch = 500
|
||||
|
||||
steps_per_epoch = len(train_dataset) // batch_size
|
||||
total_steps = steps_per_epoch * total_epoch
|
||||
warmup_steps = int(0.05 * total_steps)
|
||||
|
||||
grad_norm = 1.0
|
||||
|
||||
|
||||
model = UTransformer.from_pretrained_backbone(
|
||||
"facebook/dinov3-vitl16-pretrain-sat493m"
|
||||
).to(device)
|
||||
rf = RF(model)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
||||
rf = RF(model, "icfm", "lpips_mse")
|
||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
wandb.init(project="cloud-removal-kmu", id="dashing-moon-31", resume="allow")
|
||||
# scheduler
|
||||
def get_lr(step: int) -> float:
|
||||
if step < warmup_steps:
|
||||
return step / warmup_steps
|
||||
else:
|
||||
progress = (step - warmup_steps) / (total_steps - warmup_steps)
|
||||
return 0.5 * (1 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
scheduler = optim.lr_scheduler.LambdaLR(optimizer, get_lr)
|
||||
|
||||
wandb.init(project="cloud-removal-kmu", resume="allow")
|
||||
|
||||
# phase 2
|
||||
# model.requires_grad_(True)
|
||||
|
||||
if not (wandb.run and wandb.run.name):
|
||||
raise Exception("nope")
|
||||
@@ -35,11 +62,11 @@ if os.path.exists(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
if "scheduler_state_dict" in checkpoint:
|
||||
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
start_epoch = checkpoint["epoch"] + 1
|
||||
|
||||
batch_size = 8
|
||||
accumulation_steps = 8
|
||||
total_epoch = 1000
|
||||
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
lossbin = {i: 0 for i in range(10)}
|
||||
losscnt = {i: 1e-6 for i in range(10)}
|
||||
@@ -54,15 +81,31 @@ for epoch in range(start_epoch, total_epoch):
|
||||
cloud = batch["cloud"].to(device)
|
||||
gt = batch["gt"].to(device)
|
||||
|
||||
loss, blsct = rf.forward(gt, cloud)
|
||||
loss, blsct, loss_list = rf.forward(gt, cloud)
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
if (i // batch_size + 1) % accumulation_steps == 0:
|
||||
# total_norm = torch.nn.utils.clip_grad_norm_(
|
||||
# model.parameters(), max_norm=grad_norm
|
||||
# )
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
wandb.log({"train/loss": loss.item() * accumulation_steps})
|
||||
# wandb.log(
|
||||
# {
|
||||
# "train/grad_norm": total_norm.item(),
|
||||
# }
|
||||
# )
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
"train/loss": loss.item() * accumulation_steps,
|
||||
"train/lr": scheduler.get_last_lr()[0],
|
||||
}
|
||||
)
|
||||
wandb.log(loss_list)
|
||||
|
||||
for t, lss in blsct:
|
||||
bin_idx = min(int(t * 10), 9)
|
||||
@@ -70,14 +113,23 @@ for epoch in range(start_epoch, total_epoch):
|
||||
losscnt[bin_idx] += 1
|
||||
|
||||
if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0:
|
||||
# total_norm = torch.nn.utils.clip_grad_norm_(
|
||||
# model.parameters(), max_norm=grad_norm
|
||||
# )
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
# wandb.log(
|
||||
# {
|
||||
# "train/grad_norm": total_norm.item(),
|
||||
# }
|
||||
# )
|
||||
|
||||
epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
|
||||
epoch_metrics["epoch"] = epoch
|
||||
wandb.log(epoch_metrics)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
if (epoch + 1) % 50 == 0:
|
||||
rf.model.eval()
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
@@ -90,7 +142,7 @@ for epoch in range(start_epoch, total_epoch):
|
||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample_heun(batch["cloud"].to(device))
|
||||
images = rf.sample(batch["cloud"].to(device))
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
|
||||
@@ -128,24 +180,27 @@ for epoch in range(start_epoch, total_epoch):
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
},
|
||||
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
|
||||
)
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch, # type: ignore
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
},
|
||||
f"artifact/{wandb.run.name}/checkpoint_final.pt",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user