Files
cloud-removal/src/model/dino.py
2025-10-15 18:02:10 +09:00

442 lines
16 KiB
Python

import math
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn import flash_attn_func
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
def get_patches_center_coordinates(
num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
coords_h = coords_h / num_patches_h
coords_w = coords_w / num_patches_w
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = coords.flatten(0, 1)
coords = 2.0 * coords - 1.0
return coords
def augment_patches_center_coordinates(
coords: torch.Tensor,
shift: Optional[float] = None,
jitter: Optional[float] = None,
rescale: Optional[float] = None,
) -> torch.Tensor:
if shift is not None:
shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
shift_hw = shift_hw.uniform_(-shift, shift)
coords = coords + shift_hw
if jitter is not None:
jitter_range = np.log(jitter)
jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp()
coords = coords * jitter_hw
if rescale is not None:
rescale_range = np.log(rescale)
rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype)
rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp()
coords = coords * rescale_hw
return coords
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
num_tokens = q.shape[-2]
num_patches = sin.shape[-2]
num_prefix_tokens = num_tokens - num_patches
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
return q, k
def drop_path(
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
) -> torch.Tensor:
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=input.dtype, device=input.device
)
random_tensor.floor_()
output = input.div(keep_prob) * random_tensor
return output
class DINOv3ViTEmbeddings(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.config = config
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.register_tokens = nn.Parameter(
torch.empty(1, config.num_register_tokens, config.hidden_size)
)
self.patch_embeddings = nn.Conv2d(
config.num_channels,
config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
)
def forward(
self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embeddings.weight.dtype
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
if bool_masked_pos is not None:
mask_token = self.mask_token.to(patch_embeddings.dtype)
patch_embeddings = torch.where(
bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings
)
cls_token = self.cls_token.expand(batch_size, -1, -1)
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
return embeddings
class DINOv3ViTRopePositionEmbedding(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.config = config
self.base = config.rope_theta
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_patches_h = config.image_size // config.patch_size
self.num_patches_w = config.image_size // config.patch_size
inv_freq = 1 / self.base ** torch.arange(
0, 1, 4 / self.head_dim, dtype=torch.float32
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
_, _, height, width = pixel_values.shape
num_patches_h = height // self.config.patch_size
num_patches_w = width // self.config.patch_size
device = pixel_values.device
device_type = (
device.type
if isinstance(device.type, str) and device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
patch_coords = get_patches_center_coordinates(
num_patches_h, num_patches_w, dtype=torch.float32, device=device
)
if self.training:
patch_coords = augment_patches_center_coordinates(
patch_coords,
shift=self.config.pos_embed_shift,
jitter=self.config.pos_embed_jitter,
rescale=self.config.pos_embed_rescale,
)
angles = (
2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] # type: ignore
)
angles = angles.flatten(1, 2)
angles = angles.tile(2)
cos = torch.cos(angles)
sin = torch.sin(angles)
dtype = pixel_values.dtype
return cos.to(dtype=dtype), sin.to(dtype=dtype)
class DINOv3ViTAttention(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.scaling = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias)
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias)
def forward(
self,
hidden_states: torch.Tensor,
other: torch.Tensor | None = None,
attention_mask: Optional[torch.Tensor] = None, # wont work rn
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# assert position_embeddings is not None
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) if other is None else self.k_proj(other)
value_states = (
self.v_proj(hidden_states) if other is None else self.v_proj(other)
)
query_states = query_states.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_p = self.dropout if self.training else 0.0
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_p,
softmax_scale=None,
causal=False,
)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() # type: ignore
attn_output = self.o_proj(attn_output)
return attn_output
class DINOv3ViTLayerScale(nn.Module):
def __init__(self, config):
super().__init__()
self.lambda1 = nn.Parameter(
config.layerscale_value * torch.ones(config.hidden_size)
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return hidden_state * self.lambda1
class DINOv3ViTDropPath(nn.Module):
def __init__(self, drop_prob: float):
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
class DINOv3ViTMLP(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
)
if config.hidden_act == "gelu":
self.act_fn = F.gelu
elif config.hidden_act == "relu":
self.act_fn = F.relu
elif config.hidden_act == "silu":
self.act_fn = F.silu
else:
self.act_fn = F.gelu
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
class DINOv3ViTGatedMLP(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
)
if config.hidden_act == "gelu":
self.act_fn = F.gelu
elif config.hidden_act == "relu":
self.act_fn = F.relu
elif config.hidden_act == "silu":
self.act_fn = F.silu
else:
self.act_fn = F.gelu
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class DINOv3ViTLayer(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = DINOv3ViTAttention(config)
self.layer_scale1 = DINOv3ViTLayerScale(config)
self.drop_path = (
DINOv3ViTDropPath(config.drop_path_rate)
if config.drop_path_rate > 0.0
else nn.Identity()
)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if config.use_gated_mlp:
self.mlp = DINOv3ViTGatedMLP(config)
else:
self.mlp = DINOv3ViTMLP(config)
self.layer_scale2 = DINOv3ViTLayerScale(config)
def forward(
self,
hidden_states: torch.Tensor,
*,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
assert position_embeddings is not None
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attention(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
)
hidden_states = self.layer_scale1(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.layer_scale2(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
return hidden_states
class DINOv3ViTModel(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.config = config
self.embeddings = DINOv3ViTEmbeddings(config)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
self.layer = nn.ModuleList(
[DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DINOv3ViTEmbeddings):
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.cls_token.dtype)
if module.config.num_register_tokens > 0:
module.register_tokens.data = nn.init.trunc_normal_(
module.register_tokens.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.register_tokens.dtype)
module.mask_token.data.zero_()
elif isinstance(module, DINOv3ViTLayerScale):
module.lambda1.data.fill_(self.config.layerscale_value)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
latents = []
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module(
hidden_states,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
latents.append(hidden_states)
sequence_output = self.norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return {
"last_hidden_state": sequence_output,
"latents": latents,
"pooler_output": pooled_output,
}