v2; a lot of hacks
This commit is contained in:
@@ -1,19 +1,41 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.attention import CrossAttention, PlainAttention, RoPE
|
||||
from src.model.dino import (
|
||||
DINOv3ViTAttention,
|
||||
DINOv3ViTDropPath,
|
||||
DINOv3ViTEmbeddings,
|
||||
DINOv3ViTGatedMLP,
|
||||
DINOv3ViTLayer,
|
||||
DINOv3ViTLayerScale,
|
||||
DINOv3ViTMLP,
|
||||
DINOv3ViTRopePositionEmbedding,
|
||||
)
|
||||
from src.model.dit import modulate
|
||||
from src.model.resnet import ResBlock
|
||||
|
||||
|
||||
def create_coordinate(h, w, start=0, end=1, device="cuda:1", dtype=torch.float32):
|
||||
# Create a grid of coordinates
|
||||
x = torch.linspace(start, end, h, device=device, dtype=dtype)
|
||||
y = torch.linspace(start, end, w, device=device, dtype=dtype)
|
||||
# Create a 2D map using meshgrid
|
||||
xx, yy = torch.meshgrid(x, y, indexing="ij")
|
||||
# Stack the x and y coordinates to create the final map
|
||||
coord_map = torch.stack([xx, yy], dim=-1)[None, ...]
|
||||
coords = rearrange(coord_map, "b h w c -> b (h w) c", h=h, w=w)
|
||||
return coords
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
@@ -76,24 +98,21 @@ class LabelEmbedder(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
|
||||
class DinoEncoderLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig):
|
||||
super().__init__(config)
|
||||
self.is_encoder = is_encoder
|
||||
|
||||
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cond = DINOv3ViTAttention(config)
|
||||
self.layer_scale_cond = DINOv3ViTLayerScale(config)
|
||||
|
||||
# no init zeros!
|
||||
if is_encoder:
|
||||
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
|
||||
self.norm1.requires_grad_(False)
|
||||
self.norm2.requires_grad_(False)
|
||||
self.attention.requires_grad_(False)
|
||||
self.mlp.requires_grad_(False)
|
||||
self.layer_scale1.requires_grad_(False)
|
||||
self.layer_scale2.requires_grad_(False)
|
||||
nn.init.constant_(self.layer_scale_cond.lambda1, 0)
|
||||
self.norm1.requires_grad_(False)
|
||||
self.norm2.requires_grad_(False)
|
||||
self.attention.requires_grad_(False)
|
||||
self.mlp.requires_grad_(False)
|
||||
self.layer_scale1.requires_grad_(False)
|
||||
self.layer_scale2.requires_grad_(False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -102,79 +121,257 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
conditioning_input: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
do_condition: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert position_embeddings is not None
|
||||
assert conditioning_input is not None or not do_condition
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
|
||||
if do_condition:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_cond(hidden_states)
|
||||
hidden_states = self.cond(
|
||||
hidden_states,
|
||||
conditioning_input,
|
||||
hidden_states = (
|
||||
self.drop_path(
|
||||
self.layer_scale1(
|
||||
self.attention(
|
||||
self.norm1(hidden_states),
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
)
|
||||
)
|
||||
hidden_states = self.layer_scale_cond(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.layer_scale2(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
hidden_states = (
|
||||
self.drop_path(
|
||||
self.layer_scale_cond(
|
||||
self.cond(
|
||||
hidden_states,
|
||||
self.norm_cond(conditioning_input),
|
||||
)
|
||||
)
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
hidden_states = (
|
||||
self.drop_path(self.layer_scale2(self.mlp(self.norm2(hidden_states))))
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# class DinoV3ViTDecoder(nn.Module):
|
||||
# def __init__(self, config: DINOv3ViTConfig):
|
||||
# super().__init__()
|
||||
# self.config = config
|
||||
# self.num_channels_out = config.num_channels
|
||||
class DinoDecoderLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, depth: int):
|
||||
super().__init__(config)
|
||||
|
||||
# self.projection = nn.Linear(
|
||||
# config.hidden_size,
|
||||
# self.num_channels_out * config.patch_size * config.patch_size,
|
||||
# bias=True,
|
||||
# )
|
||||
hidden_size = config.hidden_size // (16**depth)
|
||||
hacky_config = copy.copy(config)
|
||||
hacky_config.hidden_size = hidden_size
|
||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth)
|
||||
|
||||
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
# batch_size = x.shape[0]
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = PlainAttention(
|
||||
hidden_size, config.num_attention_heads // (4**depth)
|
||||
) # head scaling law?
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
|
||||
self.drop_path = (
|
||||
DINOv3ViTDropPath(hacky_config.drop_path_rate)
|
||||
if hacky_config.drop_path_rate > 0.0
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
# num_special_tokens = 1 + self.config.num_register_tokens
|
||||
# patch_tokens = x[:, num_special_tokens:, :]
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=hacky_config.layer_norm_eps)
|
||||
|
||||
# projected_tokens = self.projection(patch_tokens)
|
||||
if config.use_gated_mlp:
|
||||
self.mlp = DINOv3ViTGatedMLP(hacky_config)
|
||||
else:
|
||||
self.mlp = DINOv3ViTMLP(hacky_config)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(hacky_config)
|
||||
|
||||
# p = self.config.patch_size
|
||||
# c = self.num_channels_out
|
||||
# h_grid = image_size[0] // p
|
||||
# w_grid = image_size[1] // p
|
||||
# adaln
|
||||
self.adaln = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(config.hidden_size, 6 * hidden_size, bias=True),
|
||||
)
|
||||
|
||||
# assert patch_tokens.shape[1] == h_grid * w_grid, (
|
||||
# "Number of patches does not match image size."
|
||||
# )
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
conditioning_input: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
conditioning_input = conditioning_input.squeeze(1) # type: ignore
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln(
|
||||
conditioning_input
|
||||
).chunk(6, dim=-1)
|
||||
|
||||
# x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
|
||||
hidden_states = (
|
||||
self.drop_path(
|
||||
gate_msa.unsqueeze(1)
|
||||
* self.layer_scale1(
|
||||
self.attention(
|
||||
modulate(self.norm1(hidden_states), shift_msa, scale_msa),
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
)
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
# x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
|
||||
hidden_states = (
|
||||
self.drop_path(
|
||||
gate_mlp.unsqueeze(1)
|
||||
* self.layer_scale2(
|
||||
self.mlp(modulate(self.norm2(hidden_states), shift_mlp, scale_mlp))
|
||||
)
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
# reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
|
||||
return hidden_states
|
||||
|
||||
# return reconstructed_image
|
||||
|
||||
# lets try conv decoder
|
||||
class ResidualUpscaler(nn.Module):
|
||||
def __init__(
|
||||
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
|
||||
): # max depth 2 (4**2 = 16 = patch size)
|
||||
super().__init__()
|
||||
|
||||
def build_encoder(in_dim, num_layers=2):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_dim,
|
||||
bottleneck_dim,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
padding_mode="reflect",
|
||||
bias=False,
|
||||
),
|
||||
*[
|
||||
ResBlock(
|
||||
bottleneck_dim,
|
||||
bottleneck_dim,
|
||||
kernel_size=1,
|
||||
num_groups=8,
|
||||
pad_mode="reflect",
|
||||
norm_fn=nn.GroupNorm,
|
||||
activation_fn=nn.SiLU,
|
||||
use_conv_shortcut=False,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
],
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.depth = depth
|
||||
self.global_encode = nn.Linear(
|
||||
config.hidden_size * len(depth),
|
||||
bottleneck_dim * 2,
|
||||
)
|
||||
self.local_encode = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(config.hidden_size, bottleneck_dim * 2)
|
||||
if d != 0
|
||||
else nn.Identity()
|
||||
for d in depth
|
||||
]
|
||||
)
|
||||
self.q_norm = nn.ModuleList(
|
||||
[nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth]
|
||||
)
|
||||
self.k_norm = nn.ModuleList(
|
||||
[nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth]
|
||||
)
|
||||
self.v_downsample = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(config.hidden_size, config.hidden_size // (16**d))
|
||||
if d != 0
|
||||
else nn.Identity()
|
||||
for d in depth
|
||||
]
|
||||
)
|
||||
self.cross_attn = nn.ModuleList(
|
||||
[
|
||||
CrossAttention(
|
||||
bottleneck_dim,
|
||||
bottleneck_dim,
|
||||
config.hidden_size // (16**d),
|
||||
)
|
||||
if d != 0
|
||||
else nn.Identity()
|
||||
for d in depth
|
||||
]
|
||||
)
|
||||
|
||||
self.image_encoder = build_encoder(3)
|
||||
self.q_encoder = build_encoder(bottleneck_dim)
|
||||
self.k_encoder = build_encoder(bottleneck_dim)
|
||||
|
||||
self.rope = RoPE(bottleneck_dim)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
||||
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
|
||||
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
||||
assert self.config.patch_size is not None
|
||||
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
||||
|
||||
global_shift, global_scale = self.global_encode(
|
||||
einops.rearrange(
|
||||
torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)"
|
||||
)
|
||||
).chunk(2, dim=-1) # patch-level global btw
|
||||
|
||||
image_residual = self.image_encoder(pixel_values) # messy; todo: cleanup
|
||||
coords = create_coordinate(pixel_values.shape[-2], pixel_values.shape[-1])
|
||||
image_residual = rearrange(image_residual, "b c h w -> b (h w) c")
|
||||
image_residual = self.rope(image_residual, coords)
|
||||
image_residual = rearrange(
|
||||
image_residual, "b (h w) c -> b c h w", h=pixel_values.shape[-2]
|
||||
)
|
||||
q = self.q_encoder(image_residual)
|
||||
k = self.k_encoder(image_residual)
|
||||
|
||||
reformed_residual = []
|
||||
for i, (depth, residual) in enumerate(zip(self.depth, residuals)):
|
||||
if depth == 0:
|
||||
reformed_residual.append(residual)
|
||||
continue
|
||||
|
||||
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
||||
local_q = self.q_norm[i](
|
||||
einops.rearrange(
|
||||
F.adaptive_avg_pool2d(
|
||||
q,
|
||||
output_size=(
|
||||
image_h // self.config.patch_size * (4**depth),
|
||||
image_w // self.config.patch_size * (4**depth),
|
||||
),
|
||||
),
|
||||
"b c h w -> b (h w) c",
|
||||
)
|
||||
)
|
||||
local_k = (1 + local_scale) * (
|
||||
(1 + global_scale)
|
||||
* self.k_norm[i](
|
||||
einops.rearrange(
|
||||
F.adaptive_avg_pool2d(
|
||||
k,
|
||||
output_size=(
|
||||
image_h // self.config.patch_size,
|
||||
image_w // self.config.patch_size,
|
||||
),
|
||||
),
|
||||
"b c h w -> b (h w) c",
|
||||
)
|
||||
)
|
||||
+ global_shift
|
||||
) + local_shift
|
||||
local_v = self.v_downsample[i](residual)
|
||||
|
||||
reformed_residual.append(self.cross_attn[i](local_q, local_k, local_v))
|
||||
|
||||
return reformed_residual
|
||||
|
||||
|
||||
class DinoV3ViTDecoder(nn.Module):
|
||||
@@ -185,9 +382,8 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size, config.num_channels * (self.patch_size**2), bias=True
|
||||
config.hidden_size // 16 // 16, config.num_channels, bias=True
|
||||
)
|
||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||
|
||||
nn.init.zeros_(self.projection.weight)
|
||||
nn.init.zeros_(
|
||||
@@ -195,69 +391,11 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
) if self.projection.bias is not None else None
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
p = self.config.patch_size
|
||||
h_grid = image_size[0] // p
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert x.shape[1] == h_grid * w_grid
|
||||
x = self.projection(x)
|
||||
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
||||
x = self.pixel_shuffle(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# how about transposed conv decoderclass DinoV3ViTDecoder(nn.Module):
|
||||
# class DinoV3ViTDecoder(nn.Module):
|
||||
# def __init__(self, config: DINOv3ViTConfig):
|
||||
# super().__init__()
|
||||
# self.config = config
|
||||
# self.num_channels_out = config.num_channels
|
||||
# self.patch_size = config.patch_size
|
||||
|
||||
# intermediate_channels = config.hidden_size // 4
|
||||
|
||||
# self.decoder_block = nn.Sequential(
|
||||
# nn.ConvTranspose2d(
|
||||
# in_channels=config.hidden_size,
|
||||
# out_channels=intermediate_channels,
|
||||
# kernel_size=self.patch_size,
|
||||
# stride=self.patch_size,
|
||||
# bias=True,
|
||||
# ),
|
||||
# nn.LayerNorm(intermediate_channels),
|
||||
# nn.GELU(),
|
||||
# nn.Conv2d(
|
||||
# in_channels=intermediate_channels,
|
||||
# out_channels=config.num_channels,
|
||||
# kernel_size=1,
|
||||
# bias=True,
|
||||
# ),
|
||||
# )
|
||||
|
||||
# nn.init.zeros_(self.decoder_block[-1].weight)
|
||||
# nn.init.zeros_(self.decoder_block[-1].bias)
|
||||
|
||||
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
# batch_size = x.shape[0]
|
||||
|
||||
# x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
# p = self.config.patch_size
|
||||
# h_grid = image_size[0] // p
|
||||
# w_grid = image_size[1] // p
|
||||
# assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
# x = x.transpose(1, 2).reshape(
|
||||
# batch_size, self.config.hidden_size, h_grid, w_grid
|
||||
# )
|
||||
# x = self.decoder_block(x)
|
||||
|
||||
# return x
|
||||
return self.projection(
|
||||
einops.rearrange(
|
||||
x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1]
|
||||
)
|
||||
).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
@@ -268,41 +406,71 @@ class UTransformer(nn.Module):
|
||||
self.config = config
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
assert config.num_hidden_layers % scale_factor == 0
|
||||
assert config.num_hidden_layers % scale_factor % 3 == 0
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
|
||||
def gen_rope(depth: int):
|
||||
hidden_size = config.hidden_size // (16**depth)
|
||||
hacky_config = copy.copy(config)
|
||||
hacky_config.hidden_size = hidden_size
|
||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (
|
||||
16**depth
|
||||
)
|
||||
hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
|
||||
4**depth
|
||||
)
|
||||
return DINOv3ViTRopePositionEmbedding(hacky_config)
|
||||
|
||||
self.decode_ropes = nn.ModuleList([gen_rope(i + 1) for i in range(2)])
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
# self.y_embedder = LabelEmbedder(
|
||||
# num_classes, config.hidden_size, config.drop_path_rate
|
||||
# ) # disable cond for now
|
||||
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, True)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
[DinoEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.residual_upscaler = ResidualUpscaler(
|
||||
config,
|
||||
[0, 0, 1, 1, 2, 2], # hardcoded, sorry
|
||||
bottleneck_dim=128,
|
||||
)
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, False)
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
nn.ModuleList(
|
||||
[
|
||||
DinoDecoderLayer(config, depth)
|
||||
for _ in range((config.num_hidden_layers // scale_factor) // 3)
|
||||
]
|
||||
)
|
||||
for depth in range(3)
|
||||
]
|
||||
)
|
||||
self.residual_merger = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size)
|
||||
nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
config.hidden_size // (16**depth),
|
||||
2 * config.hidden_size // (16**depth),
|
||||
),
|
||||
)
|
||||
for _ in range((config.num_hidden_layers // scale_factor) // 3)
|
||||
]
|
||||
)
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
for depth in range(3)
|
||||
]
|
||||
)
|
||||
self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)])
|
||||
self.rest_decoder = nn.ModuleList(
|
||||
[DinoConditionedLayer(config, False) for _ in range(4)]
|
||||
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(
|
||||
(config.hidden_size // (16**2)), eps=config.layer_norm_eps
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
# freeze pretrained
|
||||
@@ -333,7 +501,7 @@ class UTransformer(nn.Module):
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
residual.append(x[:, 1 + self.config.num_register_tokens :])
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
@@ -343,69 +511,66 @@ class UTransformer(nn.Module):
|
||||
)
|
||||
|
||||
x = self.encoder_norm(x)
|
||||
x = x[:, 1 + self.config.num_register_tokens :]
|
||||
reversed_residual = self.residual_upscaler(pixel_values, residual[::-1])
|
||||
|
||||
reversed_residual = residual[::-1]
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
residual_idx = 0
|
||||
for depth, layers in enumerate(self.decoder_layers):
|
||||
for i, layer_module in enumerate(layers): # type: ignore
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
shift, scale = self.residual_merger[depth][i]( # type: ignore
|
||||
reversed_residual[residual_idx]
|
||||
).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
residual_idx += 1
|
||||
x = (
|
||||
rearrange(
|
||||
self.upsample[depth](
|
||||
rearrange(
|
||||
x,
|
||||
"b (h w) d -> b d h w",
|
||||
h=pixel_values.shape[-2]
|
||||
// (self.config.patch_size)
|
||||
* (4**depth),
|
||||
)
|
||||
),
|
||||
"b d h w -> b (h w) d",
|
||||
)
|
||||
if depth != 2
|
||||
else x
|
||||
)
|
||||
shift, scale = self.residual_merger[i](reversed_residual[i]).chunk(
|
||||
2, dim=-1
|
||||
position_embeddings = (
|
||||
self.decode_ropes[depth](
|
||||
torch.zeros(
|
||||
(
|
||||
1,
|
||||
1,
|
||||
pixel_values.shape[-2] * (4 ** (depth + 1)),
|
||||
pixel_values.shape[-1] * (4 ** (depth + 1)),
|
||||
),
|
||||
device=x.device,
|
||||
).to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
)
|
||||
if depth != 2
|
||||
else position_embeddings
|
||||
)
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
for i, layer_module in enumerate(self.rest_decoder):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
def get_residual(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: Optional[torch.Tensor],
|
||||
do_condition: bool,
|
||||
):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=None)
|
||||
|
||||
if do_condition:
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
conditioning_input = t.to(x.dtype)
|
||||
else:
|
||||
conditioning_input = None
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=do_condition,
|
||||
)
|
||||
|
||||
return residual
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
|
||||
Reference in New Issue
Block a user