import math from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig def get_patches_center_coordinates(num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) coords_h = coords_h / num_patches_h coords_w = coords_w / num_patches_w coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) coords = coords.flatten(0, 1) coords = 2.0 * coords - 1.0 return coords def augment_patches_center_coordinates(coords: torch.Tensor, shift: Optional[float] = None, jitter: Optional[float] = None, rescale: Optional[float] = None) -> torch.Tensor: if shift is not None: shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) shift_hw = shift_hw.uniform_(-shift, shift) coords = coords + shift_hw if jitter is not None: jitter_range = np.log(jitter) jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp() coords = coords * jitter_hw if rescale is not None: rescale_range = np.log(rescale) rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype) rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp() coords = coords * rescale_hw return coords def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: num_tokens = q.shape[-2] num_patches = sin.shape[-2] num_prefix_tokens = num_tokens - num_patches q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) q = torch.cat((q_prefix_tokens, q_patches), dim=-2) k = torch.cat((k_prefix_tokens, k_patches), dim=-2) return q, k def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() output = input.div(keep_prob) * random_tensor return output class DINOv3ViTEmbeddings(nn.Module): def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size)) self.patch_embeddings = nn.Conv2d( config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size ) def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embeddings.weight.dtype patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) if bool_masked_pos is not None: mask_token = self.mask_token.to(patch_embeddings.dtype) patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) cls_token = self.cls_token.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) return embeddings class DINOv3ViTRopePositionEmbedding(nn.Module): def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.base = config.rope_theta self.head_dim = config.hidden_size // config.num_attention_heads self.num_patches_h = config.image_size // config.patch_size self.num_patches_w = config.image_size // config.patch_size inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: _, _, height, width = pixel_values.shape num_patches_h = height // self.config.patch_size num_patches_w = width // self.config.patch_size device = pixel_values.device device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): patch_coords = get_patches_center_coordinates( num_patches_h, num_patches_w, dtype=torch.float32, device=device ) if self.training: patch_coords = augment_patches_center_coordinates( patch_coords, shift=self.config.pos_embed_shift, jitter=self.config.pos_embed_jitter, rescale=self.config.pos_embed_rescale, ) angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] # type: ignore angles = angles.flatten(1, 2) angles = angles.tile(2) cos = torch.cos(angles) sin = torch.sin(angles) dtype = pixel_values.dtype return cos.to(dtype=dtype), sin.to(dtype=dtype) class DINOv3ViTAttention(nn.Module): def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.scaling = self.head_dim**-0.5 self.dropout = config.attention_dropout self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias) self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: assert position_embeddings is not None batch_size, patches, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) if self.training: attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) if attention_mask is not None: attn_weights = attn_weights * attention_mask attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class DINOv3ViTLayerScale(nn.Module): def __init__(self, config): super().__init__() self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return hidden_state * self.lambda1 class DINOv3ViTDropPath(nn.Module): def __init__(self, drop_prob: float): super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) class DINOv3ViTMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) if config.hidden_act == "gelu": self.act_fn = F.gelu elif config.hidden_act == "relu": self.act_fn = F.relu elif config.hidden_act == "silu": self.act_fn = F.silu else: self.act_fn = F.gelu def forward(self, x): return self.down_proj(self.act_fn(self.up_proj(x))) class DINOv3ViTGatedMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) if config.hidden_act == "gelu": self.act_fn = F.gelu elif config.hidden_act == "relu": self.act_fn = F.relu elif config.hidden_act == "silu": self.act_fn = F.silu else: self.act_fn = F.gelu def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class DINOv3ViTLayer(nn.Module): def __init__(self, config: DINOv3ViTConfig): super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = DINOv3ViTAttention(config) self.layer_scale1 = DINOv3ViTLayerScale(config) self.drop_path = DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_gated_mlp: self.mlp = DINOv3ViTGatedMLP(config) else: self.mlp = DINOv3ViTMLP(config) self.layer_scale2 = DINOv3ViTLayerScale(config) def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor: assert position_embeddings is not None residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states, _ = self.attention( hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, ) hidden_states = self.layer_scale1(hidden_states) hidden_states = self.drop_path(hidden_states) + residual residual = hidden_states hidden_states = self.norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.layer_scale2(hidden_states) hidden_states = self.drop_path(hidden_states) + residual return hidden_states class DINOv3ViTModel(nn.Module): def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.embeddings = DINOv3ViTEmbeddings(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) self.layers = nn.ModuleList([DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data = nn.init.trunc_normal_( module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, DINOv3ViTEmbeddings): module.cls_token.data = nn.init.trunc_normal_( module.cls_token.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.cls_token.dtype) if module.config.num_register_tokens > 0: module.register_tokens.data = nn.init.trunc_normal_( module.register_tokens.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.register_tokens.dtype) module.mask_token.data.zero_() elif isinstance(module, DINOv3ViTLayerScale): module.lambda1.data.fill_(self.config.layerscale_value) def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None): pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) position_embeddings = self.rope_embeddings(pixel_values) for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None hidden_states = layer_module( hidden_states, attention_mask=layer_head_mask, position_embeddings=position_embeddings, ) sequence_output = self.norm(hidden_states) pooled_output = sequence_output[:, 0, :] return { "last_hidden_state": sequence_output, "pooler_output": pooled_output, }