Merge commit 'ca589fa'

This commit is contained in:
neulus
2025-11-02 22:55:07 +09:00
13 changed files with 888 additions and 98 deletions

View File

@@ -1,5 +1,6 @@
import copy
import math
from functools import lru_cache
from typing import Optional
import einops
@@ -420,13 +421,107 @@ class DinoV3ViTDecoder(nn.Module):
)
class NerfEmbedder(nn.Module):
def __init__(self, in_channels, hidden_size_input, max_freqs):
super().__init__()
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
self.embedder = nn.Sequential(
nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True),
)
@lru_cache
def fetch_pos(self, patch_size, device, dtype):
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
pos_x = pos_x.reshape(-1, 1, 1)
pos_y = pos_y.reshape(-1, 1, 1)
freqs = torch.linspace(
0, self.max_freqs, self.max_freqs, dtype=dtype, device=device
)
freqs_x = freqs[None, :, None]
freqs_y = freqs[None, None, :]
coeffs = (1 + freqs_x * freqs_y) ** -1
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2)
return dct
def forward(self, inputs):
target_dtype = self.embedder[0].weight.dtype
inputs = inputs.to(dtype=target_dtype)
B, P2, C = inputs.shape
patch_size = int(P2**0.5)
device = inputs.device
dtype = inputs.dtype
dct = self.fetch_pos(patch_size, device, dtype)
dct = dct.repeat(B, 1, 1)
inputs = torch.cat([inputs, dct], dim=-1)
inputs = self.embedder(inputs)
return inputs
class NerfBlock(nn.Module):
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio: int = 4):
super().__init__()
self.param_generator1 = nn.Sequential(
nn.Linear(hidden_size_s, 2 * hidden_size_x**2 * mlp_ratio, bias=True),
)
self.norm = nn.RMSNorm(hidden_size_x, eps=1e-6)
self.mlp_ratio = mlp_ratio
def forward(self, x, s):
batch_size, num_x, hidden_size_x = x.shape
mlp_params1 = self.param_generator1(s)
fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1)
fc1_param1 = fc1_param1.view(
batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio
)
fc2_param1 = fc2_param1.view(
batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x
)
# normalize fc1
normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2)
# normalize fc2
normalized_fc2_param1 = torch.nn.functional.normalize(fc2_param1, dim=-2)
# mlp 1
res_x = x
x = self.norm(x)
x = torch.bmm(x, normalized_fc1_param1)
x = torch.nn.functional.silu(x)
x = torch.bmm(x, normalized_fc2_param1)
x = x + res_x
return x
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm = nn.RMSNorm(hidden_size, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
def forward(self, x):
x = self.norm(x)
x = self.linear(x)
return x
class UTransformer(nn.Module):
def __init__(
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
self,
config: DINOv3ViTConfig,
num_classes: int,
nerf_patch=16,
nerf_hidden=64,
scale_factor: int = 4,
):
super().__init__()
self.config = config
self.scale_factor = scale_factor
self.nerf_patch_size = nerf_patch
assert config.num_hidden_layers % scale_factor % 3 == 0
@@ -503,7 +598,13 @@ class UTransformer(nn.Module):
self.decoder_norm = nn.LayerNorm(
(config.hidden_size // (4**2)), eps=config.layer_norm_eps
)
self.decoder = DinoV3ViTDecoder(config)
# nerf!
self.nerf_encoder = NerfEmbedder(3, nerf_hidden, 8) # (rgb, hidden, freq)
self.nerf_decoder = nn.ModuleList(
[NerfBlock(self.config.hidden_size, nerf_hidden) for _ in range(12)]
)
self.final_layer = NerfFinalLayer(nerf_hidden, 3)
# freeze pretrained
self.embeddings.requires_grad_(False)
@@ -521,6 +622,13 @@ class UTransformer(nn.Module):
if time.dim() == 0:
time = time.repeat(pixel_values.shape[0])
# resolution config
B = pixel_values.shape[0]
dino_h = pixel_values.shape[-2] // self.config.patch_size
dino_w = pixel_values.shape[-1] // self.config.patch_size
nerf_h = pixel_values.shape[-2] // self.nerf_patch_size
nerf_w = pixel_values.shape[-1] // self.nerf_patch_size
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
@@ -600,6 +708,84 @@ class UTransformer(nn.Module):
)
for i, layer_module in enumerate(self.rest_decoder):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
do_condition=False,
) # (batch, image // patch^2, 1024)
x = x[:, 1 + self.config.num_register_tokens :, :]
nerf_cond = nn.functional.silu(t + x) # (batch, image // patch^2, 1024)
nerf_cond = nerf_cond.reshape(
B, dino_h, dino_w, self.config.hidden_size
).permute(0, 3, 1, 2) # (batch, 1024, image // patch, image // patch)
# nerf_cond = nn.functional.interpolate(
# nerf_cond, size=(nerf_h, nerf_w), mode="bilinear", align_corners=False
# )
nerf_cond = (
nerf_cond.permute(0, 2, 3, 1)
.reshape(-1, nerf_h * nerf_w, self.config.hidden_size)
.view(-1, self.config.hidden_size)
)
# nerf
x_nerf = nn.functional.unfold(
pixel_values, self.nerf_patch_size, stride=self.nerf_patch_size
).transpose(1, 2)
x_nerf = x_nerf.reshape(
B * x_nerf.shape[1], -1, self.nerf_patch_size**2
).transpose(1, 2)
x_nerf = self.nerf_encoder(x_nerf)
for module in self.nerf_decoder:
x_nerf = module(x_nerf, nerf_cond)
x_nerf = self.final_layer(x_nerf)
num_patches = nerf_h * nerf_w
x_nerf = x_nerf.reshape(
B * num_patches, -1
) # (B*num_patches, 48): flatten pixels+RGB per patch
x_nerf = (
x_nerf.view(B, num_patches, -1).transpose(1, 2).contiguous()
) # (B, 48, num_patches)
res = nn.functional.fold(
x_nerf,
(pixel_values.shape[-2], pixel_values.shape[-1]),
kernel_size=self.nerf_patch_size,
stride=self.nerf_patch_size,
)
return res
def get_residual(
self,
pixel_values: torch.Tensor,
time: Optional[torch.Tensor],
do_condition: bool,
):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
x = self.embeddings(pixel_values, bool_masked_pos=None)
if do_condition:
t = self.t_embedder(time).unsqueeze(1)
# y = self.y_embedder(cond, self.training).unsqueeze(1)
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
conditioning_input = t.to(x.dtype)
else:
conditioning_input = None
residual = []
for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0:
residual.append(x)
x = layer_module(
x,
conditioning_input=conditioning_input,
@@ -614,7 +800,7 @@ class UTransformer(nn.Module):
@staticmethod
def from_pretrained_backbone(name: str):
config = DINOv3ViTConfig.from_pretrained(name)
instance = UTransformer(config, 0).to("cuda:1")
instance = UTransformer(config, 0)
weight_dict = {}
with safe_open(