import math from typing import Optional import torch from huggingface_hub import hf_hub_download from safetensors import safe_open from torch import nn from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig from src.model.dino import ( DINOv3ViTEmbeddings, DINOv3ViTLayer, DINOv3ViTLayerScale, DINOv3ViTRopePositionEmbedding, ) class TimestepEmbedder(nn.Module): def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size), nn.SiLU(), nn.Linear(hidden_size, hidden_size), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half) / half ).to(t.device) args = t[:, None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( dtype=next(self.parameters()).dtype ) t_emb = self.mlp(t_freq) return t_emb class LabelEmbedder(nn.Module): def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = int(dropout_prob > 0) self.embedding_table = nn.Embedding( num_classes + use_cfg_embedding, hidden_size ) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob drop_ids = drop_ids.cuda() drop_ids = drop_ids.to(labels.device) else: drop_ids = force_drop_ids == 1 labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels, train, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class DinoConditionedLayer(DINOv3ViTLayer): def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False): super().__init__(config) self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.cond = nn.MultiheadAttention( config.hidden_size, config.num_attention_heads, config.drop_path_rate, batch_first=True, ) self.layer_scale_cond = DINOv3ViTLayerScale(config) # no init zeros! if is_encoder: nn.init.constant_(self.layer_scale_cond.lambda1, 0) self.norm1.requires_grad_(False) self.norm2.requires_grad_(False) self.attention.requires_grad_(False) self.mlp.requires_grad_(False) self.layer_scale1.requires_grad_(False) self.layer_scale2.requires_grad_(False) def forward( self, hidden_states: torch.Tensor, *, conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: assert position_embeddings is not None assert conditioning_input is not None residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states, _ = self.attention( hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, ) hidden_states = self.layer_scale1(hidden_states) hidden_states = self.drop_path(hidden_states) + residual residual = hidden_states hidden_states = self.norm_cond(hidden_states) hidden_states, _ = self.cond( hidden_states, conditioning_input, conditioning_input ) hidden_states = self.layer_scale_cond(hidden_states) hidden_states = self.drop_path(hidden_states) + residual residual = hidden_states hidden_states = self.norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.layer_scale2(hidden_states) hidden_states = self.drop_path(hidden_states) + residual return hidden_states # class DinoV3ViTDecoder(nn.Module): # def __init__(self, config: DINOv3ViTConfig): # super().__init__() # self.config = config # self.num_channels_out = config.num_channels # self.projection = nn.Linear( # config.hidden_size, # self.num_channels_out * config.patch_size * config.patch_size, # bias=True, # ) # def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: # batch_size = x.shape[0] # num_special_tokens = 1 + self.config.num_register_tokens # patch_tokens = x[:, num_special_tokens:, :] # projected_tokens = self.projection(patch_tokens) # p = self.config.patch_size # c = self.num_channels_out # h_grid = image_size[0] // p # w_grid = image_size[1] // p # assert patch_tokens.shape[1] == h_grid * w_grid, ( # "Number of patches does not match image size." # ) # x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c) # x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped) # reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p) # return reconstructed_image # lets try conv decoder class DinoV3ViTDecoder(nn.Module): def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.num_channels_out = config.num_channels hidden_dim = config.hidden_size patch_size = config.patch_size self.projection = nn.Linear(hidden_dim, hidden_dim) if patch_size == 14: final_upsample = 7 elif patch_size == 16: final_upsample = 8 elif patch_size == 8: final_upsample = 4 else: raise ValueError("invalid") self.decoder = nn.Sequential( nn.Conv2d(hidden_dim, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Upsample( scale_factor=final_upsample, mode="bilinear", align_corners=False ), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, self.num_channels_out, kernel_size=1), ) def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: batch_size = x.shape[0] patch_tokens = x[:, 1 + self.config.num_register_tokens :, :] projected_tokens = self.projection(patch_tokens) p = self.config.patch_size h_grid = image_size[0] // p w_grid = image_size[1] // p assert patch_tokens.shape[1] == h_grid * w_grid x_spatial = projected_tokens.reshape( batch_size, h_grid, w_grid, self.config.hidden_size ) x_spatial = x_spatial.permute(0, 3, 1, 2) reconstructed_image = self.decoder(x_spatial) return reconstructed_image class UTransformer(nn.Module): def __init__(self, config: DINOv3ViTConfig, num_classes: int): super().__init__() self.config = config self.embeddings = DINOv3ViTEmbeddings(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) self.t_embedder = TimestepEmbedder(config.hidden_size) # self.y_embedder = LabelEmbedder( # num_classes, config.hidden_size, config.drop_path_rate # ) # disable cond for now self.encoder_layers = nn.ModuleList( [ DinoConditionedLayer(config, True) for _ in range(config.num_hidden_layers) ] ) self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.decoder_layers = nn.ModuleList( [ DinoConditionedLayer(config, False) for _ in range(config.num_hidden_layers) ] ) self.decoder = DinoV3ViTDecoder(config) # freeze pretrained self.embeddings.requires_grad_(False) self.rope_embeddings.requires_grad_(False) self.encoder_norm.requires_grad_(False) def forward( self, pixel_values: torch.Tensor, time: torch.Tensor, # cond: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, ): pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) position_embeddings = self.rope_embeddings(pixel_values) x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) t = self.t_embedder(time).unsqueeze(1) # y = self.y_embedder(cond, self.training).unsqueeze(1) # conditioning_input = t.to(x.dtype) + y.to(x.dtype) conditioning_input = t.to(x.dtype) residual = [] for i, layer_module in enumerate(self.encoder_layers): residual.append(x) layer_head_mask = head_mask[i] if head_mask is not None else None x = layer_module( x, conditioning_input=conditioning_input, attention_mask=layer_head_mask, position_embeddings=position_embeddings, ) x = self.encoder_norm(x) for i, layer_module in enumerate(self.decoder_layers): layer_head_mask = head_mask[i] if head_mask is not None else None x = layer_module( x, conditioning_input=conditioning_input, attention_mask=layer_head_mask, position_embeddings=position_embeddings, ) x = x + residual.pop() return self.decoder(x, image_size=pixel_values.shape[-2:]) @staticmethod def from_pretrained_backbone(name: str): config = DINOv3ViTConfig.from_pretrained(name) instance = UTransformer(config, 0).to("cuda:3") weight_dict = {} with safe_open( hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:3" ) as f: for key in f.keys(): new_key = key.replace("layer.", "encoder_layers.").replace( "norm.", "encoder_norm." ) weight_dict[new_key] = f.get_tensor(key) instance.load_state_dict(weight_dict, strict=False) return instance