This commit is contained in:
neulus
2025-09-30 10:27:41 +09:00
parent 0ccf1ff42d
commit 8966cafb8f
2 changed files with 46 additions and 19 deletions

53
main.py
View File

@@ -11,54 +11,72 @@ from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer from src.model.utransformer import UTransformer
from src.rf import RF from src.rf import RF
device = "cuda:3" device = "cuda:0"
model = UTransformer.from_pretrained_backbone( model = UTransformer.from_pretrained_backbone(
"facebook/dinov3-vits16-pretrain-lvd1689m" "facebook/dinov3-vitl16-pretrain-sat493m"
).to(device) ).to(device)
rf = RF(model) rf = RF(model)
optimizer = optim.AdamW(model.parameters(), lr=5e-4) optimizer = optim.AdamW(model.parameters(), lr=1e-4)
train_dataset, test_dataset = get_dataset() train_dataset, test_dataset = get_dataset()
wandb.init(project="cloud-removal-kmu") wandb.init(project="cloud-removal-kmu", id="icy-field-11", resume="allow")
if not (wandb.run and wandb.run.name): if not (wandb.run and wandb.run.name):
raise Exception("nope") raise Exception("nope")
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True) os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
batch_size = 32 start_epoch = 0
for epoch in range(100): checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"] + 1
batch_size = 4
accumulation_steps = 4
total_epoch = 1000
for epoch in range(start_epoch, total_epoch):
lossbin = {i: 0 for i in range(10)} lossbin = {i: 0 for i in range(10)}
losscnt = {i: 1e-6 for i in range(10)} losscnt = {i: 1e-6 for i in range(10)}
train_dataset = train_dataset.shuffle(seed=epoch) train_dataset = train_dataset.shuffle(seed=epoch)
for i in tqdm( for i in tqdm(
range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/100" range(0, len(train_dataset), batch_size),
desc=f"Epoch {epoch + 1}/{total_epoch}",
): ):
batch = train_dataset[i : i + batch_size] batch = train_dataset[i : i + batch_size]
x0 = batch["x0"].to(device) x0 = batch["x0"].to(device)
x1 = batch["x1"].to(device) x1 = batch["x1"].to(device)
optimizer.zero_grad()
loss, blsct = rf.forward(x0, x1) loss, blsct = rf.forward(x0, x1)
loss = loss / accumulation_steps
loss.backward() loss.backward()
if (i // batch_size + 1) % accumulation_steps == 0:
optimizer.step() optimizer.step()
wandb.log({"loss": loss.item()}) optimizer.zero_grad()
wandb.log({"train/loss": loss.item() * accumulation_steps})
for t, lss in blsct: for t, lss in blsct:
bin_idx = min(int(t * 10), 9) bin_idx = min(int(t * 10), 9)
lossbin[bin_idx] += lss lossbin[bin_idx] += lss
losscnt[bin_idx] += 1 losscnt[bin_idx] += 1
epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)} if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0:
optimizer.step()
optimizer.zero_grad()
epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
epoch_metrics["epoch"] = epoch epoch_metrics["epoch"] = epoch
wandb.log(epoch_metrics) wandb.log(epoch_metrics)
if (epoch + 1) % 10 == 0: if (epoch + 1) % 10 == 0:
# bench
rf.model.eval() rf.model.eval()
psnr_sum = 0 psnr_sum = 0
ssim_sum = 0 ssim_sum = 0
@@ -68,7 +86,7 @@ for epoch in range(100):
with torch.no_grad(): with torch.no_grad():
for i in tqdm( for i in tqdm(
range(0, len(test_dataset), batch_size), range(0, len(test_dataset), batch_size),
desc=f"Benchmark {epoch + 1}/100", desc=f"Benchmark {epoch + 1}/{total_epoch}",
): ):
batch = test_dataset[i : i + batch_size] batch = test_dataset[i : i + batch_size]
images = rf.sample(batch["x0"].to(device)) images = rf.sample(batch["x0"].to(device))
@@ -105,7 +123,16 @@ for epoch in range(100):
torch.save( torch.save(
{ {
"epoch": 100, "epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
checkpoint_path,
)
torch.save(
{
"epoch": epoch, # type: ignore
"model_state_dict": model.state_dict(), "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
}, },

View File

@@ -2,19 +2,19 @@ import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
# note that its LVD-1689M (not SAT) # note that its SAT
def make_transform(resize_size: int = 256): def make_transform(resize_size: int = 256):
to_tensor = v2.ToImage() to_tensor = v2.ToImage()
resize = v2.Resize((resize_size, resize_size), antialias=True) resize = v2.Resize((resize_size, resize_size), antialias=True)
to_float = v2.ToDtype(torch.float32, scale=True) to_float = v2.ToDtype(torch.float32, scale=True)
normalize = v2.Normalize( normalize = v2.Normalize(
mean=(0.485, 0.456, 0.406), mean=(0.430, 0.411, 0.296),
std=(0.229, 0.224, 0.225), std=(0.213, 0.156, 0.143),
) )
return v2.Compose([to_tensor, resize, to_float, normalize]) return v2.Compose([to_tensor, resize, to_float, normalize])
def denormalize(tensor): def denormalize(tensor):
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device) mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device) std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
return tensor * std + mean return tensor * std + mean