better arti
This commit is contained in:
8
main.py
8
main.py
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
@@ -41,6 +42,11 @@ train_dataset.set_format(type="torch", columns=["x0", "x1"])
|
||||
|
||||
wandb.init(project="cloud-removal-kmu")
|
||||
|
||||
if not (wandb.run and wandb.run.name):
|
||||
raise Exception("nope")
|
||||
|
||||
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
|
||||
|
||||
batch_size = 16
|
||||
for epoch in range(100):
|
||||
lossbin = {i: 0 for i in range(10)}
|
||||
@@ -77,6 +83,6 @@ for epoch in range(100):
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
},
|
||||
f"checkpoint_epoch_{epoch + 1}.pt",
|
||||
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
|
||||
)
|
||||
wandb.finish()
|
||||
|
||||
Reference in New Issue
Block a user