things
This commit is contained in:
@@ -1,15 +1,22 @@
|
||||
import lpips
|
||||
from pytorch_msssim import ssim
|
||||
from torchmetrics.image import (
|
||||
LearnedPerceptualImagePatchSimilarity,
|
||||
PeakSignalNoiseRatio,
|
||||
StructuralSimilarityIndexMeasure,
|
||||
)
|
||||
|
||||
psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3))
|
||||
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(
|
||||
net_type="alex", reduction="none", normalize=True
|
||||
)
|
||||
lp = lpips.LPIPS(net="alex")
|
||||
|
||||
|
||||
def benchmark(image1, image2):
|
||||
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)
|
||||
return (
|
||||
psnr(image1, image2),
|
||||
ssim(
|
||||
image1,
|
||||
image2,
|
||||
data_range=1.0,
|
||||
size_average=False,
|
||||
),
|
||||
lp(image1 * 2 - 1, image2 * 2 - 1),
|
||||
)
|
||||
|
||||
@@ -373,7 +373,7 @@ class DINOv3ViTModel(nn.Module):
|
||||
self.config = config
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
self.layers = nn.ModuleList(
|
||||
self.layer = nn.ModuleList(
|
||||
[DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -420,7 +420,7 @@ class DINOv3ViTModel(nn.Module):
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
latents = []
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
|
||||
387
src/model/dit.py
Normal file
387
src/model/dit.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# Code heavily based on https://github.com/Alpha-VLLM/LLaMA2-Accessory
|
||||
# this is modeling code for DiT-LLaMA model
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.dino import DINOv3ViTModel
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
||||
).to(t.device)
|
||||
args = t[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = int(dropout_prob > 0)
|
||||
self.embedding_table = nn.Embedding(
|
||||
num_classes + use_cfg_embedding, hidden_size
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
|
||||
drop_ids = drop_ids.cuda()
|
||||
drop_ids = drop_ids.to(labels.device)
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, n_heads):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.n_rep = 1
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
|
||||
self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim)
|
||||
|
||||
@staticmethod
|
||||
def reshape_for_broadcast(freqs_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
# assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
_freqs_cis = freqs_cis[: x.shape[1]]
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return _freqs_cis.view(*shape)
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(xq, xk, freqs_cis):
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis_xq = Attention.reshape_for_broadcast(freqs_cis, xq_)
|
||||
freqs_cis_xk = Attention.reshape_for_broadcast(freqs_cis, xk_)
|
||||
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis_xq).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis_xk).flatten(3)
|
||||
return xq_out, xk_out
|
||||
|
||||
def forward(self, x, freqs_cis):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
dtype = xq.dtype
|
||||
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
|
||||
xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
xq, xk = xq.to(dtype), xk.to(dtype)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq.permute(0, 2, 1, 3),
|
||||
xk.permute(0, 2, 1, 3),
|
||||
xv.permute(0, 2, 1, 3),
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
).permute(0, 2, 1, 3)
|
||||
output = output.flatten(-2)
|
||||
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
if ffn_dim_multiplier:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = Attention(dim, n_heads)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(dim, 1024), 6 * dim, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, freqs_cis, adaln_input=None):
|
||||
if adaln_input is not None:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.adaLN_modulation(adaln_input).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1) * self.attention(
|
||||
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
|
||||
modulate(self.ffn_norm(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
else:
|
||||
x = x + self.attention(self.attention_norm(x), freqs_cis)
|
||||
x = x + self.feed_forward(self.ffn_norm(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, bias=True
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True),
|
||||
)
|
||||
# # init zero
|
||||
nn.init.constant_(self.linear.weight, 0)
|
||||
nn.init.constant_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DiT_Llama(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dino_cfg: DINOv3ViTConfig,
|
||||
in_channels=3,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
dim=512,
|
||||
n_layers=5,
|
||||
n_heads=16,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.input_size = input_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.init_conv_seq = nn.Sequential(
|
||||
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
|
||||
nn.SiLU(),
|
||||
nn.GroupNorm(32, dim // 2),
|
||||
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
|
||||
nn.SiLU(),
|
||||
nn.GroupNorm(32, dim // 2),
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
|
||||
nn.init.constant_(self.x_embedder.bias, 0)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024))
|
||||
self.y_embedder = DINOv3ViTModel(dino_cfg)
|
||||
self.thing = nn.Linear(dino_cfg.hidden_size, min(dim, 1024))
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
|
||||
|
||||
self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 16384)
|
||||
|
||||
# freeze
|
||||
self.y_embedder.requires_grad_(False)
|
||||
|
||||
def unpatchify(self, x):
|
||||
c = self.out_channels
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def patchify(self, x):
|
||||
B, C, H, W = x.size()
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
H // self.patch_size,
|
||||
self.patch_size,
|
||||
W // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
return x
|
||||
|
||||
def forward(self, x, t, y):
|
||||
self.freqs_cis = self.freqs_cis.to(x.device)
|
||||
|
||||
x = self.init_conv_seq(x)
|
||||
|
||||
x = self.patchify(x)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
y = self.thing(self.y_embedder(y)["pooler_output"]) # (N, D)
|
||||
adaln_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_cfg(self, x, t, y, cfg_scale):
|
||||
half = x[: len(x) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.forward(combined, t, y)
|
||||
eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim, end, theta=10000.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end)
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(
|
||||
name: str,
|
||||
in_channels=3,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
dim=512,
|
||||
n_layers=5,
|
||||
n_heads=16,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = DiT_Llama(
|
||||
config,
|
||||
in_channels,
|
||||
input_size,
|
||||
patch_size,
|
||||
dim,
|
||||
n_layers,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
).to("cuda:1")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
weight_dict[key] = f.get_tensor(key)
|
||||
|
||||
instance.y_embedder.load_state_dict(weight_dict, strict=True)
|
||||
|
||||
return instance
|
||||
@@ -1,144 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.dino import (
|
||||
DINOv3ViTEmbeddings,
|
||||
DINOv3ViTRopePositionEmbedding,
|
||||
)
|
||||
from src.model.utransformer import DinoConditionedLayer, TimestepEmbedder
|
||||
|
||||
|
||||
class Hourgrass(nn.Module):
|
||||
def __init__(
|
||||
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
assert config.num_hidden_layers % scale_factor == 0
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, True)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
# freeze pretrained
|
||||
self.embeddings.requires_grad_(False)
|
||||
self.rope_embeddings.requires_grad_(False)
|
||||
self.encoder_norm.requires_grad_(False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: torch.Tensor,
|
||||
# cond: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if time.dim() == 0:
|
||||
time = time.repeat(pixel_values.shape[0])
|
||||
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
|
||||
conditioning_input = t.to(x.dtype)
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
x = self.encoder_norm(x)
|
||||
|
||||
reversed_residual = residual[::-1]
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
x = x + reversed_residual[i]
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
def get_residual(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: Optional[torch.Tensor],
|
||||
do_condition: bool,
|
||||
):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=None)
|
||||
|
||||
if do_condition:
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
conditioning_input = t.to(x.dtype)
|
||||
else:
|
||||
conditioning_input = None
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=do_condition,
|
||||
)
|
||||
|
||||
return residual
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = UTransformer(config, 0).to("cuda:1")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
new_key = key.replace("layer.", "encoder_layers.").replace(
|
||||
"norm.", "encoder_norm."
|
||||
)
|
||||
|
||||
weight_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
instance.load_state_dict(weight_dict, strict=False)
|
||||
|
||||
return instance
|
||||
@@ -78,6 +78,7 @@ class LabelEmbedder(nn.Module):
|
||||
class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
|
||||
super().__init__(config)
|
||||
self.is_encoder = is_encoder
|
||||
|
||||
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cond = nn.MultiheadAttention(
|
||||
@@ -298,6 +299,17 @@ class UTransformer(nn.Module):
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
]
|
||||
)
|
||||
self.residual_merger = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size)
|
||||
)
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
]
|
||||
)
|
||||
self.rest_decoder = nn.ModuleList(
|
||||
[DinoConditionedLayer(config, False) for _ in range(4)]
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
@@ -348,8 +360,22 @@ class UTransformer(nn.Module):
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
shift, scale = self.residual_merger[i](reversed_residual[i]).chunk(
|
||||
2, dim=-1
|
||||
)
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
for i, layer_module in enumerate(self.rest_decoder):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
x = x + reversed_residual[i]
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
|
||||
94
src/rf.py
94
src/rf.py
@@ -25,6 +25,7 @@ class RF:
|
||||
def __init__(self, model, fm="otcfm", loss="mse"):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.iter = 0
|
||||
|
||||
sigma = 0.0
|
||||
if fm == "otcfm":
|
||||
@@ -97,9 +98,14 @@ class RF:
|
||||
total_d_loss.backward(retain_graph=True)
|
||||
self.optimizer_D.step()
|
||||
|
||||
def forward(self, gt, cloud):
|
||||
t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore
|
||||
vt, _ = self.model(xt, t)
|
||||
def forward(self, gt, cloud, condition=False):
|
||||
t, xt, ut = self.fm.sample_location_and_conditional_flow( # type: ignore
|
||||
cloud if not condition else torch.randn_like(cloud), gt
|
||||
)
|
||||
if condition:
|
||||
vt = self.model(xt, t, cloud)
|
||||
else:
|
||||
vt, _ = self.model(xt, t)
|
||||
|
||||
if self.loss == "mse":
|
||||
loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
|
||||
@@ -116,10 +122,12 @@ class RF:
|
||||
}
|
||||
loss = mse + lpips * 2.0
|
||||
elif self.loss == "gan_lpips_mse":
|
||||
self.gan_loss(
|
||||
denormalize(gt),
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
)
|
||||
self.iter += 1
|
||||
# if self.iter % 4 == 0:
|
||||
# self.gan_loss(
|
||||
# denormalize(gt),
|
||||
# denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
# )
|
||||
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
|
||||
lpips = self.lpips(
|
||||
denormalize(gt) * 2 - 1,
|
||||
@@ -129,12 +137,9 @@ class RF:
|
||||
denormalize(gt) * 2 - 1,
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
|
||||
)
|
||||
gan = (
|
||||
-self.discriminator(
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
).mean(-1)
|
||||
* 0.01
|
||||
)
|
||||
# gan = -self.discriminator(
|
||||
# denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
# ).mean(-1)
|
||||
ssim = 1 - ms_ssim(
|
||||
denormalize(gt),
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
@@ -145,10 +150,10 @@ class RF:
|
||||
"train/mse": mse.mean().item(),
|
||||
"train/lpips": lpips.mean().item(),
|
||||
"train/alexlpips": alexlpips.mean().item(),
|
||||
"train/gan": gan.mean().item(),
|
||||
# "train/gan": gan.mean().item(),
|
||||
"train/ssim": ssim.mean().item(),
|
||||
}
|
||||
loss = mse + lpips * 4.0 + gan + alexlpips + ssim
|
||||
loss = mse + lpips * 2.0 + alexlpips + ssim
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
@@ -158,43 +163,28 @@ class RF:
|
||||
return loss.mean(), ttloss, loss_list
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor:
|
||||
def sample(
|
||||
self, cloud, tol=1e-5, integration="dopri5", condition=False
|
||||
) -> torch.Tensor:
|
||||
t_span = torch.linspace(0, 1, 2, device=cloud.device)
|
||||
traj = odeint(
|
||||
lambda t, x: self.model(x, t)[0],
|
||||
cloud,
|
||||
t_span,
|
||||
rtol=tol,
|
||||
atol=tol,
|
||||
method=integration,
|
||||
)
|
||||
if condition:
|
||||
x = torch.randn_like(cloud)
|
||||
traj = odeint(
|
||||
lambda t, x: self.model(x, t, cloud),
|
||||
x,
|
||||
t_span,
|
||||
rtol=tol,
|
||||
atol=tol,
|
||||
method=integration,
|
||||
)
|
||||
else:
|
||||
traj = odeint(
|
||||
lambda t, x: self.model(x, t)[0],
|
||||
cloud,
|
||||
t_span,
|
||||
rtol=tol,
|
||||
atol=tol,
|
||||
method=integration,
|
||||
)
|
||||
|
||||
return [traj[i] for i in range(traj.shape[0])] # type: ignore
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(self, z1, sample_steps=50):
|
||||
b = z1.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
|
||||
images = [z1]
|
||||
z = z1
|
||||
|
||||
for i in range(sample_steps, 0, -1):
|
||||
t_current = i / sample_steps
|
||||
t_next = (i - 1) / sample_steps
|
||||
|
||||
t_current_tensor = torch.tensor([t_current] * b, device=z.device)
|
||||
|
||||
v_current, _ = self.model(z, t_current_tensor)
|
||||
z_pred = z - dt * v_current
|
||||
|
||||
t_next_tensor = torch.tensor([t_next] * b, device=z.device)
|
||||
v_next, _ = self.model(z_pred, t_next_tensor)
|
||||
|
||||
v_avg = 0.5 * (v_current + v_next)
|
||||
|
||||
z = z - dt * v_avg
|
||||
|
||||
images.append(z)
|
||||
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user