270 lines
9.3 KiB
Python
270 lines
9.3 KiB
Python
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors import safe_open
|
|
from torch import nn
|
|
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
|
|
|
from src.model.dino import (
|
|
DINOv3ViTEmbeddings,
|
|
DINOv3ViTLayer,
|
|
DINOv3ViTLayerScale,
|
|
DINOv3ViTRopePositionEmbedding,
|
|
)
|
|
|
|
|
|
class TimestepEmbedder(nn.Module):
|
|
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(frequency_embedding_size, hidden_size),
|
|
nn.SiLU(),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
)
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
|
|
@staticmethod
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
half = dim // 2
|
|
freqs = torch.exp(
|
|
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
|
).to(t.device)
|
|
args = t[:, None] * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat(
|
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
)
|
|
return embedding
|
|
|
|
def forward(self, t):
|
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
|
|
dtype=next(self.parameters()).dtype
|
|
)
|
|
t_emb = self.mlp(t_freq)
|
|
return t_emb
|
|
|
|
|
|
class LabelEmbedder(nn.Module):
|
|
def __init__(self, num_classes, hidden_size, dropout_prob):
|
|
super().__init__()
|
|
use_cfg_embedding = int(dropout_prob > 0)
|
|
self.embedding_table = nn.Embedding(
|
|
num_classes + use_cfg_embedding, hidden_size
|
|
)
|
|
self.num_classes = num_classes
|
|
self.dropout_prob = dropout_prob
|
|
|
|
def token_drop(self, labels, force_drop_ids=None):
|
|
if force_drop_ids is None:
|
|
drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
|
|
drop_ids = drop_ids.cuda()
|
|
drop_ids = drop_ids.to(labels.device)
|
|
else:
|
|
drop_ids = force_drop_ids == 1
|
|
labels = torch.where(drop_ids, self.num_classes, labels)
|
|
return labels
|
|
|
|
def forward(self, labels, train, force_drop_ids=None):
|
|
use_dropout = self.dropout_prob > 0
|
|
if (train and use_dropout) or (force_drop_ids is not None):
|
|
labels = self.token_drop(labels, force_drop_ids)
|
|
embeddings = self.embedding_table(labels)
|
|
return embeddings
|
|
|
|
|
|
class DinoConditionedLayer(DINOv3ViTLayer):
|
|
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
|
|
super().__init__(config)
|
|
|
|
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.cond = nn.MultiheadAttention(
|
|
config.hidden_size,
|
|
config.num_attention_heads,
|
|
config.drop_path_rate,
|
|
batch_first=True,
|
|
)
|
|
self.layer_scale_cond = DINOv3ViTLayerScale(config)
|
|
|
|
# no init zeros!
|
|
if is_encoder:
|
|
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
|
|
self.norm1.requires_grad_(False)
|
|
self.norm2.requires_grad_(False)
|
|
self.attention.requires_grad_(False)
|
|
self.mlp.requires_grad_(False)
|
|
self.layer_scale1.requires_grad_(False)
|
|
self.layer_scale2.requires_grad_(False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
*,
|
|
conditioning_input: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
assert position_embeddings is not None
|
|
assert conditioning_input is not None
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.norm1(hidden_states)
|
|
hidden_states, _ = self.attention(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
hidden_states = self.layer_scale1(hidden_states)
|
|
hidden_states = self.drop_path(hidden_states) + residual
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.norm_cond(hidden_states)
|
|
hidden_states, _ = self.cond(
|
|
hidden_states, conditioning_input, conditioning_input
|
|
)
|
|
hidden_states = self.layer_scale_cond(hidden_states)
|
|
hidden_states = self.drop_path(hidden_states) + residual
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.layer_scale2(hidden_states)
|
|
hidden_states = self.drop_path(hidden_states) + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
class DinoV3ViTDecoder(nn.Module):
|
|
def __init__(self, config: DINOv3ViTConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.num_channels_out = config.num_channels
|
|
|
|
self.projection = nn.Linear(
|
|
config.hidden_size,
|
|
self.num_channels_out * config.patch_size * config.patch_size,
|
|
bias=True,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
|
batch_size = x.shape[0]
|
|
|
|
num_special_tokens = 1 + self.config.num_register_tokens
|
|
patch_tokens = x[:, num_special_tokens:, :]
|
|
|
|
projected_tokens = self.projection(patch_tokens)
|
|
|
|
p = self.config.patch_size
|
|
c = self.num_channels_out
|
|
h_grid = image_size[0] // p
|
|
w_grid = image_size[1] // p
|
|
|
|
assert patch_tokens.shape[1] == h_grid * w_grid, (
|
|
"Number of patches does not match image size."
|
|
)
|
|
|
|
x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
|
|
|
|
x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
|
|
|
|
reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
|
|
|
|
return reconstructed_image
|
|
|
|
|
|
class UTransformer(nn.Module):
|
|
def __init__(self, config: DINOv3ViTConfig, num_classes: int):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.embeddings = DINOv3ViTEmbeddings(config)
|
|
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
|
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
|
# self.y_embedder = LabelEmbedder(
|
|
# num_classes, config.hidden_size, config.drop_path_rate
|
|
# ) # disable cond for now
|
|
|
|
self.encoder_layers = nn.ModuleList(
|
|
[
|
|
DinoConditionedLayer(config, True)
|
|
for _ in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
self.decoder_layers = nn.ModuleList(
|
|
[
|
|
DinoConditionedLayer(config, False)
|
|
for _ in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.decoder = DinoV3ViTDecoder(config)
|
|
|
|
# freeze pretrained
|
|
self.embeddings.requires_grad_(False)
|
|
self.rope_embeddings.requires_grad_(False)
|
|
self.encoder_norm.requires_grad_(False)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
time: torch.Tensor,
|
|
# cond: torch.Tensor,
|
|
bool_masked_pos: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
|
position_embeddings = self.rope_embeddings(pixel_values)
|
|
|
|
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
|
t = self.t_embedder(time).unsqueeze(1)
|
|
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
|
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
|
conditioning_input = t.to(x.dtype)
|
|
|
|
residual = []
|
|
for i, layer_module in enumerate(self.encoder_layers):
|
|
residual.append(x)
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
x = layer_module(
|
|
x,
|
|
conditioning_input=conditioning_input,
|
|
attention_mask=layer_head_mask,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
x = self.encoder_norm(x)
|
|
|
|
for i, layer_module in enumerate(self.decoder_layers):
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
x = layer_module(
|
|
x,
|
|
conditioning_input=conditioning_input,
|
|
attention_mask=layer_head_mask,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
x = x + residual.pop()
|
|
|
|
return self.decoder(x, image_size=pixel_values.shape[-2:])
|
|
|
|
@staticmethod
|
|
def from_pretrained_backbone(name: str):
|
|
config = DINOv3ViTConfig.from_pretrained(name)
|
|
instance = UTransformer(config, 0).to("cuda:3")
|
|
|
|
weight_dict = {}
|
|
with safe_open(
|
|
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:3"
|
|
) as f:
|
|
for key in f.keys():
|
|
new_key = key.replace("layer.", "encoder_layers.").replace(
|
|
"norm.", "encoder_norm."
|
|
)
|
|
weight_dict[new_key] = f.get_tensor(key)
|
|
|
|
instance.load_state_dict(weight_dict, strict=False)
|
|
|
|
return instance
|