flash attn
This commit is contained in:
@@ -5,6 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from flash_attn import flash_attn_func
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
|
||||
@@ -194,53 +195,54 @@ class DINOv3ViTAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
other: torch.Tensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None, # wont work rn
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert position_embeddings is not None
|
||||
# assert position_embeddings is not None
|
||||
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states) if other is None else self.k_proj(other)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states) if other is None else self.v_proj(other)
|
||||
)
|
||||
|
||||
query_states = query_states.view(
|
||||
batch_size, patches, self.num_heads, self.head_dim
|
||||
batch_size, -1, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
batch_size, patches, self.num_heads, self.head_dim
|
||||
batch_size, -1, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
batch_size, patches, self.num_heads, self.head_dim
|
||||
batch_size, -1, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
attn_weights = (
|
||||
torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling
|
||||
)
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
query_states.dtype
|
||||
)
|
||||
|
||||
if self.training:
|
||||
attn_weights = F.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=None, # Will use default 1/sqrt(headdim)
|
||||
causal=False,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() # type: ignore
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
return attn_output
|
||||
|
||||
|
||||
class DINOv3ViTLayerScale(nn.Module):
|
||||
@@ -350,7 +352,7 @@ class DINOv3ViTLayer(nn.Module):
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states, _ = self.attention(
|
||||
hidden_states = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
|
||||
@@ -8,6 +8,7 @@ from torch import nn
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.dino import (
|
||||
DINOv3ViTAttention,
|
||||
DINOv3ViTEmbeddings,
|
||||
DINOv3ViTLayer,
|
||||
DINOv3ViTLayerScale,
|
||||
@@ -81,12 +82,7 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
self.is_encoder = is_encoder
|
||||
|
||||
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cond = nn.MultiheadAttention(
|
||||
config.hidden_size,
|
||||
config.num_attention_heads,
|
||||
config.drop_path_rate,
|
||||
batch_first=True,
|
||||
)
|
||||
self.cond = DINOv3ViTAttention(config)
|
||||
self.layer_scale_cond = DINOv3ViTLayerScale(config)
|
||||
|
||||
# no init zeros!
|
||||
@@ -114,7 +110,7 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states, _ = self.attention(
|
||||
hidden_states = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
@@ -125,8 +121,9 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
if do_condition:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_cond(hidden_states)
|
||||
hidden_states, _ = self.cond(
|
||||
hidden_states, conditioning_input, conditioning_input
|
||||
hidden_states = self.cond(
|
||||
hidden_states,
|
||||
conditioning_input,
|
||||
)
|
||||
hidden_states = self.layer_scale_cond(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
@@ -188,9 +185,7 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.num_channels * (self.patch_size**2),
|
||||
bias=True,
|
||||
config.hidden_size, config.num_channels * (self.patch_size**2), bias=True
|
||||
)
|
||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||
|
||||
@@ -209,11 +204,8 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
x = self.projection(x)
|
||||
|
||||
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
||||
|
||||
x = self.pixel_shuffle(x)
|
||||
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user