flash attn

This commit is contained in:
neulus
2025-10-15 00:18:34 +09:00
parent 3b03453e5d
commit e51017897d
5 changed files with 69 additions and 52 deletions

20
main.py
View File

@@ -3,6 +3,7 @@ import os
import torch
import torch.optim as optim
from torch.cuda.amp import autocast
from torchvision.utils import make_grid
from tqdm import tqdm
@@ -17,8 +18,8 @@ train_dataset, test_dataset = get_dataset()
device = "cuda:1"
batch_size = 8 * 4
accumulation_steps = 4
batch_size = 8 * 4 * 2
accumulation_steps = 2
total_epoch = 500
steps_per_epoch = len(train_dataset) // batch_size
@@ -28,9 +29,11 @@ warmup_steps = int(0.05 * total_steps)
grad_norm = 1.0
model = UTransformer.from_pretrained_backbone(
"facebook/dinov3-vitl16-pretrain-sat493m"
).to(device)
model = (
UTransformer.from_pretrained_backbone("facebook/dinov3-vitl16-pretrain-sat493m")
.to(device)
.bfloat16()
)
rf = RF(model, "icfm", "lpips_mse")
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
@@ -81,9 +84,10 @@ for epoch in range(start_epoch, total_epoch):
cloud = batch["cloud"].to(device)
gt = batch["gt"].to(device)
loss, blsct, loss_list = rf.forward(gt, cloud)
loss = loss / accumulation_steps
loss.backward()
with autocast(dtype=torch.bfloat16):
loss, blsct, loss_list = rf.forward(gt, cloud)
loss = loss / accumulation_steps
loss.backward()
if (i // batch_size + 1) % accumulation_steps == 0:
# total_norm = torch.nn.utils.clip_grad_norm_(