add rf
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
from typing import Optional
|
||||
from torch import nn
|
||||
import torch
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.dino import DINOv3ViTEmbeddings, DINOv3ViTLayerScale, DINOv3ViTRopePositionEmbedding, DINOv3ViTLayer
|
||||
from src.model.dino import (
|
||||
DINOv3ViTEmbeddings,
|
||||
DINOv3ViTLayer,
|
||||
DINOv3ViTLayerScale,
|
||||
DINOv3ViTRopePositionEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size: int, frequency_embedding_size: int=256):
|
||||
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size),
|
||||
@@ -65,12 +74,18 @@ class LabelEmbedder(nn.Module):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
|
||||
super().__init__(config)
|
||||
|
||||
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cond = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.drop_path_rate, batch_first=True)
|
||||
self.cond = nn.MultiheadAttention(
|
||||
config.hidden_size,
|
||||
config.num_attention_heads,
|
||||
config.drop_path_rate,
|
||||
batch_first=True,
|
||||
)
|
||||
self.layer_scale_cond = DINOv3ViTLayerScale(config)
|
||||
|
||||
# no init zeros!
|
||||
@@ -83,9 +98,15 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
self.layer_scale1.requires_grad_(False)
|
||||
self.layer_scale2.requires_grad_(False)
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *, conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
conditioning_input: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert position_embeddings is not None
|
||||
assert conditioning_input is not None
|
||||
|
||||
@@ -101,7 +122,9 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_cond(hidden_states)
|
||||
hidden_states, _ = self.cond(hidden_states, conditioning_input, conditioning_input)
|
||||
hidden_states, _ = self.cond(
|
||||
hidden_states, conditioning_input, conditioning_input
|
||||
)
|
||||
hidden_states = self.layer_scale_cond(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
|
||||
@@ -123,7 +146,7 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.num_channels_out * config.patch_size * config.patch_size,
|
||||
bias=True
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
@@ -139,7 +162,9 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
h_grid = image_size[0] // p
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert patch_tokens.shape[1] == h_grid * w_grid, "Number of patches does not match image size."
|
||||
assert patch_tokens.shape[1] == h_grid * w_grid, (
|
||||
"Number of patches does not match image size."
|
||||
)
|
||||
|
||||
x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
|
||||
|
||||
@@ -149,6 +174,7 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
|
||||
return reconstructed_image
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
def __init__(self, config: DINOv3ViTConfig, num_classes: int):
|
||||
super().__init__()
|
||||
@@ -157,12 +183,24 @@ class UTransformer(nn.Module):
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes, config.hidden_size, config.drop_path_rate)
|
||||
# self.y_embedder = LabelEmbedder(
|
||||
# num_classes, config.hidden_size, config.drop_path_rate
|
||||
# ) # disable cond for now
|
||||
|
||||
self.encoder_layers = nn.ModuleList([DinoConditionedLayer(config, True) for _ in range(config.num_hidden_layers)])
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, True)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder_layers = nn.ModuleList([DinoConditionedLayer(config, False) for _ in range(config.num_hidden_layers)])
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, False)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
# freeze pretrained
|
||||
@@ -170,15 +208,22 @@ class UTransformer(nn.Module):
|
||||
self.rope_embeddings.requires_grad_(False)
|
||||
self.encoder_norm.requires_grad_(False)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, time: torch.Tensor, cond: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: torch.Tensor,
|
||||
# cond: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||
conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
conditioning_input = t.to(x.dtype)
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
@@ -203,3 +248,22 @@ class UTransformer(nn.Module):
|
||||
x = x + residual.pop()
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:])
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = UTransformer(config, 0).to("cuda:3")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:3"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
new_key = key.replace("layer.", "encoder_layers.").replace(
|
||||
"norm.", "encoder_norm."
|
||||
)
|
||||
weight_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
instance.load_state_dict(weight_dict, strict=False)
|
||||
|
||||
return instance
|
||||
|
||||
Reference in New Issue
Block a user