flash attn
This commit is contained in:
20
main.py
20
main.py
@@ -3,6 +3,7 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.cuda.amp import autocast
|
||||
from torchvision.utils import make_grid
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -17,8 +18,8 @@ train_dataset, test_dataset = get_dataset()
|
||||
|
||||
device = "cuda:1"
|
||||
|
||||
batch_size = 8 * 4
|
||||
accumulation_steps = 4
|
||||
batch_size = 8 * 4 * 2
|
||||
accumulation_steps = 2
|
||||
total_epoch = 500
|
||||
|
||||
steps_per_epoch = len(train_dataset) // batch_size
|
||||
@@ -28,9 +29,11 @@ warmup_steps = int(0.05 * total_steps)
|
||||
grad_norm = 1.0
|
||||
|
||||
|
||||
model = UTransformer.from_pretrained_backbone(
|
||||
"facebook/dinov3-vitl16-pretrain-sat493m"
|
||||
).to(device)
|
||||
model = (
|
||||
UTransformer.from_pretrained_backbone("facebook/dinov3-vitl16-pretrain-sat493m")
|
||||
.to(device)
|
||||
.bfloat16()
|
||||
)
|
||||
rf = RF(model, "icfm", "lpips_mse")
|
||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||
|
||||
@@ -81,9 +84,10 @@ for epoch in range(start_epoch, total_epoch):
|
||||
cloud = batch["cloud"].to(device)
|
||||
gt = batch["gt"].to(device)
|
||||
|
||||
loss, blsct, loss_list = rf.forward(gt, cloud)
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
with autocast(dtype=torch.bfloat16):
|
||||
loss, blsct, loss_list = rf.forward(gt, cloud)
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
if (i // batch_size + 1) % accumulation_steps == 0:
|
||||
# total_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
||||
Reference in New Issue
Block a user