better arti

This commit is contained in:
neulus
2025-09-29 16:18:28 +09:00
parent 95cc43d69e
commit f86076a314

View File

@@ -1,3 +1,4 @@
import os
from typing import cast from typing import cast
import torch import torch
@@ -41,6 +42,11 @@ train_dataset.set_format(type="torch", columns=["x0", "x1"])
wandb.init(project="cloud-removal-kmu") wandb.init(project="cloud-removal-kmu")
if not (wandb.run and wandb.run.name):
raise Exception("nope")
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
batch_size = 16 batch_size = 16
for epoch in range(100): for epoch in range(100):
lossbin = {i: 0 for i in range(10)} lossbin = {i: 0 for i in range(10)}
@@ -77,6 +83,6 @@ for epoch in range(100):
"model_state_dict": model.state_dict(), "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
}, },
f"checkpoint_epoch_{epoch + 1}.pt", f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
) )
wandb.finish() wandb.finish()