flash attn

This commit is contained in:
neulus
2025-10-15 00:18:34 +09:00
parent 3b03453e5d
commit e51017897d
5 changed files with 69 additions and 52 deletions

View File

@@ -5,6 +5,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn import flash_attn_func
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
@@ -194,53 +195,54 @@ class DINOv3ViTAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
other: torch.Tensor | None = None,
attention_mask: Optional[torch.Tensor] = None, # wont work rn
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert position_embeddings is not None
# assert position_embeddings is not None
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
key_states = self.k_proj(hidden_states) if other is None else self.k_proj(other)
value_states = (
self.v_proj(hidden_states) if other is None else self.v_proj(other)
)
query_states = query_states.view(
batch_size, patches, self.num_heads, self.head_dim
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, patches, self.num_heads, self.head_dim
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, patches, self.num_heads, self.head_dim
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
attn_weights = (
torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling
)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
if self.training:
attn_weights = F.dropout(
attn_weights, p=self.dropout, training=self.training
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_p = self.dropout if self.training else 0.0
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_p,
softmax_scale=None, # Will use default 1/sqrt(headdim)
causal=False,
)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() # type: ignore
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
return attn_output
class DINOv3ViTLayerScale(nn.Module):
@@ -350,7 +352,7 @@ class DINOv3ViTLayer(nn.Module):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states, _ = self.attention(
hidden_states = self.attention(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,