diff --git a/src/model/dino.py b/src/model/dino.py index 18504b1..bff3f0d 100644 --- a/src/model/dino.py +++ b/src/model/dino.py @@ -214,7 +214,7 @@ class DINOv3ViTDropPath(nn.Module): class DINOv3ViTMLP(nn.Module): - def __init__(self, config): + def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -236,7 +236,7 @@ class DINOv3ViTMLP(nn.Module): class DINOv3ViTGatedMLP(nn.Module): - def __init__(self, config): + def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -274,8 +274,8 @@ class DINOv3ViTLayer(nn.Module): self.mlp = DINOv3ViTMLP(config) self.layer_scale2 = DINOv3ViTLayerScale(config) - def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, *, attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> torch.Tensor: assert position_embeddings is not None residual = hidden_states diff --git a/src/model/utransformer.py b/src/model/utransformer.py new file mode 100644 index 0000000..7feee1b --- /dev/null +++ b/src/model/utransformer.py @@ -0,0 +1,205 @@ +from typing import Optional +from torch import nn +import torch +import math +from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig + +from src.model.dino import DINOv3ViTEmbeddings, DINOv3ViTLayerScale, DINOv3ViTRopePositionEmbedding, DINOv3ViTLayer + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size: int, frequency_embedding_size: int=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half) / half + ).to(t.device) + args = t[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( + dtype=next(self.parameters()).dtype + ) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = int(dropout_prob > 0) + self.embedding_table = nn.Embedding( + num_classes + use_cfg_embedding, hidden_size + ) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob + drop_ids = drop_ids.cuda() + drop_ids = drop_ids.to(labels.device) + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + +class DinoConditionedLayer(DINOv3ViTLayer): + def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False): + super().__init__(config) + + self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.cond = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.drop_path_rate, batch_first=True) + self.layer_scale_cond = DINOv3ViTLayerScale(config) + + # no init zeros! + if is_encoder: + nn.init.constant_(self.layer_scale_cond.lambda1, 0) + self.norm1.requires_grad_(False) + self.norm2.requires_grad_(False) + self.attention.requires_grad_(False) + self.mlp.requires_grad_(False) + self.layer_scale1.requires_grad_(False) + self.layer_scale2.requires_grad_(False) + + + def forward(self, hidden_states: torch.Tensor, *, conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> torch.Tensor: + assert position_embeddings is not None + assert conditioning_input is not None + + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states, _ = self.attention( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual + + residual = hidden_states + hidden_states = self.norm_cond(hidden_states) + hidden_states, _ = self.cond(hidden_states, conditioning_input, conditioning_input) + hidden_states = self.layer_scale_cond(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual + + return hidden_states + + +class DinoV3ViTDecoder(nn.Module): + def __init__(self, config: DINOv3ViTConfig): + super().__init__() + self.config = config + self.num_channels_out = config.num_channels + + self.projection = nn.Linear( + config.hidden_size, + self.num_channels_out * config.patch_size * config.patch_size, + bias=True + ) + + def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: + batch_size = x.shape[0] + + num_special_tokens = 1 + self.config.num_register_tokens + patch_tokens = x[:, num_special_tokens:, :] + + projected_tokens = self.projection(patch_tokens) + + p = self.config.patch_size + c = self.num_channels_out + h_grid = image_size[0] // p + w_grid = image_size[1] // p + + assert patch_tokens.shape[1] == h_grid * w_grid, "Number of patches does not match image size." + + x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c) + + x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped) + + reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p) + + return reconstructed_image + +class UTransformer(nn.Module): + def __init__(self, config: DINOv3ViTConfig, num_classes: int): + super().__init__() + self.config = config + + self.embeddings = DINOv3ViTEmbeddings(config) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) + self.t_embedder = TimestepEmbedder(config.hidden_size) + self.y_embedder = LabelEmbedder(num_classes, config.hidden_size, config.drop_path_rate) + + self.encoder_layers = nn.ModuleList([DinoConditionedLayer(config, True) for _ in range(config.num_hidden_layers)]) + self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder_layers = nn.ModuleList([DinoConditionedLayer(config, False) for _ in range(config.num_hidden_layers)]) + self.decoder = DinoV3ViTDecoder(config) + + # freeze pretrained + self.embeddings.requires_grad_(False) + self.rope_embeddings.requires_grad_(False) + self.encoder_norm.requires_grad_(False) + + def forward(self, pixel_values: torch.Tensor, time: torch.Tensor, cond: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None): + pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) + position_embeddings = self.rope_embeddings(pixel_values) + + x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + t = self.t_embedder(time).unsqueeze(1) + y = self.y_embedder(cond, self.training).unsqueeze(1) + conditioning_input = t.to(x.dtype) + y.to(x.dtype) + + residual = [] + for i, layer_module in enumerate(self.encoder_layers): + residual.append(x) + layer_head_mask = head_mask[i] if head_mask is not None else None + x = layer_module( + x, + conditioning_input=conditioning_input, + attention_mask=layer_head_mask, + position_embeddings=position_embeddings, + ) + x = self.encoder_norm(x) + + for i, layer_module in enumerate(self.decoder_layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + x = layer_module( + x, + conditioning_input=conditioning_input, + attention_mask=layer_head_mask, + position_embeddings=position_embeddings, + ) + x = x + residual.pop() + + return self.decoder(x, image_size=pixel_values.shape[-2:])