diff --git a/main.py b/main.py index 62ba1cf..8c88887 100644 --- a/main.py +++ b/main.py @@ -85,4 +85,13 @@ for epoch in range(100): }, f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt", ) + +torch.save( + { + "epoch": 100, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + }, + f"artifact/{wandb.run.name}/checkpoint_final.pt", +) wandb.finish()