utransformer

This commit is contained in:
Senstella
2025-09-28 19:22:54 +09:00
parent b9cc48bd25
commit 2761171fe3
2 changed files with 209 additions and 4 deletions

View File

@@ -214,7 +214,7 @@ class DINOv3ViTDropPath(nn.Module):
class DINOv3ViTMLP(nn.Module): class DINOv3ViTMLP(nn.Module):
def __init__(self, config): def __init__(self, config: DINOv3ViTConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@@ -236,7 +236,7 @@ class DINOv3ViTMLP(nn.Module):
class DINOv3ViTGatedMLP(nn.Module): class DINOv3ViTGatedMLP(nn.Module):
def __init__(self, config): def __init__(self, config: DINOv3ViTConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@@ -274,8 +274,8 @@ class DINOv3ViTLayer(nn.Module):
self.mlp = DINOv3ViTMLP(config) self.mlp = DINOv3ViTMLP(config)
self.layer_scale2 = DINOv3ViTLayerScale(config) self.layer_scale2 = DINOv3ViTLayerScale(config)
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, def forward(self, hidden_states: torch.Tensor, *, attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor: position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> torch.Tensor:
assert position_embeddings is not None assert position_embeddings is not None
residual = hidden_states residual = hidden_states

205
src/model/utransformer.py Normal file
View File

@@ -0,0 +1,205 @@
from typing import Optional
from torch import nn
import torch
import math
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
from src.model.dino import DINOv3ViTEmbeddings, DINOv3ViTLayerScale, DINOv3ViTRopePositionEmbedding, DINOv3ViTLayer
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size: int, frequency_embedding_size: int=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
dtype=next(self.parameters()).dtype
)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = int(dropout_prob > 0)
self.embedding_table = nn.Embedding(
num_classes + use_cfg_embedding, hidden_size
)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
drop_ids = drop_ids.cuda()
drop_ids = drop_ids.to(labels.device)
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class DinoConditionedLayer(DINOv3ViTLayer):
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
super().__init__(config)
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cond = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.drop_path_rate, batch_first=True)
self.layer_scale_cond = DINOv3ViTLayerScale(config)
# no init zeros!
if is_encoder:
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
self.norm1.requires_grad_(False)
self.norm2.requires_grad_(False)
self.attention.requires_grad_(False)
self.mlp.requires_grad_(False)
self.layer_scale1.requires_grad_(False)
self.layer_scale2.requires_grad_(False)
def forward(self, hidden_states: torch.Tensor, *, conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> torch.Tensor:
assert position_embeddings is not None
assert conditioning_input is not None
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states, _ = self.attention(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
)
hidden_states = self.layer_scale1(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
residual = hidden_states
hidden_states = self.norm_cond(hidden_states)
hidden_states, _ = self.cond(hidden_states, conditioning_input, conditioning_input)
hidden_states = self.layer_scale_cond(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.layer_scale2(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
return hidden_states
class DinoV3ViTDecoder(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.config = config
self.num_channels_out = config.num_channels
self.projection = nn.Linear(
config.hidden_size,
self.num_channels_out * config.patch_size * config.patch_size,
bias=True
)
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
batch_size = x.shape[0]
num_special_tokens = 1 + self.config.num_register_tokens
patch_tokens = x[:, num_special_tokens:, :]
projected_tokens = self.projection(patch_tokens)
p = self.config.patch_size
c = self.num_channels_out
h_grid = image_size[0] // p
w_grid = image_size[1] // p
assert patch_tokens.shape[1] == h_grid * w_grid, "Number of patches does not match image size."
x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
return reconstructed_image
class UTransformer(nn.Module):
def __init__(self, config: DINOv3ViTConfig, num_classes: int):
super().__init__()
self.config = config
self.embeddings = DINOv3ViTEmbeddings(config)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
self.t_embedder = TimestepEmbedder(config.hidden_size)
self.y_embedder = LabelEmbedder(num_classes, config.hidden_size, config.drop_path_rate)
self.encoder_layers = nn.ModuleList([DinoConditionedLayer(config, True) for _ in range(config.num_hidden_layers)])
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder_layers = nn.ModuleList([DinoConditionedLayer(config, False) for _ in range(config.num_hidden_layers)])
self.decoder = DinoV3ViTDecoder(config)
# freeze pretrained
self.embeddings.requires_grad_(False)
self.rope_embeddings.requires_grad_(False)
self.encoder_norm.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, time: torch.Tensor, cond: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
t = self.t_embedder(time).unsqueeze(1)
y = self.y_embedder(cond, self.training).unsqueeze(1)
conditioning_input = t.to(x.dtype) + y.to(x.dtype)
residual = []
for i, layer_module in enumerate(self.encoder_layers):
residual.append(x)
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = self.encoder_norm(x)
for i, layer_module in enumerate(self.decoder_layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = x + residual.pop()
return self.decoder(x, image_size=pixel_values.shape[-2:])