import copy import math from functools import lru_cache from typing import Optional import einops import torch from einops import rearrange from huggingface_hub import hf_hub_download from safetensors import safe_open from torch import nn from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig from src.model.attention import CrossAttention, PlainAttention, RoPE from src.model.dino import ( DINOv3ViTAttention, DINOv3ViTDropPath, DINOv3ViTEmbeddings, DINOv3ViTGatedMLP, DINOv3ViTLayer, DINOv3ViTLayerScale, DINOv3ViTMLP, DINOv3ViTRopePositionEmbedding, ) from src.model.dit import modulate from src.model.resnet import ResBlock def create_coordinate(h, w, start=0, end=1, device="cuda:1", dtype=torch.float32): # Create a grid of coordinates x = torch.linspace(start, end, h, device=device, dtype=dtype) y = torch.linspace(start, end, w, device=device, dtype=dtype) # Create a 2D map using meshgrid xx, yy = torch.meshgrid(x, y, indexing="ij") # Stack the x and y coordinates to create the final map coord_map = torch.stack([xx, yy], dim=-1)[None, ...] coords = rearrange(coord_map, "b h w c -> b (h w) c", h=h, w=w) return coords class TimestepEmbedder(nn.Module): def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size), nn.SiLU(), nn.Linear(hidden_size, hidden_size), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half) / half ).to(t.device) args = t[:, None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( dtype=next(self.parameters()).dtype ) t_emb = self.mlp(t_freq) return t_emb class LabelEmbedder(nn.Module): def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = int(dropout_prob > 0) self.embedding_table = nn.Embedding( num_classes + use_cfg_embedding, hidden_size ) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob drop_ids = drop_ids.cuda() drop_ids = drop_ids.to(labels.device) else: drop_ids = force_drop_ids == 1 labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels, train, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class DinoEncoderLayer(DINOv3ViTLayer): def __init__(self, config: DINOv3ViTConfig): super().__init__(config) self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.cond = DINOv3ViTAttention(config) self.layer_scale_cond = DINOv3ViTLayerScale(config) nn.init.constant_(self.layer_scale_cond.lambda1, 0) self.norm1.requires_grad_(False) self.norm2.requires_grad_(False) self.attention.requires_grad_(False) self.mlp.requires_grad_(False) self.layer_scale1.requires_grad_(False) self.layer_scale2.requires_grad_(False) def forward( self, hidden_states: torch.Tensor, *, conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: assert position_embeddings is not None hidden_states = ( self.drop_path( self.layer_scale1( self.attention( self.norm1(hidden_states), attention_mask=attention_mask, position_embeddings=position_embeddings, ) ) ) + hidden_states ) hidden_states = ( self.drop_path( self.layer_scale_cond( self.cond( hidden_states, self.norm_cond(conditioning_input), ) ) ) + hidden_states ) hidden_states = ( self.drop_path(self.layer_scale2(self.mlp(self.norm2(hidden_states)))) + hidden_states ) return hidden_states class DinoDecoderLayer(DINOv3ViTLayer): def __init__(self, config: DINOv3ViTConfig, depth: int): super().__init__(config) hidden_size = config.hidden_size // (4**depth) hacky_config = copy.copy(config) hacky_config.hidden_size = hidden_size hacky_config.intermediate_size = hacky_config.intermediate_size // (3**depth) self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) self.attention = PlainAttention( hidden_size, config.num_attention_heads // (2**depth) ) # head scaling law? self.layer_scale1 = DINOv3ViTLayerScale(hacky_config) self.drop_path = ( DINOv3ViTDropPath(hacky_config.drop_path_rate) if hacky_config.drop_path_rate > 0.0 else nn.Identity() ) self.norm2 = nn.LayerNorm(hidden_size, eps=hacky_config.layer_norm_eps) if config.use_gated_mlp: self.mlp = DINOv3ViTGatedMLP(hacky_config) else: self.mlp = DINOv3ViTMLP(hacky_config) self.layer_scale2 = DINOv3ViTLayerScale(hacky_config) # adaln self.adaln = nn.Sequential( nn.SiLU(), nn.Linear(config.hidden_size, 6 * hidden_size, bias=True), ) def forward( self, hidden_states: torch.Tensor, *, conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: conditioning_input = conditioning_input.squeeze(1) # type: ignore shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln( conditioning_input ).chunk(6, dim=-1) hidden_states = ( self.drop_path( gate_msa.unsqueeze(1) * self.layer_scale1( self.attention( modulate(self.norm1(hidden_states), shift_msa, scale_msa), position_embeddings=position_embeddings, ) ) ) + hidden_states ) hidden_states = ( self.drop_path( gate_mlp.unsqueeze(1) * self.layer_scale2( self.mlp(modulate(self.norm2(hidden_states), shift_mlp, scale_mlp)) ) ) + hidden_states ) return hidden_states class ResidualUpscaler(nn.Module): def __init__( self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128 ): # max depth 2 (2**2 = 16 = patch size) super().__init__() def build_encoder(in_dim, num_layers=2): return nn.Sequential( nn.Conv2d( in_dim, bottleneck_dim, kernel_size=1, padding=0, padding_mode="reflect", bias=False, ), *[ ResBlock( bottleneck_dim, bottleneck_dim, kernel_size=1, num_groups=8, pad_mode="reflect", norm_fn=nn.GroupNorm, activation_fn=nn.SiLU, use_conv_shortcut=False, ) for _ in range(num_layers) ], ) self.config = config self.depth = depth self.global_encode = nn.Linear( config.hidden_size * len(depth), bottleneck_dim * 2, ) self.local_encode = nn.ModuleList( [ nn.Linear(config.hidden_size, bottleneck_dim * 2) if d != 0 else nn.Identity() for d in depth ] ) self.q_norm = nn.ModuleList( [nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth] ) self.k_norm = nn.ModuleList( [nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth] ) self.v_downsample = nn.ModuleList( [ nn.Linear(config.hidden_size, config.hidden_size // (4**d)) if d != 0 else nn.Identity() for d in depth ] ) self.cross_attn = nn.ModuleList( [ CrossAttention( bottleneck_dim, bottleneck_dim, config.hidden_size // (4**d), ) if d != 0 else nn.Identity() for d in depth ] ) self.image_encoder = build_encoder(3) self.q_encoder = build_encoder(bottleneck_dim) self.k_encoder = build_encoder(bottleneck_dim) self.rope = RoPE(bottleneck_dim) # ok just shuffle it; no dont # self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)] def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]): # residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, 1 + self.config.num_register_tokens + seq, d), (b, 1 + self.config.num_register_tokens + seq, d)] # objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well assert self.config.patch_size is not None image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1] rest = [ residual[:, : 1 + self.config.num_register_tokens] for residual in residuals ] residuals = [ residual[:, 1 + self.config.num_register_tokens :] for residual in residuals ] global_shift, global_scale = self.global_encode( einops.rearrange( torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)" ) ).chunk(2, dim=-1) # patch-level global btw image_residual = self.image_encoder(pixel_values) # messy; todo: cleanup coords = create_coordinate(pixel_values.shape[-2], pixel_values.shape[-1]) image_residual = rearrange(image_residual, "b c h w -> b (h w) c") image_residual = self.rope(image_residual, coords) image_residual = rearrange( image_residual, "b (h w) c -> b c h w", h=pixel_values.shape[-2] ) q = self.q_encoder(image_residual) k = self.k_encoder(image_residual) reformed_residual = [] for i, (depth, residual, rest) in enumerate(zip(self.depth, residuals, rest)): if depth == 0: reformed_residual.append(torch.cat((rest, residual), dim=1)) continue local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1) local_q = self.q_norm[i]( einops.rearrange( torch.nn.functional.adaptive_avg_pool2d( q, output_size=( image_h // self.config.patch_size * (2**depth), image_w // self.config.patch_size * (2**depth), ), ), "b c h w -> b (h w) c", ) ) local_k = (1 + local_scale) * ( (1 + global_scale) * self.k_norm[i]( einops.rearrange( torch.nn.functional.adaptive_avg_pool2d( k, output_size=( image_h // self.config.patch_size, image_w // self.config.patch_size, ), ), "b c h w -> b (h w) c", ) ) + global_shift ) + local_shift local_v = self.v_downsample[i](residual) local_rest = self.v_downsample[i](rest) final_residual = torch.concat( (local_rest, self.cross_attn[i](local_q, local_k, local_v)), dim=1 ) reformed_residual.append(final_residual) return reformed_residual class DinoV3ViTDecoder(nn.Module): def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.num_channels_out = config.num_channels self.patch_size = config.patch_size self.projection = nn.Linear( config.hidden_size // 16, config.num_channels * 16, bias=True ) self.upscale = nn.PixelShuffle(4) nn.init.zeros_(self.projection.weight) nn.init.zeros_( self.projection.bias ) if self.projection.bias is not None else None def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: return self.upscale( self.projection( einops.rearrange( x, "b (h w) d -> b h w d", h=image_size[0] // 4, w=image_size[1] // 4, ) ).permute(0, 3, 1, 2) ) class NerfEmbedder(nn.Module): def __init__(self, in_channels, hidden_size_input, max_freqs): super().__init__() self.max_freqs = max_freqs self.hidden_size_input = hidden_size_input self.embedder = nn.Sequential( nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True), ) @lru_cache def fetch_pos(self, patch_size, device, dtype): pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") pos_x = pos_x.reshape(-1, 1, 1) pos_y = pos_y.reshape(-1, 1, 1) freqs = torch.linspace( 0, self.max_freqs, self.max_freqs, dtype=dtype, device=device ) freqs_x = freqs[None, :, None] freqs_y = freqs[None, None, :] coeffs = (1 + freqs_x * freqs_y) ** -1 dct_x = torch.cos(pos_x * freqs_x * torch.pi) dct_y = torch.cos(pos_y * freqs_y * torch.pi) dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2) return dct def forward(self, inputs): target_dtype = self.embedder[0].weight.dtype inputs = inputs.to(dtype=target_dtype) B, P2, C = inputs.shape patch_size = int(P2**0.5) device = inputs.device dtype = inputs.dtype dct = self.fetch_pos(patch_size, device, dtype) dct = dct.repeat(B, 1, 1) inputs = torch.cat([inputs, dct], dim=-1) inputs = self.embedder(inputs) return inputs class NerfBlock(nn.Module): def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio: int = 4): super().__init__() self.param_generator1 = nn.Sequential( nn.Linear(hidden_size_s, 2 * hidden_size_x**2 * mlp_ratio, bias=True), ) self.norm = nn.RMSNorm(hidden_size_x, eps=1e-6) self.mlp_ratio = mlp_ratio def forward(self, x, s): batch_size, num_x, hidden_size_x = x.shape mlp_params1 = self.param_generator1(s) fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1) fc1_param1 = fc1_param1.view( batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio ) fc2_param1 = fc2_param1.view( batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x ) # normalize fc1 normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2) # normalize fc2 normalized_fc2_param1 = torch.nn.functional.normalize(fc2_param1, dim=-2) # mlp 1 res_x = x x = self.norm(x) x = torch.bmm(x, normalized_fc1_param1) x = torch.nn.functional.silu(x) x = torch.bmm(x, normalized_fc2_param1) x = x + res_x return x class NerfFinalLayer(nn.Module): def __init__(self, hidden_size, out_channels): super().__init__() self.norm = nn.RMSNorm(hidden_size, eps=1e-6) self.linear = nn.Linear(hidden_size, out_channels, bias=True) def forward(self, x): x = self.norm(x) x = self.linear(x) return x class UTransformer(nn.Module): def __init__( self, config: DINOv3ViTConfig, num_classes: int, nerf_patch=16, nerf_hidden=64, scale_factor: int = 4, ): super().__init__() self.config = config self.scale_factor = scale_factor self.nerf_patch_size = nerf_patch assert config.num_hidden_layers % scale_factor % 3 == 0 self.embeddings = DINOv3ViTEmbeddings(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) def gen_rope(depth: int): hidden_size = config.hidden_size // (4**depth) hacky_config = copy.copy(config) hacky_config.hidden_size = hidden_size hacky_config.intermediate_size = hacky_config.intermediate_size // ( 3**depth ) hacky_config.num_attention_heads = hacky_config.num_attention_heads // ( 2**depth ) return DINOv3ViTRopePositionEmbedding(hacky_config) self.decode_ropes = nn.ModuleList([gen_rope(i + 1) for i in range(2)]) self.t_embedder = TimestepEmbedder(config.hidden_size) self.encoder_layers = nn.ModuleList( [DinoEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) DEPTH_LAYER = [0, 0, 1, 1, 2, 2] self.residual_upscaler = ResidualUpscaler( config, DEPTH_LAYER, # hardcoded, sorry bottleneck_dim=128, ) self.decoder_layers = nn.ModuleList( [ nn.ModuleList( [ DinoDecoderLayer(config, depth) for _ in range(DEPTH_LAYER.count(depth)) ] ) for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index) ] ) self.residual_merger = nn.ModuleList( [ nn.ModuleList( [ nn.Sequential( nn.SiLU(), nn.Linear( config.hidden_size // (4**depth), 2 * config.hidden_size // (4**depth), ), ) for _ in range(DEPTH_LAYER.count(depth)) ] ) for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index) ] ) self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)]) self.upsample_latent = nn.ModuleList( [ nn.Linear( config.hidden_size // (4**depth), config.hidden_size // (4 ** (depth + 1)), ) for depth in range(2) ] ) self.rest_decoder = nn.ModuleList( [DinoDecoderLayer(config, 2) for _ in range(4)] ) self.decoder_norm = nn.LayerNorm( (config.hidden_size // (4**2)), eps=config.layer_norm_eps ) # nerf! self.nerf_encoder = NerfEmbedder(3, nerf_hidden, 8) # (rgb, hidden, freq) self.nerf_decoder = nn.ModuleList( [NerfBlock(self.config.hidden_size, nerf_hidden) for _ in range(12)] ) self.final_layer = NerfFinalLayer(nerf_hidden, 3) # freeze pretrained self.embeddings.requires_grad_(False) self.rope_embeddings.requires_grad_(False) self.encoder_norm.requires_grad_(False) def forward( self, pixel_values: torch.Tensor, time: torch.Tensor, # cond: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, ): if time.dim() == 0: time = time.repeat(pixel_values.shape[0]) # resolution config B = pixel_values.shape[0] dino_h = pixel_values.shape[-2] // self.config.patch_size dino_w = pixel_values.shape[-1] // self.config.patch_size nerf_h = pixel_values.shape[-2] // self.nerf_patch_size nerf_w = pixel_values.shape[-1] // self.nerf_patch_size pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) position_embeddings = self.rope_embeddings(pixel_values) x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) t = self.t_embedder(time).unsqueeze(1) # y = self.y_embedder(cond, self.training).unsqueeze(1) # conditioning_input = t.to(x.dtype) + y.to(x.dtype) conditioning_input = t.to(x.dtype) residual = [] for i, layer_module in enumerate(self.encoder_layers): if i % self.scale_factor == 0: residual.append(x) layer_head_mask = head_mask[i] if head_mask is not None else None x = layer_module( x, conditioning_input=conditioning_input, attention_mask=layer_head_mask, position_embeddings=position_embeddings, ) x = self.encoder_norm(x) reversed_residual = self.residual_upscaler(pixel_values, residual[::-1]) residual_idx = 0 for depth, layers in enumerate(self.decoder_layers): for i, layer_module in enumerate(layers): # type: ignore x = layer_module( x, conditioning_input=conditioning_input, attention_mask=None, position_embeddings=position_embeddings, ) shift, scale = self.residual_merger[depth][i]( # type: ignore reversed_residual[residual_idx] ).chunk(2, dim=-1) x = x * (1 + scale) + shift residual_idx += 1 x = torch.cat( ( rearrange( self.upsample[depth]( rearrange( x[:, 1 + self.config.num_register_tokens :], "b (h w) d -> b d h w", h=pixel_values.shape[-2] // (self.config.patch_size) * (2**depth), ) ), "b d h w -> b (h w) d", ) if depth != 2 else x[:, 1 + self.config.num_register_tokens :], self.upsample_latent[depth]( x[:, : 1 + self.config.num_register_tokens] ) if depth != 2 else x[:, : 1 + self.config.num_register_tokens], ), dim=1, ) position_embeddings = ( self.decode_ropes[depth]( torch.zeros( ( 1, 1, pixel_values.shape[-2] * (2 ** (depth + 1)), pixel_values.shape[-1] * (2 ** (depth + 1)), ), device=x.device, ).to(self.embeddings.patch_embeddings.weight.dtype) ) if depth != 2 else position_embeddings ) for i, layer_module in enumerate(self.rest_decoder): layer_head_mask = head_mask[i] if head_mask is not None else None x = layer_module( x, conditioning_input=conditioning_input, attention_mask=layer_head_mask, position_embeddings=position_embeddings, do_condition=False, ) # (batch, image // patch^2, 1024) x = x[:, 1 + self.config.num_register_tokens :, :] nerf_cond = nn.functional.silu(t + x) # (batch, image // patch^2, 1024) nerf_cond = nerf_cond.reshape( B, dino_h, dino_w, self.config.hidden_size ).permute(0, 3, 1, 2) # (batch, 1024, image // patch, image // patch) # nerf_cond = nn.functional.interpolate( # nerf_cond, size=(nerf_h, nerf_w), mode="bilinear", align_corners=False # ) nerf_cond = ( nerf_cond.permute(0, 2, 3, 1) .reshape(-1, nerf_h * nerf_w, self.config.hidden_size) .view(-1, self.config.hidden_size) ) # nerf x_nerf = nn.functional.unfold( pixel_values, self.nerf_patch_size, stride=self.nerf_patch_size ).transpose(1, 2) x_nerf = x_nerf.reshape( B * x_nerf.shape[1], -1, self.nerf_patch_size**2 ).transpose(1, 2) x_nerf = self.nerf_encoder(x_nerf) for module in self.nerf_decoder: x_nerf = module(x_nerf, nerf_cond) x_nerf = self.final_layer(x_nerf) num_patches = nerf_h * nerf_w x_nerf = x_nerf.reshape( B * num_patches, -1 ) # (B*num_patches, 48): flatten pixels+RGB per patch x_nerf = ( x_nerf.view(B, num_patches, -1).transpose(1, 2).contiguous() ) # (B, 48, num_patches) res = nn.functional.fold( x_nerf, (pixel_values.shape[-2], pixel_values.shape[-1]), kernel_size=self.nerf_patch_size, stride=self.nerf_patch_size, ) return res def get_residual( self, pixel_values: torch.Tensor, time: Optional[torch.Tensor], do_condition: bool, ): pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) position_embeddings = self.rope_embeddings(pixel_values) x = self.embeddings(pixel_values, bool_masked_pos=None) if do_condition: t = self.t_embedder(time).unsqueeze(1) # y = self.y_embedder(cond, self.training).unsqueeze(1) # conditioning_input = t.to(x.dtype) + y.to(x.dtype) conditioning_input = t.to(x.dtype) else: conditioning_input = None residual = [] for i, layer_module in enumerate(self.encoder_layers): if i % self.scale_factor == 0: residual.append(x) x = layer_module( x, conditioning_input=conditioning_input, attention_mask=None, position_embeddings=position_embeddings, ) x = x[:, 1 + self.config.num_register_tokens :] x = self.decoder_norm(x) return self.decoder(x, image_size=pixel_values.shape[-2:]), residual @staticmethod def from_pretrained_backbone(name: str): config = DINOv3ViTConfig.from_pretrained(name) instance = UTransformer(config, 0) weight_dict = {} with safe_open( hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1" ) as f: for key in f.keys(): new_key = key.replace("layer.", "encoder_layers.").replace( "norm.", "encoder_norm." ) weight_dict[new_key] = f.get_tensor(key) instance.load_state_dict(weight_dict, strict=False) return instance