819 lines
29 KiB
Python
819 lines
29 KiB
Python
import copy
|
|
import math
|
|
from functools import lru_cache
|
|
from typing import Optional
|
|
|
|
import einops
|
|
import torch
|
|
from einops import rearrange
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors import safe_open
|
|
from torch import nn
|
|
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
|
|
|
from src.model.attention import CrossAttention, PlainAttention, RoPE
|
|
from src.model.dino import (
|
|
DINOv3ViTAttention,
|
|
DINOv3ViTDropPath,
|
|
DINOv3ViTEmbeddings,
|
|
DINOv3ViTGatedMLP,
|
|
DINOv3ViTLayer,
|
|
DINOv3ViTLayerScale,
|
|
DINOv3ViTMLP,
|
|
DINOv3ViTRopePositionEmbedding,
|
|
)
|
|
from src.model.dit import modulate
|
|
from src.model.resnet import ResBlock
|
|
|
|
|
|
def create_coordinate(h, w, start=0, end=1, device="cuda:1", dtype=torch.float32):
|
|
# Create a grid of coordinates
|
|
x = torch.linspace(start, end, h, device=device, dtype=dtype)
|
|
y = torch.linspace(start, end, w, device=device, dtype=dtype)
|
|
# Create a 2D map using meshgrid
|
|
xx, yy = torch.meshgrid(x, y, indexing="ij")
|
|
# Stack the x and y coordinates to create the final map
|
|
coord_map = torch.stack([xx, yy], dim=-1)[None, ...]
|
|
coords = rearrange(coord_map, "b h w c -> b (h w) c", h=h, w=w)
|
|
return coords
|
|
|
|
|
|
class TimestepEmbedder(nn.Module):
|
|
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(frequency_embedding_size, hidden_size),
|
|
nn.SiLU(),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
)
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
|
|
@staticmethod
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
half = dim // 2
|
|
freqs = torch.exp(
|
|
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
|
).to(t.device)
|
|
args = t[:, None] * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat(
|
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
)
|
|
return embedding
|
|
|
|
def forward(self, t):
|
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
|
|
dtype=next(self.parameters()).dtype
|
|
)
|
|
t_emb = self.mlp(t_freq)
|
|
return t_emb
|
|
|
|
|
|
class LabelEmbedder(nn.Module):
|
|
def __init__(self, num_classes, hidden_size, dropout_prob):
|
|
super().__init__()
|
|
use_cfg_embedding = int(dropout_prob > 0)
|
|
self.embedding_table = nn.Embedding(
|
|
num_classes + use_cfg_embedding, hidden_size
|
|
)
|
|
self.num_classes = num_classes
|
|
self.dropout_prob = dropout_prob
|
|
|
|
def token_drop(self, labels, force_drop_ids=None):
|
|
if force_drop_ids is None:
|
|
drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
|
|
drop_ids = drop_ids.cuda()
|
|
drop_ids = drop_ids.to(labels.device)
|
|
else:
|
|
drop_ids = force_drop_ids == 1
|
|
labels = torch.where(drop_ids, self.num_classes, labels)
|
|
return labels
|
|
|
|
def forward(self, labels, train, force_drop_ids=None):
|
|
use_dropout = self.dropout_prob > 0
|
|
if (train and use_dropout) or (force_drop_ids is not None):
|
|
labels = self.token_drop(labels, force_drop_ids)
|
|
embeddings = self.embedding_table(labels)
|
|
return embeddings
|
|
|
|
|
|
class DinoEncoderLayer(DINOv3ViTLayer):
|
|
def __init__(self, config: DINOv3ViTConfig):
|
|
super().__init__(config)
|
|
|
|
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.cond = DINOv3ViTAttention(config)
|
|
self.layer_scale_cond = DINOv3ViTLayerScale(config)
|
|
|
|
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
|
|
self.norm1.requires_grad_(False)
|
|
self.norm2.requires_grad_(False)
|
|
self.attention.requires_grad_(False)
|
|
self.mlp.requires_grad_(False)
|
|
self.layer_scale1.requires_grad_(False)
|
|
self.layer_scale2.requires_grad_(False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
*,
|
|
conditioning_input: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
assert position_embeddings is not None
|
|
|
|
hidden_states = (
|
|
self.drop_path(
|
|
self.layer_scale1(
|
|
self.attention(
|
|
self.norm1(hidden_states),
|
|
attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
)
|
|
)
|
|
+ hidden_states
|
|
)
|
|
|
|
hidden_states = (
|
|
self.drop_path(
|
|
self.layer_scale_cond(
|
|
self.cond(
|
|
hidden_states,
|
|
self.norm_cond(conditioning_input),
|
|
)
|
|
)
|
|
)
|
|
+ hidden_states
|
|
)
|
|
|
|
hidden_states = (
|
|
self.drop_path(self.layer_scale2(self.mlp(self.norm2(hidden_states))))
|
|
+ hidden_states
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class DinoDecoderLayer(DINOv3ViTLayer):
|
|
def __init__(self, config: DINOv3ViTConfig, depth: int):
|
|
super().__init__(config)
|
|
|
|
hidden_size = config.hidden_size // (4**depth)
|
|
hacky_config = copy.copy(config)
|
|
hacky_config.hidden_size = hidden_size
|
|
hacky_config.intermediate_size = hacky_config.intermediate_size // (3**depth)
|
|
|
|
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
|
self.attention = PlainAttention(
|
|
hidden_size, config.num_attention_heads // (2**depth)
|
|
) # head scaling law?
|
|
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
|
|
self.drop_path = (
|
|
DINOv3ViTDropPath(hacky_config.drop_path_rate)
|
|
if hacky_config.drop_path_rate > 0.0
|
|
else nn.Identity()
|
|
)
|
|
|
|
self.norm2 = nn.LayerNorm(hidden_size, eps=hacky_config.layer_norm_eps)
|
|
|
|
if config.use_gated_mlp:
|
|
self.mlp = DINOv3ViTGatedMLP(hacky_config)
|
|
else:
|
|
self.mlp = DINOv3ViTMLP(hacky_config)
|
|
self.layer_scale2 = DINOv3ViTLayerScale(hacky_config)
|
|
|
|
# adaln
|
|
self.adaln = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(config.hidden_size, 6 * hidden_size, bias=True),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
*,
|
|
conditioning_input: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
conditioning_input = conditioning_input.squeeze(1) # type: ignore
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln(
|
|
conditioning_input
|
|
).chunk(6, dim=-1)
|
|
|
|
hidden_states = (
|
|
self.drop_path(
|
|
gate_msa.unsqueeze(1)
|
|
* self.layer_scale1(
|
|
self.attention(
|
|
modulate(self.norm1(hidden_states), shift_msa, scale_msa),
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
)
|
|
)
|
|
+ hidden_states
|
|
)
|
|
|
|
hidden_states = (
|
|
self.drop_path(
|
|
gate_mlp.unsqueeze(1)
|
|
* self.layer_scale2(
|
|
self.mlp(modulate(self.norm2(hidden_states), shift_mlp, scale_mlp))
|
|
)
|
|
)
|
|
+ hidden_states
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class ResidualUpscaler(nn.Module):
|
|
def __init__(
|
|
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
|
|
): # max depth 2 (2**2 = 16 = patch size)
|
|
super().__init__()
|
|
|
|
def build_encoder(in_dim, num_layers=2):
|
|
return nn.Sequential(
|
|
nn.Conv2d(
|
|
in_dim,
|
|
bottleneck_dim,
|
|
kernel_size=1,
|
|
padding=0,
|
|
padding_mode="reflect",
|
|
bias=False,
|
|
),
|
|
*[
|
|
ResBlock(
|
|
bottleneck_dim,
|
|
bottleneck_dim,
|
|
kernel_size=1,
|
|
num_groups=8,
|
|
pad_mode="reflect",
|
|
norm_fn=nn.GroupNorm,
|
|
activation_fn=nn.SiLU,
|
|
use_conv_shortcut=False,
|
|
)
|
|
for _ in range(num_layers)
|
|
],
|
|
)
|
|
|
|
self.config = config
|
|
self.depth = depth
|
|
self.global_encode = nn.Linear(
|
|
config.hidden_size * len(depth),
|
|
bottleneck_dim * 2,
|
|
)
|
|
self.local_encode = nn.ModuleList(
|
|
[
|
|
nn.Linear(config.hidden_size, bottleneck_dim * 2)
|
|
if d != 0
|
|
else nn.Identity()
|
|
for d in depth
|
|
]
|
|
)
|
|
self.q_norm = nn.ModuleList(
|
|
[nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth]
|
|
)
|
|
self.k_norm = nn.ModuleList(
|
|
[nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth]
|
|
)
|
|
self.v_downsample = nn.ModuleList(
|
|
[
|
|
nn.Linear(config.hidden_size, config.hidden_size // (4**d))
|
|
if d != 0
|
|
else nn.Identity()
|
|
for d in depth
|
|
]
|
|
)
|
|
self.cross_attn = nn.ModuleList(
|
|
[
|
|
CrossAttention(
|
|
bottleneck_dim,
|
|
bottleneck_dim,
|
|
config.hidden_size // (4**d),
|
|
)
|
|
if d != 0
|
|
else nn.Identity()
|
|
for d in depth
|
|
]
|
|
)
|
|
|
|
self.image_encoder = build_encoder(3)
|
|
self.q_encoder = build_encoder(bottleneck_dim)
|
|
self.k_encoder = build_encoder(bottleneck_dim)
|
|
|
|
self.rope = RoPE(bottleneck_dim)
|
|
|
|
# ok just shuffle it; no dont
|
|
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
|
|
|
|
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
|
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, 1 + self.config.num_register_tokens + seq, d), (b, 1 + self.config.num_register_tokens + seq, d)]
|
|
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
|
assert self.config.patch_size is not None
|
|
|
|
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
|
|
|
rest = [
|
|
residual[:, : 1 + self.config.num_register_tokens] for residual in residuals
|
|
]
|
|
residuals = [
|
|
residual[:, 1 + self.config.num_register_tokens :] for residual in residuals
|
|
]
|
|
|
|
global_shift, global_scale = self.global_encode(
|
|
einops.rearrange(
|
|
torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)"
|
|
)
|
|
).chunk(2, dim=-1) # patch-level global btw
|
|
|
|
image_residual = self.image_encoder(pixel_values) # messy; todo: cleanup
|
|
coords = create_coordinate(pixel_values.shape[-2], pixel_values.shape[-1])
|
|
image_residual = rearrange(image_residual, "b c h w -> b (h w) c")
|
|
image_residual = self.rope(image_residual, coords)
|
|
image_residual = rearrange(
|
|
image_residual, "b (h w) c -> b c h w", h=pixel_values.shape[-2]
|
|
)
|
|
q = self.q_encoder(image_residual)
|
|
k = self.k_encoder(image_residual)
|
|
|
|
reformed_residual = []
|
|
for i, (depth, residual, rest) in enumerate(zip(self.depth, residuals, rest)):
|
|
if depth == 0:
|
|
reformed_residual.append(torch.cat((rest, residual), dim=1))
|
|
continue
|
|
|
|
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
|
local_q = self.q_norm[i](
|
|
einops.rearrange(
|
|
torch.nn.functional.adaptive_avg_pool2d(
|
|
q,
|
|
output_size=(
|
|
image_h // self.config.patch_size * (2**depth),
|
|
image_w // self.config.patch_size * (2**depth),
|
|
),
|
|
),
|
|
"b c h w -> b (h w) c",
|
|
)
|
|
)
|
|
local_k = (1 + local_scale) * (
|
|
(1 + global_scale)
|
|
* self.k_norm[i](
|
|
einops.rearrange(
|
|
torch.nn.functional.adaptive_avg_pool2d(
|
|
k,
|
|
output_size=(
|
|
image_h // self.config.patch_size,
|
|
image_w // self.config.patch_size,
|
|
),
|
|
),
|
|
"b c h w -> b (h w) c",
|
|
)
|
|
)
|
|
+ global_shift
|
|
) + local_shift
|
|
local_v = self.v_downsample[i](residual)
|
|
|
|
local_rest = self.v_downsample[i](rest)
|
|
|
|
final_residual = torch.concat(
|
|
(local_rest, self.cross_attn[i](local_q, local_k, local_v)), dim=1
|
|
)
|
|
|
|
reformed_residual.append(final_residual)
|
|
|
|
return reformed_residual
|
|
|
|
|
|
class DinoV3ViTDecoder(nn.Module):
|
|
def __init__(self, config: DINOv3ViTConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.num_channels_out = config.num_channels
|
|
self.patch_size = config.patch_size
|
|
|
|
self.projection = nn.Linear(
|
|
config.hidden_size // 16, config.num_channels * 16, bias=True
|
|
)
|
|
self.upscale = nn.PixelShuffle(4)
|
|
|
|
nn.init.zeros_(self.projection.weight)
|
|
nn.init.zeros_(
|
|
self.projection.bias
|
|
) if self.projection.bias is not None else None
|
|
|
|
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
|
return self.upscale(
|
|
self.projection(
|
|
einops.rearrange(
|
|
x,
|
|
"b (h w) d -> b h w d",
|
|
h=image_size[0] // 4,
|
|
w=image_size[1] // 4,
|
|
)
|
|
).permute(0, 3, 1, 2)
|
|
)
|
|
|
|
|
|
class NerfEmbedder(nn.Module):
|
|
def __init__(self, in_channels, hidden_size_input, max_freqs):
|
|
super().__init__()
|
|
self.max_freqs = max_freqs
|
|
self.hidden_size_input = hidden_size_input
|
|
self.embedder = nn.Sequential(
|
|
nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True),
|
|
)
|
|
|
|
@lru_cache
|
|
def fetch_pos(self, patch_size, device, dtype):
|
|
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
|
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
|
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
|
pos_x = pos_x.reshape(-1, 1, 1)
|
|
pos_y = pos_y.reshape(-1, 1, 1)
|
|
|
|
freqs = torch.linspace(
|
|
0, self.max_freqs, self.max_freqs, dtype=dtype, device=device
|
|
)
|
|
freqs_x = freqs[None, :, None]
|
|
freqs_y = freqs[None, None, :]
|
|
coeffs = (1 + freqs_x * freqs_y) ** -1
|
|
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
|
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
|
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2)
|
|
return dct
|
|
|
|
def forward(self, inputs):
|
|
target_dtype = self.embedder[0].weight.dtype
|
|
inputs = inputs.to(dtype=target_dtype)
|
|
B, P2, C = inputs.shape
|
|
patch_size = int(P2**0.5)
|
|
device = inputs.device
|
|
dtype = inputs.dtype
|
|
dct = self.fetch_pos(patch_size, device, dtype)
|
|
dct = dct.repeat(B, 1, 1)
|
|
inputs = torch.cat([inputs, dct], dim=-1)
|
|
inputs = self.embedder(inputs)
|
|
return inputs
|
|
|
|
|
|
class NerfBlock(nn.Module):
|
|
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio: int = 4):
|
|
super().__init__()
|
|
self.param_generator1 = nn.Sequential(
|
|
nn.Linear(hidden_size_s, 2 * hidden_size_x**2 * mlp_ratio, bias=True),
|
|
)
|
|
self.norm = nn.RMSNorm(hidden_size_x, eps=1e-6)
|
|
self.mlp_ratio = mlp_ratio
|
|
|
|
def forward(self, x, s):
|
|
batch_size, num_x, hidden_size_x = x.shape
|
|
mlp_params1 = self.param_generator1(s)
|
|
fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1)
|
|
fc1_param1 = fc1_param1.view(
|
|
batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio
|
|
)
|
|
fc2_param1 = fc2_param1.view(
|
|
batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x
|
|
)
|
|
|
|
# normalize fc1
|
|
normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2)
|
|
# normalize fc2
|
|
normalized_fc2_param1 = torch.nn.functional.normalize(fc2_param1, dim=-2)
|
|
# mlp 1
|
|
res_x = x
|
|
x = self.norm(x)
|
|
x = torch.bmm(x, normalized_fc1_param1)
|
|
x = torch.nn.functional.silu(x)
|
|
x = torch.bmm(x, normalized_fc2_param1)
|
|
x = x + res_x
|
|
return x
|
|
|
|
|
|
class NerfFinalLayer(nn.Module):
|
|
def __init__(self, hidden_size, out_channels):
|
|
super().__init__()
|
|
self.norm = nn.RMSNorm(hidden_size, eps=1e-6)
|
|
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
|
|
|
def forward(self, x):
|
|
x = self.norm(x)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
|
|
class UTransformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: DINOv3ViTConfig,
|
|
num_classes: int,
|
|
nerf_patch=16,
|
|
nerf_hidden=64,
|
|
scale_factor: int = 4,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.scale_factor = scale_factor
|
|
self.nerf_patch_size = nerf_patch
|
|
|
|
assert config.num_hidden_layers % scale_factor % 3 == 0
|
|
|
|
self.embeddings = DINOv3ViTEmbeddings(config)
|
|
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
|
|
|
def gen_rope(depth: int):
|
|
hidden_size = config.hidden_size // (4**depth)
|
|
hacky_config = copy.copy(config)
|
|
hacky_config.hidden_size = hidden_size
|
|
hacky_config.intermediate_size = hacky_config.intermediate_size // (
|
|
3**depth
|
|
)
|
|
hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
|
|
2**depth
|
|
)
|
|
return DINOv3ViTRopePositionEmbedding(hacky_config)
|
|
|
|
self.decode_ropes = nn.ModuleList([gen_rope(i + 1) for i in range(2)])
|
|
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
|
|
|
self.encoder_layers = nn.ModuleList(
|
|
[DinoEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
)
|
|
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
DEPTH_LAYER = [0, 0, 1, 1, 2, 2]
|
|
self.residual_upscaler = ResidualUpscaler(
|
|
config,
|
|
DEPTH_LAYER, # hardcoded, sorry
|
|
bottleneck_dim=128,
|
|
)
|
|
self.decoder_layers = nn.ModuleList(
|
|
[
|
|
nn.ModuleList(
|
|
[
|
|
DinoDecoderLayer(config, depth)
|
|
for _ in range(DEPTH_LAYER.count(depth))
|
|
]
|
|
)
|
|
for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
|
|
]
|
|
)
|
|
self.residual_merger = nn.ModuleList(
|
|
[
|
|
nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(
|
|
config.hidden_size // (4**depth),
|
|
2 * config.hidden_size // (4**depth),
|
|
),
|
|
)
|
|
for _ in range(DEPTH_LAYER.count(depth))
|
|
]
|
|
)
|
|
for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
|
|
]
|
|
)
|
|
self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
|
|
self.upsample_latent = nn.ModuleList(
|
|
[
|
|
nn.Linear(
|
|
config.hidden_size // (4**depth),
|
|
config.hidden_size // (4 ** (depth + 1)),
|
|
)
|
|
for depth in range(2)
|
|
]
|
|
)
|
|
self.rest_decoder = nn.ModuleList(
|
|
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
|
)
|
|
self.decoder_norm = nn.LayerNorm(
|
|
(config.hidden_size // (4**2)), eps=config.layer_norm_eps
|
|
)
|
|
|
|
# nerf!
|
|
self.nerf_encoder = NerfEmbedder(3, nerf_hidden, 8) # (rgb, hidden, freq)
|
|
self.nerf_decoder = nn.ModuleList(
|
|
[NerfBlock(self.config.hidden_size, nerf_hidden) for _ in range(12)]
|
|
)
|
|
self.final_layer = NerfFinalLayer(nerf_hidden, 3)
|
|
|
|
# freeze pretrained
|
|
self.embeddings.requires_grad_(False)
|
|
self.rope_embeddings.requires_grad_(False)
|
|
self.encoder_norm.requires_grad_(False)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
time: torch.Tensor,
|
|
# cond: torch.Tensor,
|
|
bool_masked_pos: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
if time.dim() == 0:
|
|
time = time.repeat(pixel_values.shape[0])
|
|
|
|
# resolution config
|
|
B = pixel_values.shape[0]
|
|
dino_h = pixel_values.shape[-2] // self.config.patch_size
|
|
dino_w = pixel_values.shape[-1] // self.config.patch_size
|
|
nerf_h = pixel_values.shape[-2] // self.nerf_patch_size
|
|
nerf_w = pixel_values.shape[-1] // self.nerf_patch_size
|
|
|
|
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
|
position_embeddings = self.rope_embeddings(pixel_values)
|
|
|
|
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
|
t = self.t_embedder(time).unsqueeze(1)
|
|
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
|
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
|
conditioning_input = t.to(x.dtype)
|
|
|
|
residual = []
|
|
for i, layer_module in enumerate(self.encoder_layers):
|
|
if i % self.scale_factor == 0:
|
|
residual.append(x)
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
x = layer_module(
|
|
x,
|
|
conditioning_input=conditioning_input,
|
|
attention_mask=layer_head_mask,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
|
|
x = self.encoder_norm(x)
|
|
reversed_residual = self.residual_upscaler(pixel_values, residual[::-1])
|
|
|
|
residual_idx = 0
|
|
for depth, layers in enumerate(self.decoder_layers):
|
|
for i, layer_module in enumerate(layers): # type: ignore
|
|
x = layer_module(
|
|
x,
|
|
conditioning_input=conditioning_input,
|
|
attention_mask=None,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
shift, scale = self.residual_merger[depth][i]( # type: ignore
|
|
reversed_residual[residual_idx]
|
|
).chunk(2, dim=-1)
|
|
x = x * (1 + scale) + shift
|
|
residual_idx += 1
|
|
x = torch.cat(
|
|
(
|
|
rearrange(
|
|
self.upsample[depth](
|
|
rearrange(
|
|
x[:, 1 + self.config.num_register_tokens :],
|
|
"b (h w) d -> b d h w",
|
|
h=pixel_values.shape[-2]
|
|
// (self.config.patch_size)
|
|
* (2**depth),
|
|
)
|
|
),
|
|
"b d h w -> b (h w) d",
|
|
)
|
|
if depth != 2
|
|
else x[:, 1 + self.config.num_register_tokens :],
|
|
self.upsample_latent[depth](
|
|
x[:, : 1 + self.config.num_register_tokens]
|
|
)
|
|
if depth != 2
|
|
else x[:, : 1 + self.config.num_register_tokens],
|
|
),
|
|
dim=1,
|
|
)
|
|
position_embeddings = (
|
|
self.decode_ropes[depth](
|
|
torch.zeros(
|
|
(
|
|
1,
|
|
1,
|
|
pixel_values.shape[-2] * (2 ** (depth + 1)),
|
|
pixel_values.shape[-1] * (2 ** (depth + 1)),
|
|
),
|
|
device=x.device,
|
|
).to(self.embeddings.patch_embeddings.weight.dtype)
|
|
)
|
|
if depth != 2
|
|
else position_embeddings
|
|
)
|
|
|
|
for i, layer_module in enumerate(self.rest_decoder):
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
x = layer_module(
|
|
x,
|
|
conditioning_input=conditioning_input,
|
|
attention_mask=layer_head_mask,
|
|
position_embeddings=position_embeddings,
|
|
do_condition=False,
|
|
) # (batch, image // patch^2, 1024)
|
|
|
|
x = x[:, 1 + self.config.num_register_tokens :, :]
|
|
|
|
nerf_cond = nn.functional.silu(t + x) # (batch, image // patch^2, 1024)
|
|
nerf_cond = nerf_cond.reshape(
|
|
B, dino_h, dino_w, self.config.hidden_size
|
|
).permute(0, 3, 1, 2) # (batch, 1024, image // patch, image // patch)
|
|
# nerf_cond = nn.functional.interpolate(
|
|
# nerf_cond, size=(nerf_h, nerf_w), mode="bilinear", align_corners=False
|
|
# )
|
|
nerf_cond = (
|
|
nerf_cond.permute(0, 2, 3, 1)
|
|
.reshape(-1, nerf_h * nerf_w, self.config.hidden_size)
|
|
.view(-1, self.config.hidden_size)
|
|
)
|
|
|
|
# nerf
|
|
x_nerf = nn.functional.unfold(
|
|
pixel_values, self.nerf_patch_size, stride=self.nerf_patch_size
|
|
).transpose(1, 2)
|
|
x_nerf = x_nerf.reshape(
|
|
B * x_nerf.shape[1], -1, self.nerf_patch_size**2
|
|
).transpose(1, 2)
|
|
x_nerf = self.nerf_encoder(x_nerf)
|
|
|
|
for module in self.nerf_decoder:
|
|
x_nerf = module(x_nerf, nerf_cond)
|
|
|
|
x_nerf = self.final_layer(x_nerf)
|
|
|
|
num_patches = nerf_h * nerf_w
|
|
x_nerf = x_nerf.reshape(
|
|
B * num_patches, -1
|
|
) # (B*num_patches, 48): flatten pixels+RGB per patch
|
|
x_nerf = (
|
|
x_nerf.view(B, num_patches, -1).transpose(1, 2).contiguous()
|
|
) # (B, 48, num_patches)
|
|
|
|
res = nn.functional.fold(
|
|
x_nerf,
|
|
(pixel_values.shape[-2], pixel_values.shape[-1]),
|
|
kernel_size=self.nerf_patch_size,
|
|
stride=self.nerf_patch_size,
|
|
)
|
|
return res
|
|
|
|
def get_residual(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
time: Optional[torch.Tensor],
|
|
do_condition: bool,
|
|
):
|
|
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
|
position_embeddings = self.rope_embeddings(pixel_values)
|
|
|
|
x = self.embeddings(pixel_values, bool_masked_pos=None)
|
|
|
|
if do_condition:
|
|
t = self.t_embedder(time).unsqueeze(1)
|
|
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
|
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
|
conditioning_input = t.to(x.dtype)
|
|
else:
|
|
conditioning_input = None
|
|
|
|
residual = []
|
|
for i, layer_module in enumerate(self.encoder_layers):
|
|
if i % self.scale_factor == 0:
|
|
residual.append(x)
|
|
|
|
x = layer_module(
|
|
x,
|
|
conditioning_input=conditioning_input,
|
|
attention_mask=None,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
x = x[:, 1 + self.config.num_register_tokens :]
|
|
x = self.decoder_norm(x)
|
|
|
|
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
|
|
|
@staticmethod
|
|
def from_pretrained_backbone(name: str):
|
|
config = DINOv3ViTConfig.from_pretrained(name)
|
|
instance = UTransformer(config, 0)
|
|
|
|
weight_dict = {}
|
|
with safe_open(
|
|
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
|
) as f:
|
|
for key in f.keys():
|
|
new_key = key.replace("layer.", "encoder_layers.").replace(
|
|
"norm.", "encoder_norm."
|
|
)
|
|
|
|
weight_dict[new_key] = f.get_tensor(key)
|
|
|
|
instance.load_state_dict(weight_dict, strict=False)
|
|
|
|
return instance
|