v2; a lot of hacks

This commit is contained in:
neulus
2025-10-15 18:02:10 +09:00
parent e51017897d
commit 6ab33ceb83
7 changed files with 586 additions and 199 deletions

View File

@@ -1,6 +1,7 @@
import math
import os
import lovely_tensors as lt
import torch
import torch.optim as optim
from torch.cuda.amp import autocast
@@ -14,11 +15,13 @@ from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer
from src.rf import RF
lt.monkey_patch()
train_dataset, test_dataset = get_dataset()
device = "cuda:1"
batch_size = 8 * 4 * 2
batch_size = 16
accumulation_steps = 2
total_epoch = 500