flash attn
This commit is contained in:
20
main.py
20
main.py
@@ -3,6 +3,7 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -17,8 +18,8 @@ train_dataset, test_dataset = get_dataset()
|
|||||||
|
|
||||||
device = "cuda:1"
|
device = "cuda:1"
|
||||||
|
|
||||||
batch_size = 8 * 4
|
batch_size = 8 * 4 * 2
|
||||||
accumulation_steps = 4
|
accumulation_steps = 2
|
||||||
total_epoch = 500
|
total_epoch = 500
|
||||||
|
|
||||||
steps_per_epoch = len(train_dataset) // batch_size
|
steps_per_epoch = len(train_dataset) // batch_size
|
||||||
@@ -28,9 +29,11 @@ warmup_steps = int(0.05 * total_steps)
|
|||||||
grad_norm = 1.0
|
grad_norm = 1.0
|
||||||
|
|
||||||
|
|
||||||
model = UTransformer.from_pretrained_backbone(
|
model = (
|
||||||
"facebook/dinov3-vitl16-pretrain-sat493m"
|
UTransformer.from_pretrained_backbone("facebook/dinov3-vitl16-pretrain-sat493m")
|
||||||
).to(device)
|
.to(device)
|
||||||
|
.bfloat16()
|
||||||
|
)
|
||||||
rf = RF(model, "icfm", "lpips_mse")
|
rf = RF(model, "icfm", "lpips_mse")
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||||
|
|
||||||
@@ -81,9 +84,10 @@ for epoch in range(start_epoch, total_epoch):
|
|||||||
cloud = batch["cloud"].to(device)
|
cloud = batch["cloud"].to(device)
|
||||||
gt = batch["gt"].to(device)
|
gt = batch["gt"].to(device)
|
||||||
|
|
||||||
loss, blsct, loss_list = rf.forward(gt, cloud)
|
with autocast(dtype=torch.bfloat16):
|
||||||
loss = loss / accumulation_steps
|
loss, blsct, loss_list = rf.forward(gt, cloud)
|
||||||
loss.backward()
|
loss = loss / accumulation_steps
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
if (i // batch_size + 1) % accumulation_steps == 0:
|
if (i // batch_size + 1) % accumulation_steps == 0:
|
||||||
# total_norm = torch.nn.utils.clip_grad_norm_(
|
# total_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
|||||||
@@ -22,4 +22,11 @@ dependencies = [
|
|||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
"transformers>=4.56.2",
|
"transformers>=4.56.2",
|
||||||
"wandb[media]>=0.22.0",
|
"wandb[media]>=0.22.0",
|
||||||
|
"flash-attn"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.uv.extra-build-dependencies]
|
||||||
|
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
||||||
|
|
||||||
|
[tool.uv.extra-build-variables]
|
||||||
|
flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" }
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -194,53 +195,54 @@ class DINOv3ViTAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
other: torch.Tensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None, # wont work rn
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert position_embeddings is not None
|
# assert position_embeddings is not None
|
||||||
|
|
||||||
batch_size, patches, _ = hidden_states.size()
|
batch_size, patches, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states) if other is None else self.k_proj(other)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = (
|
||||||
|
self.v_proj(hidden_states) if other is None else self.v_proj(other)
|
||||||
|
)
|
||||||
|
|
||||||
query_states = query_states.view(
|
query_states = query_states.view(
|
||||||
batch_size, patches, self.num_heads, self.head_dim
|
batch_size, -1, self.num_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
key_states = key_states.view(
|
key_states = key_states.view(
|
||||||
batch_size, patches, self.num_heads, self.head_dim
|
batch_size, -1, self.num_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
value_states = value_states.view(
|
value_states = value_states.view(
|
||||||
batch_size, patches, self.num_heads, self.head_dim
|
batch_size, -1, self.num_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
if position_embeddings is not None:
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
cos, sin = position_embeddings
|
||||||
query_states, key_states, cos, sin
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
)
|
query_states, key_states, cos, sin
|
||||||
|
|
||||||
attn_weights = (
|
|
||||||
torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling
|
|
||||||
)
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
|
||||||
query_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
attn_weights = F.dropout(
|
|
||||||
attn_weights, p=self.dropout, training=self.training
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
query_states = query_states.transpose(1, 2)
|
||||||
attn_weights = attn_weights * attention_mask
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
dropout_p = self.dropout if self.training else 0.0
|
||||||
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = flash_attn_func(
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
query_states,
|
||||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
key_states,
|
||||||
|
value_states,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
softmax_scale=None, # Will use default 1/sqrt(headdim)
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() # type: ignore
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
class DINOv3ViTLayerScale(nn.Module):
|
class DINOv3ViTLayerScale(nn.Module):
|
||||||
@@ -350,7 +352,7 @@ class DINOv3ViTLayer(nn.Module):
|
|||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm1(hidden_states)
|
hidden_states = self.norm1(hidden_states)
|
||||||
hidden_states, _ = self.attention(
|
hidden_states = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from torch import nn
|
|||||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||||
|
|
||||||
from src.model.dino import (
|
from src.model.dino import (
|
||||||
|
DINOv3ViTAttention,
|
||||||
DINOv3ViTEmbeddings,
|
DINOv3ViTEmbeddings,
|
||||||
DINOv3ViTLayer,
|
DINOv3ViTLayer,
|
||||||
DINOv3ViTLayerScale,
|
DINOv3ViTLayerScale,
|
||||||
@@ -81,12 +82,7 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
|||||||
self.is_encoder = is_encoder
|
self.is_encoder = is_encoder
|
||||||
|
|
||||||
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.cond = nn.MultiheadAttention(
|
self.cond = DINOv3ViTAttention(config)
|
||||||
config.hidden_size,
|
|
||||||
config.num_attention_heads,
|
|
||||||
config.drop_path_rate,
|
|
||||||
batch_first=True,
|
|
||||||
)
|
|
||||||
self.layer_scale_cond = DINOv3ViTLayerScale(config)
|
self.layer_scale_cond = DINOv3ViTLayerScale(config)
|
||||||
|
|
||||||
# no init zeros!
|
# no init zeros!
|
||||||
@@ -114,7 +110,7 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
|||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm1(hidden_states)
|
hidden_states = self.norm1(hidden_states)
|
||||||
hidden_states, _ = self.attention(
|
hidden_states = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
@@ -125,8 +121,9 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
|||||||
if do_condition:
|
if do_condition:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm_cond(hidden_states)
|
hidden_states = self.norm_cond(hidden_states)
|
||||||
hidden_states, _ = self.cond(
|
hidden_states = self.cond(
|
||||||
hidden_states, conditioning_input, conditioning_input
|
hidden_states,
|
||||||
|
conditioning_input,
|
||||||
)
|
)
|
||||||
hidden_states = self.layer_scale_cond(hidden_states)
|
hidden_states = self.layer_scale_cond(hidden_states)
|
||||||
hidden_states = self.drop_path(hidden_states) + residual
|
hidden_states = self.drop_path(hidden_states) + residual
|
||||||
@@ -188,9 +185,7 @@ class DinoV3ViTDecoder(nn.Module):
|
|||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
self.projection = nn.Linear(
|
self.projection = nn.Linear(
|
||||||
config.hidden_size,
|
config.hidden_size, config.num_channels * (self.patch_size**2), bias=True
|
||||||
config.num_channels * (self.patch_size**2),
|
|
||||||
bias=True,
|
|
||||||
)
|
)
|
||||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||||
|
|
||||||
@@ -209,11 +204,8 @@ class DinoV3ViTDecoder(nn.Module):
|
|||||||
w_grid = image_size[1] // p
|
w_grid = image_size[1] // p
|
||||||
|
|
||||||
assert x.shape[1] == h_grid * w_grid
|
assert x.shape[1] == h_grid * w_grid
|
||||||
|
|
||||||
x = self.projection(x)
|
x = self.projection(x)
|
||||||
|
|
||||||
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
x = self.pixel_shuffle(x)
|
x = self.pixel_shuffle(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
12
uv.lock
generated
12
uv.lock
generated
@@ -297,6 +297,7 @@ source = { virtual = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "datasets" },
|
{ name = "datasets" },
|
||||||
{ name = "einops" },
|
{ name = "einops" },
|
||||||
|
{ name = "flash-attn" },
|
||||||
{ name = "lpips" },
|
{ name = "lpips" },
|
||||||
{ name = "pyright" },
|
{ name = "pyright" },
|
||||||
{ name = "python-lsp-server" },
|
{ name = "python-lsp-server" },
|
||||||
@@ -318,6 +319,7 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "datasets", specifier = ">=4.1.1" },
|
{ name = "datasets", specifier = ">=4.1.1" },
|
||||||
{ name = "einops", specifier = ">=0.8.1" },
|
{ name = "einops", specifier = ">=0.8.1" },
|
||||||
|
{ name = "flash-attn" },
|
||||||
{ name = "lpips", specifier = ">=0.1.4" },
|
{ name = "lpips", specifier = ">=0.1.4" },
|
||||||
{ name = "pyright", specifier = ">=1.1.405" },
|
{ name = "pyright", specifier = ">=1.1.405" },
|
||||||
{ name = "python-lsp-server", specifier = ">=1.13.1" },
|
{ name = "python-lsp-server", specifier = ">=1.13.1" },
|
||||||
@@ -531,6 +533,16 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/42/14/42b2651a2f46b022ccd948bca9f2d5af0fd8929c4eec235b8d6d844fbe67/filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d", size = 15988, upload-time = "2025-08-14T16:56:01.633Z" },
|
{ url = "https://files.pythonhosted.org/packages/42/14/42b2651a2f46b022ccd948bca9f2d5af0fd8929c4eec235b8d6d844fbe67/filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d", size = 15988, upload-time = "2025-08-14T16:56:01.633Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "flash-attn"
|
||||||
|
version = "2.8.3"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "einops" },
|
||||||
|
{ name = "torch" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/3b/b2/8d76c41ad7974ee264754709c22963447f7f8134613fd9ce80984ed0dab7/flash_attn-2.8.3.tar.gz", hash = "sha256:1e71dd64a9e0280e0447b8a0c2541bad4bf6ac65bdeaa2f90e51a9e57de0370d", size = 8447812, upload-time = "2025-08-15T08:28:12.911Z" }
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fonttools"
|
name = "fonttools"
|
||||||
version = "4.60.1"
|
version = "4.60.1"
|
||||||
|
|||||||
Reference in New Issue
Block a user