v2; a lot of hacks

This commit is contained in:
neulus
2025-10-15 18:02:10 +09:00
parent e51017897d
commit 6ab33ceb83
7 changed files with 586 additions and 199 deletions

113
src/model/attention.py Normal file
View File

@@ -0,0 +1,113 @@
# GENERAL attention with flash attention flavor
from typing import Optional, Tuple
import einops
import torch
import torch.nn as nn
from flash_attn import flash_attn_func
from src.model.dino import apply_rotary_pos_emb, rotate_half
class PlainAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout_rate: float = 0.1,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
proj_bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim // self.num_heads
self.dropout = dropout_rate
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=query_bias)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=key_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=value_bias)
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=proj_bias)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor | None = None,
v: torch.Tensor | None = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
k = k if k is not None else q
v = v if v is not None else q
assert k is not None and v is not None
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
q = einops.rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
k = einops.rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
v = einops.rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
if position_embeddings is not None:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin) # type: ignore
# flash attn handles sftmx scaling too
o = flash_attn_func(
q.transpose(1, 2),
k.transpose(1, 2), # type: ignore
v.transpose(1, 2), # type: ignore
dropout_p=self.dropout if self.training else 0.0,
causal=False,
)
return self.o_proj(einops.rearrange(o, "b s h d -> b s (h d)").contiguous()) # type: ignore
class CrossAttention(nn.Module):
def __init__(self, query_dim: int, key_dim: int, value_dim: int):
super().__init__()
self.attention = nn.MultiheadAttention(
embed_dim=query_dim,
vdim=value_dim,
num_heads=1,
dropout=0.0,
batch_first=True,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
):
_, attn_score = self.attention(q, k, v, average_attn_weights=True)
return torch.einsum("b i j, b j d -> b i d", attn_score, v)
class RoPE(nn.Module):
def __init__(
self,
dim: int,
theta: int = 100,
):
super().__init__()
self.dim = dim
self.theta = theta
self.freqs = nn.Parameter(torch.empty(2, self.dim))
def _device_weight_init(self):
# Create freqs in 1d
freqs_1d = self.theta ** torch.linspace(0, -1, self.dim // 4)
# duplicate freqs for rotation pairs of channels
freqs_1d = torch.cat([freqs_1d, freqs_1d])
# First half of channels do x, second half do y
freqs_2d = torch.zeros(2, self.dim)
freqs_2d[0, : self.dim // 2] = freqs_1d
freqs_2d[1, -self.dim // 2 :] = freqs_1d
# it's an angular freq here
self.freqs.data.copy_(freqs_2d * 2 * torch.pi)
def forward(self, x: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
angle = coords @ self.freqs
return x * angle.cos() + rotate_half(x) * angle.sin()

View File

@@ -235,7 +235,7 @@ class DINOv3ViTAttention(nn.Module):
key_states,
value_states,
dropout_p=dropout_p,
softmax_scale=None, # Will use default 1/sqrt(headdim)
softmax_scale=None,
causal=False,
)

63
src/model/resnet.py Normal file
View File

@@ -0,0 +1,63 @@
import torch.nn as nn
class ResBlock(nn.Module):
"""Basic Residual Block, adapted from magvit1"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
num_groups=8,
pad_mode="zeros",
norm_fn=None,
activation_fn=nn.SiLU,
use_conv_shortcut=False,
):
super(ResBlock, self).__init__()
self.use_conv_shortcut = use_conv_shortcut
self.norm1 = (
norm_fn(num_groups, in_channels) if norm_fn is not None else nn.Identity()
)
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
padding_mode=pad_mode,
bias=False,
)
self.norm2 = (
norm_fn(num_groups, out_channels) if norm_fn is not None else nn.Identity()
)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
padding_mode=pad_mode,
bias=False,
)
self.activation_fn = activation_fn()
if in_channels != out_channels:
self.shortcut = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
padding=0,
padding_mode=pad_mode,
bias=False,
)
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.activation_fn(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.activation_fn(x)
x = self.conv2(x)
if self.use_conv_shortcut or residual.shape != x.shape:
residual = self.shortcut(residual)
return x + residual

View File

@@ -1,19 +1,41 @@
import copy
import math
from typing import Optional
import einops
import torch
import torch.nn.functional as F
from einops import rearrange
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torch import nn
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
from src.model.attention import CrossAttention, PlainAttention, RoPE
from src.model.dino import (
DINOv3ViTAttention,
DINOv3ViTDropPath,
DINOv3ViTEmbeddings,
DINOv3ViTGatedMLP,
DINOv3ViTLayer,
DINOv3ViTLayerScale,
DINOv3ViTMLP,
DINOv3ViTRopePositionEmbedding,
)
from src.model.dit import modulate
from src.model.resnet import ResBlock
def create_coordinate(h, w, start=0, end=1, device="cuda:1", dtype=torch.float32):
# Create a grid of coordinates
x = torch.linspace(start, end, h, device=device, dtype=dtype)
y = torch.linspace(start, end, w, device=device, dtype=dtype)
# Create a 2D map using meshgrid
xx, yy = torch.meshgrid(x, y, indexing="ij")
# Stack the x and y coordinates to create the final map
coord_map = torch.stack([xx, yy], dim=-1)[None, ...]
coords = rearrange(coord_map, "b h w c -> b (h w) c", h=h, w=w)
return coords
class TimestepEmbedder(nn.Module):
@@ -76,24 +98,21 @@ class LabelEmbedder(nn.Module):
return embeddings
class DinoConditionedLayer(DINOv3ViTLayer):
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
class DinoEncoderLayer(DINOv3ViTLayer):
def __init__(self, config: DINOv3ViTConfig):
super().__init__(config)
self.is_encoder = is_encoder
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cond = DINOv3ViTAttention(config)
self.layer_scale_cond = DINOv3ViTLayerScale(config)
# no init zeros!
if is_encoder:
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
self.norm1.requires_grad_(False)
self.norm2.requires_grad_(False)
self.attention.requires_grad_(False)
self.mlp.requires_grad_(False)
self.layer_scale1.requires_grad_(False)
self.layer_scale2.requires_grad_(False)
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
self.norm1.requires_grad_(False)
self.norm2.requires_grad_(False)
self.attention.requires_grad_(False)
self.mlp.requires_grad_(False)
self.layer_scale1.requires_grad_(False)
self.layer_scale2.requires_grad_(False)
def forward(
self,
@@ -102,79 +121,257 @@ class DinoConditionedLayer(DINOv3ViTLayer):
conditioning_input: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
do_condition: bool = True,
**kwargs,
) -> torch.Tensor:
assert position_embeddings is not None
assert conditioning_input is not None or not do_condition
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attention(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
)
hidden_states = self.layer_scale1(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
if do_condition:
residual = hidden_states
hidden_states = self.norm_cond(hidden_states)
hidden_states = self.cond(
hidden_states,
conditioning_input,
hidden_states = (
self.drop_path(
self.layer_scale1(
self.attention(
self.norm1(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings,
)
)
)
hidden_states = self.layer_scale_cond(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
+ hidden_states
)
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.layer_scale2(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
hidden_states = (
self.drop_path(
self.layer_scale_cond(
self.cond(
hidden_states,
self.norm_cond(conditioning_input),
)
)
)
+ hidden_states
)
hidden_states = (
self.drop_path(self.layer_scale2(self.mlp(self.norm2(hidden_states))))
+ hidden_states
)
return hidden_states
# class DinoV3ViTDecoder(nn.Module):
# def __init__(self, config: DINOv3ViTConfig):
# super().__init__()
# self.config = config
# self.num_channels_out = config.num_channels
class DinoDecoderLayer(DINOv3ViTLayer):
def __init__(self, config: DINOv3ViTConfig, depth: int):
super().__init__(config)
# self.projection = nn.Linear(
# config.hidden_size,
# self.num_channels_out * config.patch_size * config.patch_size,
# bias=True,
# )
hidden_size = config.hidden_size // (16**depth)
hacky_config = copy.copy(config)
hacky_config.hidden_size = hidden_size
hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth)
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
# batch_size = x.shape[0]
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.attention = PlainAttention(
hidden_size, config.num_attention_heads // (4**depth)
) # head scaling law?
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
self.drop_path = (
DINOv3ViTDropPath(hacky_config.drop_path_rate)
if hacky_config.drop_path_rate > 0.0
else nn.Identity()
)
# num_special_tokens = 1 + self.config.num_register_tokens
# patch_tokens = x[:, num_special_tokens:, :]
self.norm2 = nn.LayerNorm(hidden_size, eps=hacky_config.layer_norm_eps)
# projected_tokens = self.projection(patch_tokens)
if config.use_gated_mlp:
self.mlp = DINOv3ViTGatedMLP(hacky_config)
else:
self.mlp = DINOv3ViTMLP(hacky_config)
self.layer_scale2 = DINOv3ViTLayerScale(hacky_config)
# p = self.config.patch_size
# c = self.num_channels_out
# h_grid = image_size[0] // p
# w_grid = image_size[1] // p
# adaln
self.adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(config.hidden_size, 6 * hidden_size, bias=True),
)
# assert patch_tokens.shape[1] == h_grid * w_grid, (
# "Number of patches does not match image size."
# )
def forward(
self,
hidden_states: torch.Tensor,
*,
conditioning_input: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
conditioning_input = conditioning_input.squeeze(1) # type: ignore
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln(
conditioning_input
).chunk(6, dim=-1)
# x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
hidden_states = (
self.drop_path(
gate_msa.unsqueeze(1)
* self.layer_scale1(
self.attention(
modulate(self.norm1(hidden_states), shift_msa, scale_msa),
position_embeddings=position_embeddings,
)
)
)
+ hidden_states
)
# x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
hidden_states = (
self.drop_path(
gate_mlp.unsqueeze(1)
* self.layer_scale2(
self.mlp(modulate(self.norm2(hidden_states), shift_mlp, scale_mlp))
)
)
+ hidden_states
)
# reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
return hidden_states
# return reconstructed_image
# lets try conv decoder
class ResidualUpscaler(nn.Module):
def __init__(
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
): # max depth 2 (4**2 = 16 = patch size)
super().__init__()
def build_encoder(in_dim, num_layers=2):
return nn.Sequential(
nn.Conv2d(
in_dim,
bottleneck_dim,
kernel_size=1,
padding=0,
padding_mode="reflect",
bias=False,
),
*[
ResBlock(
bottleneck_dim,
bottleneck_dim,
kernel_size=1,
num_groups=8,
pad_mode="reflect",
norm_fn=nn.GroupNorm,
activation_fn=nn.SiLU,
use_conv_shortcut=False,
)
for _ in range(num_layers)
],
)
self.config = config
self.depth = depth
self.global_encode = nn.Linear(
config.hidden_size * len(depth),
bottleneck_dim * 2,
)
self.local_encode = nn.ModuleList(
[
nn.Linear(config.hidden_size, bottleneck_dim * 2)
if d != 0
else nn.Identity()
for d in depth
]
)
self.q_norm = nn.ModuleList(
[nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth]
)
self.k_norm = nn.ModuleList(
[nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth]
)
self.v_downsample = nn.ModuleList(
[
nn.Linear(config.hidden_size, config.hidden_size // (16**d))
if d != 0
else nn.Identity()
for d in depth
]
)
self.cross_attn = nn.ModuleList(
[
CrossAttention(
bottleneck_dim,
bottleneck_dim,
config.hidden_size // (16**d),
)
if d != 0
else nn.Identity()
for d in depth
]
)
self.image_encoder = build_encoder(3)
self.q_encoder = build_encoder(bottleneck_dim)
self.k_encoder = build_encoder(bottleneck_dim)
self.rope = RoPE(bottleneck_dim)
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
assert self.config.patch_size is not None
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
global_shift, global_scale = self.global_encode(
einops.rearrange(
torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)"
)
).chunk(2, dim=-1) # patch-level global btw
image_residual = self.image_encoder(pixel_values) # messy; todo: cleanup
coords = create_coordinate(pixel_values.shape[-2], pixel_values.shape[-1])
image_residual = rearrange(image_residual, "b c h w -> b (h w) c")
image_residual = self.rope(image_residual, coords)
image_residual = rearrange(
image_residual, "b (h w) c -> b c h w", h=pixel_values.shape[-2]
)
q = self.q_encoder(image_residual)
k = self.k_encoder(image_residual)
reformed_residual = []
for i, (depth, residual) in enumerate(zip(self.depth, residuals)):
if depth == 0:
reformed_residual.append(residual)
continue
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
local_q = self.q_norm[i](
einops.rearrange(
F.adaptive_avg_pool2d(
q,
output_size=(
image_h // self.config.patch_size * (4**depth),
image_w // self.config.patch_size * (4**depth),
),
),
"b c h w -> b (h w) c",
)
)
local_k = (1 + local_scale) * (
(1 + global_scale)
* self.k_norm[i](
einops.rearrange(
F.adaptive_avg_pool2d(
k,
output_size=(
image_h // self.config.patch_size,
image_w // self.config.patch_size,
),
),
"b c h w -> b (h w) c",
)
)
+ global_shift
) + local_shift
local_v = self.v_downsample[i](residual)
reformed_residual.append(self.cross_attn[i](local_q, local_k, local_v))
return reformed_residual
class DinoV3ViTDecoder(nn.Module):
@@ -185,9 +382,8 @@ class DinoV3ViTDecoder(nn.Module):
self.patch_size = config.patch_size
self.projection = nn.Linear(
config.hidden_size, config.num_channels * (self.patch_size**2), bias=True
config.hidden_size // 16 // 16, config.num_channels, bias=True
)
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
nn.init.zeros_(self.projection.weight)
nn.init.zeros_(
@@ -195,69 +391,11 @@ class DinoV3ViTDecoder(nn.Module):
) if self.projection.bias is not None else None
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
batch_size = x.shape[0]
x = x[:, 1 + self.config.num_register_tokens :, :]
p = self.config.patch_size
h_grid = image_size[0] // p
w_grid = image_size[1] // p
assert x.shape[1] == h_grid * w_grid
x = self.projection(x)
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
x = self.pixel_shuffle(x)
return x
# how about transposed conv decoderclass DinoV3ViTDecoder(nn.Module):
# class DinoV3ViTDecoder(nn.Module):
# def __init__(self, config: DINOv3ViTConfig):
# super().__init__()
# self.config = config
# self.num_channels_out = config.num_channels
# self.patch_size = config.patch_size
# intermediate_channels = config.hidden_size // 4
# self.decoder_block = nn.Sequential(
# nn.ConvTranspose2d(
# in_channels=config.hidden_size,
# out_channels=intermediate_channels,
# kernel_size=self.patch_size,
# stride=self.patch_size,
# bias=True,
# ),
# nn.LayerNorm(intermediate_channels),
# nn.GELU(),
# nn.Conv2d(
# in_channels=intermediate_channels,
# out_channels=config.num_channels,
# kernel_size=1,
# bias=True,
# ),
# )
# nn.init.zeros_(self.decoder_block[-1].weight)
# nn.init.zeros_(self.decoder_block[-1].bias)
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
# batch_size = x.shape[0]
# x = x[:, 1 + self.config.num_register_tokens :, :]
# p = self.config.patch_size
# h_grid = image_size[0] // p
# w_grid = image_size[1] // p
# assert x.shape[1] == h_grid * w_grid
# x = x.transpose(1, 2).reshape(
# batch_size, self.config.hidden_size, h_grid, w_grid
# )
# x = self.decoder_block(x)
# return x
return self.projection(
einops.rearrange(
x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1]
)
).permute(0, 3, 1, 2)
class UTransformer(nn.Module):
@@ -268,41 +406,71 @@ class UTransformer(nn.Module):
self.config = config
self.scale_factor = scale_factor
assert config.num_hidden_layers % scale_factor == 0
assert config.num_hidden_layers % scale_factor % 3 == 0
self.embeddings = DINOv3ViTEmbeddings(config)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
def gen_rope(depth: int):
hidden_size = config.hidden_size // (16**depth)
hacky_config = copy.copy(config)
hacky_config.hidden_size = hidden_size
hacky_config.intermediate_size = hacky_config.intermediate_size // (
16**depth
)
hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
4**depth
)
return DINOv3ViTRopePositionEmbedding(hacky_config)
self.decode_ropes = nn.ModuleList([gen_rope(i + 1) for i in range(2)])
self.t_embedder = TimestepEmbedder(config.hidden_size)
# self.y_embedder = LabelEmbedder(
# num_classes, config.hidden_size, config.drop_path_rate
# ) # disable cond for now
self.encoder_layers = nn.ModuleList(
[
DinoConditionedLayer(config, True)
for _ in range(config.num_hidden_layers)
]
[DinoEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.residual_upscaler = ResidualUpscaler(
config,
[0, 0, 1, 1, 2, 2], # hardcoded, sorry
bottleneck_dim=128,
)
self.decoder_layers = nn.ModuleList(
[
DinoConditionedLayer(config, False)
for _ in range(config.num_hidden_layers // scale_factor)
nn.ModuleList(
[
DinoDecoderLayer(config, depth)
for _ in range((config.num_hidden_layers // scale_factor) // 3)
]
)
for depth in range(3)
]
)
self.residual_merger = nn.ModuleList(
[
nn.Sequential(
nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size)
nn.ModuleList(
[
nn.Sequential(
nn.SiLU(),
nn.Linear(
config.hidden_size // (16**depth),
2 * config.hidden_size // (16**depth),
),
)
for _ in range((config.num_hidden_layers // scale_factor) // 3)
]
)
for _ in range(config.num_hidden_layers // scale_factor)
for depth in range(3)
]
)
self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)])
self.rest_decoder = nn.ModuleList(
[DinoConditionedLayer(config, False) for _ in range(4)]
[DinoDecoderLayer(config, 2) for _ in range(4)]
)
self.decoder_norm = nn.LayerNorm(
(config.hidden_size // (16**2)), eps=config.layer_norm_eps
)
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = DinoV3ViTDecoder(config)
# freeze pretrained
@@ -333,7 +501,7 @@ class UTransformer(nn.Module):
residual = []
for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0:
residual.append(x)
residual.append(x[:, 1 + self.config.num_register_tokens :])
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
@@ -343,69 +511,66 @@ class UTransformer(nn.Module):
)
x = self.encoder_norm(x)
x = x[:, 1 + self.config.num_register_tokens :]
reversed_residual = self.residual_upscaler(pixel_values, residual[::-1])
reversed_residual = residual[::-1]
for i, layer_module in enumerate(self.decoder_layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
do_condition=False,
residual_idx = 0
for depth, layers in enumerate(self.decoder_layers):
for i, layer_module in enumerate(layers): # type: ignore
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=None,
position_embeddings=position_embeddings,
)
shift, scale = self.residual_merger[depth][i]( # type: ignore
reversed_residual[residual_idx]
).chunk(2, dim=-1)
x = x * (1 + scale) + shift
residual_idx += 1
x = (
rearrange(
self.upsample[depth](
rearrange(
x,
"b (h w) d -> b d h w",
h=pixel_values.shape[-2]
// (self.config.patch_size)
* (4**depth),
)
),
"b d h w -> b (h w) d",
)
if depth != 2
else x
)
shift, scale = self.residual_merger[i](reversed_residual[i]).chunk(
2, dim=-1
position_embeddings = (
self.decode_ropes[depth](
torch.zeros(
(
1,
1,
pixel_values.shape[-2] * (4 ** (depth + 1)),
pixel_values.shape[-1] * (4 ** (depth + 1)),
),
device=x.device,
).to(self.embeddings.patch_embeddings.weight.dtype)
)
if depth != 2
else position_embeddings
)
x = x * (1 + scale) + shift
for i, layer_module in enumerate(self.rest_decoder):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
do_condition=False,
)
x = self.decoder_norm(x)
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
def get_residual(
self,
pixel_values: torch.Tensor,
time: Optional[torch.Tensor],
do_condition: bool,
):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
x = self.embeddings(pixel_values, bool_masked_pos=None)
if do_condition:
t = self.t_embedder(time).unsqueeze(1)
# y = self.y_embedder(cond, self.training).unsqueeze(1)
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
conditioning_input = t.to(x.dtype)
else:
conditioning_input = None
residual = []
for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0:
residual.append(x)
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=None,
position_embeddings=position_embeddings,
do_condition=do_condition,
)
return residual
x = self.decoder_norm(x)
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
@staticmethod
def from_pretrained_backbone(name: str):