v2; a lot of hacks
This commit is contained in:
5
main.py
5
main.py
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import lovely_tensors as lt
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.cuda.amp import autocast
|
||||
@@ -14,11 +15,13 @@ from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
lt.monkey_patch()
|
||||
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
device = "cuda:1"
|
||||
|
||||
batch_size = 8 * 4 * 2
|
||||
batch_size = 16
|
||||
accumulation_steps = 2
|
||||
total_epoch = 500
|
||||
|
||||
|
||||
Reference in New Issue
Block a user