better arti
This commit is contained in:
8
main.py
8
main.py
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -41,6 +42,11 @@ train_dataset.set_format(type="torch", columns=["x0", "x1"])
|
|||||||
|
|
||||||
wandb.init(project="cloud-removal-kmu")
|
wandb.init(project="cloud-removal-kmu")
|
||||||
|
|
||||||
|
if not (wandb.run and wandb.run.name):
|
||||||
|
raise Exception("nope")
|
||||||
|
|
||||||
|
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
|
||||||
|
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
lossbin = {i: 0 for i in range(10)}
|
lossbin = {i: 0 for i in range(10)}
|
||||||
@@ -77,6 +83,6 @@ for epoch in range(100):
|
|||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
},
|
},
|
||||||
f"checkpoint_epoch_{epoch + 1}.pt",
|
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
|
||||||
)
|
)
|
||||||
wandb.finish()
|
wandb.finish()
|
||||||
|
|||||||
Reference in New Issue
Block a user