flash attn

This commit is contained in:
neulus
2025-10-15 00:18:34 +09:00
parent 3b03453e5d
commit e51017897d
5 changed files with 69 additions and 52 deletions

View File

@@ -8,6 +8,7 @@ from torch import nn
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
from src.model.dino import (
DINOv3ViTAttention,
DINOv3ViTEmbeddings,
DINOv3ViTLayer,
DINOv3ViTLayerScale,
@@ -81,12 +82,7 @@ class DinoConditionedLayer(DINOv3ViTLayer):
self.is_encoder = is_encoder
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cond = nn.MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
config.drop_path_rate,
batch_first=True,
)
self.cond = DINOv3ViTAttention(config)
self.layer_scale_cond = DINOv3ViTLayerScale(config)
# no init zeros!
@@ -114,7 +110,7 @@ class DinoConditionedLayer(DINOv3ViTLayer):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states, _ = self.attention(
hidden_states = self.attention(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
@@ -125,8 +121,9 @@ class DinoConditionedLayer(DINOv3ViTLayer):
if do_condition:
residual = hidden_states
hidden_states = self.norm_cond(hidden_states)
hidden_states, _ = self.cond(
hidden_states, conditioning_input, conditioning_input
hidden_states = self.cond(
hidden_states,
conditioning_input,
)
hidden_states = self.layer_scale_cond(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
@@ -188,9 +185,7 @@ class DinoV3ViTDecoder(nn.Module):
self.patch_size = config.patch_size
self.projection = nn.Linear(
config.hidden_size,
config.num_channels * (self.patch_size**2),
bias=True,
config.hidden_size, config.num_channels * (self.patch_size**2), bias=True
)
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
@@ -209,11 +204,8 @@ class DinoV3ViTDecoder(nn.Module):
w_grid = image_size[1] // p
assert x.shape[1] == h_grid * w_grid
x = self.projection(x)
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
x = self.pixel_shuffle(x)
return x