v2; a lot of hacks
This commit is contained in:
113
src/model/attention.py
Normal file
113
src/model/attention.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# GENERAL attention with flash attention flavor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
from src.model.dino import apply_rotary_pos_emb, rotate_half
|
||||
|
||||
|
||||
class PlainAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout_rate: float = 0.1,
|
||||
query_bias: bool = True,
|
||||
key_bias: bool = True,
|
||||
value_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.dropout = dropout_rate
|
||||
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=query_bias)
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=key_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=value_bias)
|
||||
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=proj_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor | None = None,
|
||||
v: torch.Tensor | None = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
k = k if k is not None else q
|
||||
v = v if v is not None else q
|
||||
assert k is not None and v is not None
|
||||
|
||||
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
|
||||
q = einops.rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
|
||||
k = einops.rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
|
||||
v = einops.rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin) # type: ignore
|
||||
|
||||
# flash attn handles sftmx scaling too
|
||||
o = flash_attn_func(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2), # type: ignore
|
||||
v.transpose(1, 2), # type: ignore
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
return self.o_proj(einops.rearrange(o, "b s h d -> b s (h d)").contiguous()) # type: ignore
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim: int, key_dim: int, value_dim: int):
|
||||
super().__init__()
|
||||
self.attention = nn.MultiheadAttention(
|
||||
embed_dim=query_dim,
|
||||
vdim=value_dim,
|
||||
num_heads=1,
|
||||
dropout=0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
):
|
||||
_, attn_score = self.attention(q, k, v, average_attn_weights=True)
|
||||
return torch.einsum("b i j, b j d -> b i d", attn_score, v)
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
theta: int = 100,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.freqs = nn.Parameter(torch.empty(2, self.dim))
|
||||
|
||||
def _device_weight_init(self):
|
||||
# Create freqs in 1d
|
||||
freqs_1d = self.theta ** torch.linspace(0, -1, self.dim // 4)
|
||||
# duplicate freqs for rotation pairs of channels
|
||||
freqs_1d = torch.cat([freqs_1d, freqs_1d])
|
||||
# First half of channels do x, second half do y
|
||||
freqs_2d = torch.zeros(2, self.dim)
|
||||
freqs_2d[0, : self.dim // 2] = freqs_1d
|
||||
freqs_2d[1, -self.dim // 2 :] = freqs_1d
|
||||
# it's an angular freq here
|
||||
self.freqs.data.copy_(freqs_2d * 2 * torch.pi)
|
||||
|
||||
def forward(self, x: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
|
||||
angle = coords @ self.freqs
|
||||
return x * angle.cos() + rotate_half(x) * angle.sin()
|
||||
@@ -235,7 +235,7 @@ class DINOv3ViTAttention(nn.Module):
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=None, # Will use default 1/sqrt(headdim)
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
|
||||
63
src/model/resnet.py
Normal file
63
src/model/resnet.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Basic Residual Block, adapted from magvit1"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_groups=8,
|
||||
pad_mode="zeros",
|
||||
norm_fn=None,
|
||||
activation_fn=nn.SiLU,
|
||||
use_conv_shortcut=False,
|
||||
):
|
||||
super(ResBlock, self).__init__()
|
||||
self.use_conv_shortcut = use_conv_shortcut
|
||||
self.norm1 = (
|
||||
norm_fn(num_groups, in_channels) if norm_fn is not None else nn.Identity()
|
||||
)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
padding_mode=pad_mode,
|
||||
bias=False,
|
||||
)
|
||||
self.norm2 = (
|
||||
norm_fn(num_groups, out_channels) if norm_fn is not None else nn.Identity()
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
padding_mode=pad_mode,
|
||||
bias=False,
|
||||
)
|
||||
self.activation_fn = activation_fn()
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
padding_mode=pad_mode,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = self.activation_fn(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation_fn(x)
|
||||
x = self.conv2(x)
|
||||
if self.use_conv_shortcut or residual.shape != x.shape:
|
||||
residual = self.shortcut(residual)
|
||||
return x + residual
|
||||
@@ -1,19 +1,41 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.attention import CrossAttention, PlainAttention, RoPE
|
||||
from src.model.dino import (
|
||||
DINOv3ViTAttention,
|
||||
DINOv3ViTDropPath,
|
||||
DINOv3ViTEmbeddings,
|
||||
DINOv3ViTGatedMLP,
|
||||
DINOv3ViTLayer,
|
||||
DINOv3ViTLayerScale,
|
||||
DINOv3ViTMLP,
|
||||
DINOv3ViTRopePositionEmbedding,
|
||||
)
|
||||
from src.model.dit import modulate
|
||||
from src.model.resnet import ResBlock
|
||||
|
||||
|
||||
def create_coordinate(h, w, start=0, end=1, device="cuda:1", dtype=torch.float32):
|
||||
# Create a grid of coordinates
|
||||
x = torch.linspace(start, end, h, device=device, dtype=dtype)
|
||||
y = torch.linspace(start, end, w, device=device, dtype=dtype)
|
||||
# Create a 2D map using meshgrid
|
||||
xx, yy = torch.meshgrid(x, y, indexing="ij")
|
||||
# Stack the x and y coordinates to create the final map
|
||||
coord_map = torch.stack([xx, yy], dim=-1)[None, ...]
|
||||
coords = rearrange(coord_map, "b h w c -> b (h w) c", h=h, w=w)
|
||||
return coords
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
@@ -76,24 +98,21 @@ class LabelEmbedder(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
|
||||
class DinoEncoderLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig):
|
||||
super().__init__(config)
|
||||
self.is_encoder = is_encoder
|
||||
|
||||
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cond = DINOv3ViTAttention(config)
|
||||
self.layer_scale_cond = DINOv3ViTLayerScale(config)
|
||||
|
||||
# no init zeros!
|
||||
if is_encoder:
|
||||
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
|
||||
self.norm1.requires_grad_(False)
|
||||
self.norm2.requires_grad_(False)
|
||||
self.attention.requires_grad_(False)
|
||||
self.mlp.requires_grad_(False)
|
||||
self.layer_scale1.requires_grad_(False)
|
||||
self.layer_scale2.requires_grad_(False)
|
||||
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
|
||||
self.norm1.requires_grad_(False)
|
||||
self.norm2.requires_grad_(False)
|
||||
self.attention.requires_grad_(False)
|
||||
self.mlp.requires_grad_(False)
|
||||
self.layer_scale1.requires_grad_(False)
|
||||
self.layer_scale2.requires_grad_(False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -102,79 +121,257 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
conditioning_input: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
do_condition: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert position_embeddings is not None
|
||||
assert conditioning_input is not None or not do_condition
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
|
||||
if do_condition:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_cond(hidden_states)
|
||||
hidden_states = self.cond(
|
||||
hidden_states,
|
||||
conditioning_input,
|
||||
hidden_states = (
|
||||
self.drop_path(
|
||||
self.layer_scale1(
|
||||
self.attention(
|
||||
self.norm1(hidden_states),
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
)
|
||||
)
|
||||
hidden_states = self.layer_scale_cond(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.layer_scale2(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
hidden_states = (
|
||||
self.drop_path(
|
||||
self.layer_scale_cond(
|
||||
self.cond(
|
||||
hidden_states,
|
||||
self.norm_cond(conditioning_input),
|
||||
)
|
||||
)
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
hidden_states = (
|
||||
self.drop_path(self.layer_scale2(self.mlp(self.norm2(hidden_states))))
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# class DinoV3ViTDecoder(nn.Module):
|
||||
# def __init__(self, config: DINOv3ViTConfig):
|
||||
# super().__init__()
|
||||
# self.config = config
|
||||
# self.num_channels_out = config.num_channels
|
||||
class DinoDecoderLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, depth: int):
|
||||
super().__init__(config)
|
||||
|
||||
# self.projection = nn.Linear(
|
||||
# config.hidden_size,
|
||||
# self.num_channels_out * config.patch_size * config.patch_size,
|
||||
# bias=True,
|
||||
# )
|
||||
hidden_size = config.hidden_size // (16**depth)
|
||||
hacky_config = copy.copy(config)
|
||||
hacky_config.hidden_size = hidden_size
|
||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth)
|
||||
|
||||
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
# batch_size = x.shape[0]
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = PlainAttention(
|
||||
hidden_size, config.num_attention_heads // (4**depth)
|
||||
) # head scaling law?
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
|
||||
self.drop_path = (
|
||||
DINOv3ViTDropPath(hacky_config.drop_path_rate)
|
||||
if hacky_config.drop_path_rate > 0.0
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
# num_special_tokens = 1 + self.config.num_register_tokens
|
||||
# patch_tokens = x[:, num_special_tokens:, :]
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=hacky_config.layer_norm_eps)
|
||||
|
||||
# projected_tokens = self.projection(patch_tokens)
|
||||
if config.use_gated_mlp:
|
||||
self.mlp = DINOv3ViTGatedMLP(hacky_config)
|
||||
else:
|
||||
self.mlp = DINOv3ViTMLP(hacky_config)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(hacky_config)
|
||||
|
||||
# p = self.config.patch_size
|
||||
# c = self.num_channels_out
|
||||
# h_grid = image_size[0] // p
|
||||
# w_grid = image_size[1] // p
|
||||
# adaln
|
||||
self.adaln = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(config.hidden_size, 6 * hidden_size, bias=True),
|
||||
)
|
||||
|
||||
# assert patch_tokens.shape[1] == h_grid * w_grid, (
|
||||
# "Number of patches does not match image size."
|
||||
# )
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
conditioning_input: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
conditioning_input = conditioning_input.squeeze(1) # type: ignore
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln(
|
||||
conditioning_input
|
||||
).chunk(6, dim=-1)
|
||||
|
||||
# x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
|
||||
hidden_states = (
|
||||
self.drop_path(
|
||||
gate_msa.unsqueeze(1)
|
||||
* self.layer_scale1(
|
||||
self.attention(
|
||||
modulate(self.norm1(hidden_states), shift_msa, scale_msa),
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
)
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
# x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
|
||||
hidden_states = (
|
||||
self.drop_path(
|
||||
gate_mlp.unsqueeze(1)
|
||||
* self.layer_scale2(
|
||||
self.mlp(modulate(self.norm2(hidden_states), shift_mlp, scale_mlp))
|
||||
)
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
# reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
|
||||
return hidden_states
|
||||
|
||||
# return reconstructed_image
|
||||
|
||||
# lets try conv decoder
|
||||
class ResidualUpscaler(nn.Module):
|
||||
def __init__(
|
||||
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
|
||||
): # max depth 2 (4**2 = 16 = patch size)
|
||||
super().__init__()
|
||||
|
||||
def build_encoder(in_dim, num_layers=2):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_dim,
|
||||
bottleneck_dim,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
padding_mode="reflect",
|
||||
bias=False,
|
||||
),
|
||||
*[
|
||||
ResBlock(
|
||||
bottleneck_dim,
|
||||
bottleneck_dim,
|
||||
kernel_size=1,
|
||||
num_groups=8,
|
||||
pad_mode="reflect",
|
||||
norm_fn=nn.GroupNorm,
|
||||
activation_fn=nn.SiLU,
|
||||
use_conv_shortcut=False,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
],
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.depth = depth
|
||||
self.global_encode = nn.Linear(
|
||||
config.hidden_size * len(depth),
|
||||
bottleneck_dim * 2,
|
||||
)
|
||||
self.local_encode = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(config.hidden_size, bottleneck_dim * 2)
|
||||
if d != 0
|
||||
else nn.Identity()
|
||||
for d in depth
|
||||
]
|
||||
)
|
||||
self.q_norm = nn.ModuleList(
|
||||
[nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth]
|
||||
)
|
||||
self.k_norm = nn.ModuleList(
|
||||
[nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth]
|
||||
)
|
||||
self.v_downsample = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(config.hidden_size, config.hidden_size // (16**d))
|
||||
if d != 0
|
||||
else nn.Identity()
|
||||
for d in depth
|
||||
]
|
||||
)
|
||||
self.cross_attn = nn.ModuleList(
|
||||
[
|
||||
CrossAttention(
|
||||
bottleneck_dim,
|
||||
bottleneck_dim,
|
||||
config.hidden_size // (16**d),
|
||||
)
|
||||
if d != 0
|
||||
else nn.Identity()
|
||||
for d in depth
|
||||
]
|
||||
)
|
||||
|
||||
self.image_encoder = build_encoder(3)
|
||||
self.q_encoder = build_encoder(bottleneck_dim)
|
||||
self.k_encoder = build_encoder(bottleneck_dim)
|
||||
|
||||
self.rope = RoPE(bottleneck_dim)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
||||
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
|
||||
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
||||
assert self.config.patch_size is not None
|
||||
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
||||
|
||||
global_shift, global_scale = self.global_encode(
|
||||
einops.rearrange(
|
||||
torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)"
|
||||
)
|
||||
).chunk(2, dim=-1) # patch-level global btw
|
||||
|
||||
image_residual = self.image_encoder(pixel_values) # messy; todo: cleanup
|
||||
coords = create_coordinate(pixel_values.shape[-2], pixel_values.shape[-1])
|
||||
image_residual = rearrange(image_residual, "b c h w -> b (h w) c")
|
||||
image_residual = self.rope(image_residual, coords)
|
||||
image_residual = rearrange(
|
||||
image_residual, "b (h w) c -> b c h w", h=pixel_values.shape[-2]
|
||||
)
|
||||
q = self.q_encoder(image_residual)
|
||||
k = self.k_encoder(image_residual)
|
||||
|
||||
reformed_residual = []
|
||||
for i, (depth, residual) in enumerate(zip(self.depth, residuals)):
|
||||
if depth == 0:
|
||||
reformed_residual.append(residual)
|
||||
continue
|
||||
|
||||
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
||||
local_q = self.q_norm[i](
|
||||
einops.rearrange(
|
||||
F.adaptive_avg_pool2d(
|
||||
q,
|
||||
output_size=(
|
||||
image_h // self.config.patch_size * (4**depth),
|
||||
image_w // self.config.patch_size * (4**depth),
|
||||
),
|
||||
),
|
||||
"b c h w -> b (h w) c",
|
||||
)
|
||||
)
|
||||
local_k = (1 + local_scale) * (
|
||||
(1 + global_scale)
|
||||
* self.k_norm[i](
|
||||
einops.rearrange(
|
||||
F.adaptive_avg_pool2d(
|
||||
k,
|
||||
output_size=(
|
||||
image_h // self.config.patch_size,
|
||||
image_w // self.config.patch_size,
|
||||
),
|
||||
),
|
||||
"b c h w -> b (h w) c",
|
||||
)
|
||||
)
|
||||
+ global_shift
|
||||
) + local_shift
|
||||
local_v = self.v_downsample[i](residual)
|
||||
|
||||
reformed_residual.append(self.cross_attn[i](local_q, local_k, local_v))
|
||||
|
||||
return reformed_residual
|
||||
|
||||
|
||||
class DinoV3ViTDecoder(nn.Module):
|
||||
@@ -185,9 +382,8 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size, config.num_channels * (self.patch_size**2), bias=True
|
||||
config.hidden_size // 16 // 16, config.num_channels, bias=True
|
||||
)
|
||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||
|
||||
nn.init.zeros_(self.projection.weight)
|
||||
nn.init.zeros_(
|
||||
@@ -195,69 +391,11 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
) if self.projection.bias is not None else None
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
p = self.config.patch_size
|
||||
h_grid = image_size[0] // p
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert x.shape[1] == h_grid * w_grid
|
||||
x = self.projection(x)
|
||||
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
||||
x = self.pixel_shuffle(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# how about transposed conv decoderclass DinoV3ViTDecoder(nn.Module):
|
||||
# class DinoV3ViTDecoder(nn.Module):
|
||||
# def __init__(self, config: DINOv3ViTConfig):
|
||||
# super().__init__()
|
||||
# self.config = config
|
||||
# self.num_channels_out = config.num_channels
|
||||
# self.patch_size = config.patch_size
|
||||
|
||||
# intermediate_channels = config.hidden_size // 4
|
||||
|
||||
# self.decoder_block = nn.Sequential(
|
||||
# nn.ConvTranspose2d(
|
||||
# in_channels=config.hidden_size,
|
||||
# out_channels=intermediate_channels,
|
||||
# kernel_size=self.patch_size,
|
||||
# stride=self.patch_size,
|
||||
# bias=True,
|
||||
# ),
|
||||
# nn.LayerNorm(intermediate_channels),
|
||||
# nn.GELU(),
|
||||
# nn.Conv2d(
|
||||
# in_channels=intermediate_channels,
|
||||
# out_channels=config.num_channels,
|
||||
# kernel_size=1,
|
||||
# bias=True,
|
||||
# ),
|
||||
# )
|
||||
|
||||
# nn.init.zeros_(self.decoder_block[-1].weight)
|
||||
# nn.init.zeros_(self.decoder_block[-1].bias)
|
||||
|
||||
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
# batch_size = x.shape[0]
|
||||
|
||||
# x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
# p = self.config.patch_size
|
||||
# h_grid = image_size[0] // p
|
||||
# w_grid = image_size[1] // p
|
||||
# assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
# x = x.transpose(1, 2).reshape(
|
||||
# batch_size, self.config.hidden_size, h_grid, w_grid
|
||||
# )
|
||||
# x = self.decoder_block(x)
|
||||
|
||||
# return x
|
||||
return self.projection(
|
||||
einops.rearrange(
|
||||
x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1]
|
||||
)
|
||||
).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
@@ -268,41 +406,71 @@ class UTransformer(nn.Module):
|
||||
self.config = config
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
assert config.num_hidden_layers % scale_factor == 0
|
||||
assert config.num_hidden_layers % scale_factor % 3 == 0
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
|
||||
def gen_rope(depth: int):
|
||||
hidden_size = config.hidden_size // (16**depth)
|
||||
hacky_config = copy.copy(config)
|
||||
hacky_config.hidden_size = hidden_size
|
||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (
|
||||
16**depth
|
||||
)
|
||||
hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
|
||||
4**depth
|
||||
)
|
||||
return DINOv3ViTRopePositionEmbedding(hacky_config)
|
||||
|
||||
self.decode_ropes = nn.ModuleList([gen_rope(i + 1) for i in range(2)])
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
# self.y_embedder = LabelEmbedder(
|
||||
# num_classes, config.hidden_size, config.drop_path_rate
|
||||
# ) # disable cond for now
|
||||
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, True)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
[DinoEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.residual_upscaler = ResidualUpscaler(
|
||||
config,
|
||||
[0, 0, 1, 1, 2, 2], # hardcoded, sorry
|
||||
bottleneck_dim=128,
|
||||
)
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, False)
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
nn.ModuleList(
|
||||
[
|
||||
DinoDecoderLayer(config, depth)
|
||||
for _ in range((config.num_hidden_layers // scale_factor) // 3)
|
||||
]
|
||||
)
|
||||
for depth in range(3)
|
||||
]
|
||||
)
|
||||
self.residual_merger = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size)
|
||||
nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
config.hidden_size // (16**depth),
|
||||
2 * config.hidden_size // (16**depth),
|
||||
),
|
||||
)
|
||||
for _ in range((config.num_hidden_layers // scale_factor) // 3)
|
||||
]
|
||||
)
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
for depth in range(3)
|
||||
]
|
||||
)
|
||||
self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)])
|
||||
self.rest_decoder = nn.ModuleList(
|
||||
[DinoConditionedLayer(config, False) for _ in range(4)]
|
||||
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(
|
||||
(config.hidden_size // (16**2)), eps=config.layer_norm_eps
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
# freeze pretrained
|
||||
@@ -333,7 +501,7 @@ class UTransformer(nn.Module):
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
residual.append(x[:, 1 + self.config.num_register_tokens :])
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
@@ -343,69 +511,66 @@ class UTransformer(nn.Module):
|
||||
)
|
||||
|
||||
x = self.encoder_norm(x)
|
||||
x = x[:, 1 + self.config.num_register_tokens :]
|
||||
reversed_residual = self.residual_upscaler(pixel_values, residual[::-1])
|
||||
|
||||
reversed_residual = residual[::-1]
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
residual_idx = 0
|
||||
for depth, layers in enumerate(self.decoder_layers):
|
||||
for i, layer_module in enumerate(layers): # type: ignore
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
shift, scale = self.residual_merger[depth][i]( # type: ignore
|
||||
reversed_residual[residual_idx]
|
||||
).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
residual_idx += 1
|
||||
x = (
|
||||
rearrange(
|
||||
self.upsample[depth](
|
||||
rearrange(
|
||||
x,
|
||||
"b (h w) d -> b d h w",
|
||||
h=pixel_values.shape[-2]
|
||||
// (self.config.patch_size)
|
||||
* (4**depth),
|
||||
)
|
||||
),
|
||||
"b d h w -> b (h w) d",
|
||||
)
|
||||
if depth != 2
|
||||
else x
|
||||
)
|
||||
shift, scale = self.residual_merger[i](reversed_residual[i]).chunk(
|
||||
2, dim=-1
|
||||
position_embeddings = (
|
||||
self.decode_ropes[depth](
|
||||
torch.zeros(
|
||||
(
|
||||
1,
|
||||
1,
|
||||
pixel_values.shape[-2] * (4 ** (depth + 1)),
|
||||
pixel_values.shape[-1] * (4 ** (depth + 1)),
|
||||
),
|
||||
device=x.device,
|
||||
).to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
)
|
||||
if depth != 2
|
||||
else position_embeddings
|
||||
)
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
for i, layer_module in enumerate(self.rest_decoder):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
def get_residual(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: Optional[torch.Tensor],
|
||||
do_condition: bool,
|
||||
):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=None)
|
||||
|
||||
if do_condition:
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
conditioning_input = t.to(x.dtype)
|
||||
else:
|
||||
conditioning_input = None
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=do_condition,
|
||||
)
|
||||
|
||||
return residual
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
|
||||
Reference in New Issue
Block a user