things
This commit is contained in:
387
src/model/dit.py
Normal file
387
src/model/dit.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# Code heavily based on https://github.com/Alpha-VLLM/LLaMA2-Accessory
|
||||
# this is modeling code for DiT-LLaMA model
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.dino import DINOv3ViTModel
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
||||
).to(t.device)
|
||||
args = t[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = int(dropout_prob > 0)
|
||||
self.embedding_table = nn.Embedding(
|
||||
num_classes + use_cfg_embedding, hidden_size
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
|
||||
drop_ids = drop_ids.cuda()
|
||||
drop_ids = drop_ids.to(labels.device)
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, n_heads):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.n_rep = 1
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
|
||||
self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim)
|
||||
|
||||
@staticmethod
|
||||
def reshape_for_broadcast(freqs_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
# assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
_freqs_cis = freqs_cis[: x.shape[1]]
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return _freqs_cis.view(*shape)
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(xq, xk, freqs_cis):
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis_xq = Attention.reshape_for_broadcast(freqs_cis, xq_)
|
||||
freqs_cis_xk = Attention.reshape_for_broadcast(freqs_cis, xk_)
|
||||
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis_xq).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis_xk).flatten(3)
|
||||
return xq_out, xk_out
|
||||
|
||||
def forward(self, x, freqs_cis):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
dtype = xq.dtype
|
||||
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
|
||||
xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
xq, xk = xq.to(dtype), xk.to(dtype)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq.permute(0, 2, 1, 3),
|
||||
xk.permute(0, 2, 1, 3),
|
||||
xv.permute(0, 2, 1, 3),
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
).permute(0, 2, 1, 3)
|
||||
output = output.flatten(-2)
|
||||
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
if ffn_dim_multiplier:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = Attention(dim, n_heads)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(dim, 1024), 6 * dim, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, freqs_cis, adaln_input=None):
|
||||
if adaln_input is not None:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.adaLN_modulation(adaln_input).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1) * self.attention(
|
||||
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
|
||||
modulate(self.ffn_norm(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
else:
|
||||
x = x + self.attention(self.attention_norm(x), freqs_cis)
|
||||
x = x + self.feed_forward(self.ffn_norm(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, bias=True
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True),
|
||||
)
|
||||
# # init zero
|
||||
nn.init.constant_(self.linear.weight, 0)
|
||||
nn.init.constant_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DiT_Llama(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dino_cfg: DINOv3ViTConfig,
|
||||
in_channels=3,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
dim=512,
|
||||
n_layers=5,
|
||||
n_heads=16,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.input_size = input_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.init_conv_seq = nn.Sequential(
|
||||
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
|
||||
nn.SiLU(),
|
||||
nn.GroupNorm(32, dim // 2),
|
||||
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
|
||||
nn.SiLU(),
|
||||
nn.GroupNorm(32, dim // 2),
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
|
||||
nn.init.constant_(self.x_embedder.bias, 0)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024))
|
||||
self.y_embedder = DINOv3ViTModel(dino_cfg)
|
||||
self.thing = nn.Linear(dino_cfg.hidden_size, min(dim, 1024))
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
|
||||
|
||||
self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 16384)
|
||||
|
||||
# freeze
|
||||
self.y_embedder.requires_grad_(False)
|
||||
|
||||
def unpatchify(self, x):
|
||||
c = self.out_channels
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def patchify(self, x):
|
||||
B, C, H, W = x.size()
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
H // self.patch_size,
|
||||
self.patch_size,
|
||||
W // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
return x
|
||||
|
||||
def forward(self, x, t, y):
|
||||
self.freqs_cis = self.freqs_cis.to(x.device)
|
||||
|
||||
x = self.init_conv_seq(x)
|
||||
|
||||
x = self.patchify(x)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
y = self.thing(self.y_embedder(y)["pooler_output"]) # (N, D)
|
||||
adaln_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_cfg(self, x, t, y, cfg_scale):
|
||||
half = x[: len(x) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.forward(combined, t, y)
|
||||
eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim, end, theta=10000.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end)
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(
|
||||
name: str,
|
||||
in_channels=3,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
dim=512,
|
||||
n_layers=5,
|
||||
n_heads=16,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = DiT_Llama(
|
||||
config,
|
||||
in_channels,
|
||||
input_size,
|
||||
patch_size,
|
||||
dim,
|
||||
n_layers,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
).to("cuda:1")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
weight_dict[key] = f.get_tensor(key)
|
||||
|
||||
instance.y_embedder.load_state_dict(weight_dict, strict=True)
|
||||
|
||||
return instance
|
||||
Reference in New Issue
Block a user