diff --git a/main.py b/main.py index 5892150..f44cd7d 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import math import os +import lovely_tensors as lt import torch import torch.optim as optim from torch.cuda.amp import autocast @@ -14,11 +15,13 @@ from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF +lt.monkey_patch() + train_dataset, test_dataset = get_dataset() device = "cuda:1" -batch_size = 8 * 4 * 2 +batch_size = 16 accumulation_steps = 2 total_epoch = 500 diff --git a/pyproject.toml b/pyproject.toml index cdffa96..e33def7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,8 @@ dependencies = [ "tqdm>=4.67.1", "transformers>=4.56.2", "wandb[media]>=0.22.0", - "flash-attn" + "flash-attn", + "lovely-tensors>=0.1.19", ] [tool.uv.extra-build-dependencies] diff --git a/src/model/attention.py b/src/model/attention.py new file mode 100644 index 0000000..d436f4e --- /dev/null +++ b/src/model/attention.py @@ -0,0 +1,113 @@ +# GENERAL attention with flash attention flavor +from typing import Optional, Tuple + +import einops +import torch +import torch.nn as nn +from flash_attn import flash_attn_func + +from src.model.dino import apply_rotary_pos_emb, rotate_half + + +class PlainAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout_rate: float = 0.1, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + proj_bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = self.embed_dim // self.num_heads + + self.dropout = dropout_rate + + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=query_bias) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=key_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=value_bias) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=proj_bias) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor | None = None, + v: torch.Tensor | None = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + k = k if k is not None else q + v = v if v is not None else q + assert k is not None and v is not None + + q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v) + q = einops.rearrange(q, "b s (h d) -> b h s d", h=self.num_heads) + k = einops.rearrange(k, "b s (h d) -> b h s d", h=self.num_heads) + v = einops.rearrange(v, "b s (h d) -> b h s d", h=self.num_heads) + + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) # type: ignore + + # flash attn handles sftmx scaling too + o = flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), # type: ignore + v.transpose(1, 2), # type: ignore + dropout_p=self.dropout if self.training else 0.0, + causal=False, + ) + + return self.o_proj(einops.rearrange(o, "b s h d -> b s (h d)").contiguous()) # type: ignore + + +class CrossAttention(nn.Module): + def __init__(self, query_dim: int, key_dim: int, value_dim: int): + super().__init__() + self.attention = nn.MultiheadAttention( + embed_dim=query_dim, + vdim=value_dim, + num_heads=1, + dropout=0.0, + batch_first=True, + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ): + _, attn_score = self.attention(q, k, v, average_attn_weights=True) + return torch.einsum("b i j, b j d -> b i d", attn_score, v) + + +class RoPE(nn.Module): + def __init__( + self, + dim: int, + theta: int = 100, + ): + super().__init__() + self.dim = dim + self.theta = theta + self.freqs = nn.Parameter(torch.empty(2, self.dim)) + + def _device_weight_init(self): + # Create freqs in 1d + freqs_1d = self.theta ** torch.linspace(0, -1, self.dim // 4) + # duplicate freqs for rotation pairs of channels + freqs_1d = torch.cat([freqs_1d, freqs_1d]) + # First half of channels do x, second half do y + freqs_2d = torch.zeros(2, self.dim) + freqs_2d[0, : self.dim // 2] = freqs_1d + freqs_2d[1, -self.dim // 2 :] = freqs_1d + # it's an angular freq here + self.freqs.data.copy_(freqs_2d * 2 * torch.pi) + + def forward(self, x: torch.Tensor, coords: torch.Tensor) -> torch.Tensor: + angle = coords @ self.freqs + return x * angle.cos() + rotate_half(x) * angle.sin() diff --git a/src/model/dino.py b/src/model/dino.py index eb915ad..3aa3e36 100644 --- a/src/model/dino.py +++ b/src/model/dino.py @@ -235,7 +235,7 @@ class DINOv3ViTAttention(nn.Module): key_states, value_states, dropout_p=dropout_p, - softmax_scale=None, # Will use default 1/sqrt(headdim) + softmax_scale=None, causal=False, ) diff --git a/src/model/resnet.py b/src/model/resnet.py new file mode 100644 index 0000000..21214fc --- /dev/null +++ b/src/model/resnet.py @@ -0,0 +1,63 @@ +import torch.nn as nn + + +class ResBlock(nn.Module): + """Basic Residual Block, adapted from magvit1""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + num_groups=8, + pad_mode="zeros", + norm_fn=None, + activation_fn=nn.SiLU, + use_conv_shortcut=False, + ): + super(ResBlock, self).__init__() + self.use_conv_shortcut = use_conv_shortcut + self.norm1 = ( + norm_fn(num_groups, in_channels) if norm_fn is not None else nn.Identity() + ) + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode=pad_mode, + bias=False, + ) + self.norm2 = ( + norm_fn(num_groups, out_channels) if norm_fn is not None else nn.Identity() + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode=pad_mode, + bias=False, + ) + self.activation_fn = activation_fn() + if in_channels != out_channels: + self.shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + padding=0, + padding_mode=pad_mode, + bias=False, + ) + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.activation_fn(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.activation_fn(x) + x = self.conv2(x) + if self.use_conv_shortcut or residual.shape != x.shape: + residual = self.shortcut(residual) + return x + residual diff --git a/src/model/utransformer.py b/src/model/utransformer.py index ee156d5..fcfc4a5 100644 --- a/src/model/utransformer.py +++ b/src/model/utransformer.py @@ -1,19 +1,41 @@ +import copy import math from typing import Optional +import einops import torch +import torch.nn.functional as F +from einops import rearrange from huggingface_hub import hf_hub_download from safetensors import safe_open from torch import nn from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig +from src.model.attention import CrossAttention, PlainAttention, RoPE from src.model.dino import ( DINOv3ViTAttention, + DINOv3ViTDropPath, DINOv3ViTEmbeddings, + DINOv3ViTGatedMLP, DINOv3ViTLayer, DINOv3ViTLayerScale, + DINOv3ViTMLP, DINOv3ViTRopePositionEmbedding, ) +from src.model.dit import modulate +from src.model.resnet import ResBlock + + +def create_coordinate(h, w, start=0, end=1, device="cuda:1", dtype=torch.float32): + # Create a grid of coordinates + x = torch.linspace(start, end, h, device=device, dtype=dtype) + y = torch.linspace(start, end, w, device=device, dtype=dtype) + # Create a 2D map using meshgrid + xx, yy = torch.meshgrid(x, y, indexing="ij") + # Stack the x and y coordinates to create the final map + coord_map = torch.stack([xx, yy], dim=-1)[None, ...] + coords = rearrange(coord_map, "b h w c -> b (h w) c", h=h, w=w) + return coords class TimestepEmbedder(nn.Module): @@ -76,24 +98,21 @@ class LabelEmbedder(nn.Module): return embeddings -class DinoConditionedLayer(DINOv3ViTLayer): - def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False): +class DinoEncoderLayer(DINOv3ViTLayer): + def __init__(self, config: DINOv3ViTConfig): super().__init__(config) - self.is_encoder = is_encoder self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.cond = DINOv3ViTAttention(config) self.layer_scale_cond = DINOv3ViTLayerScale(config) - # no init zeros! - if is_encoder: - nn.init.constant_(self.layer_scale_cond.lambda1, 0) - self.norm1.requires_grad_(False) - self.norm2.requires_grad_(False) - self.attention.requires_grad_(False) - self.mlp.requires_grad_(False) - self.layer_scale1.requires_grad_(False) - self.layer_scale2.requires_grad_(False) + nn.init.constant_(self.layer_scale_cond.lambda1, 0) + self.norm1.requires_grad_(False) + self.norm2.requires_grad_(False) + self.attention.requires_grad_(False) + self.mlp.requires_grad_(False) + self.layer_scale1.requires_grad_(False) + self.layer_scale2.requires_grad_(False) def forward( self, @@ -102,79 +121,257 @@ class DinoConditionedLayer(DINOv3ViTLayer): conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - do_condition: bool = True, **kwargs, ) -> torch.Tensor: assert position_embeddings is not None - assert conditioning_input is not None or not do_condition - residual = hidden_states - hidden_states = self.norm1(hidden_states) - hidden_states = self.attention( - hidden_states, - attention_mask=attention_mask, - position_embeddings=position_embeddings, - ) - hidden_states = self.layer_scale1(hidden_states) - hidden_states = self.drop_path(hidden_states) + residual - - if do_condition: - residual = hidden_states - hidden_states = self.norm_cond(hidden_states) - hidden_states = self.cond( - hidden_states, - conditioning_input, + hidden_states = ( + self.drop_path( + self.layer_scale1( + self.attention( + self.norm1(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + ) ) - hidden_states = self.layer_scale_cond(hidden_states) - hidden_states = self.drop_path(hidden_states) + residual + + hidden_states + ) - residual = hidden_states - hidden_states = self.norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = self.layer_scale2(hidden_states) - hidden_states = self.drop_path(hidden_states) + residual + hidden_states = ( + self.drop_path( + self.layer_scale_cond( + self.cond( + hidden_states, + self.norm_cond(conditioning_input), + ) + ) + ) + + hidden_states + ) + + hidden_states = ( + self.drop_path(self.layer_scale2(self.mlp(self.norm2(hidden_states)))) + + hidden_states + ) return hidden_states -# class DinoV3ViTDecoder(nn.Module): -# def __init__(self, config: DINOv3ViTConfig): -# super().__init__() -# self.config = config -# self.num_channels_out = config.num_channels +class DinoDecoderLayer(DINOv3ViTLayer): + def __init__(self, config: DINOv3ViTConfig, depth: int): + super().__init__(config) -# self.projection = nn.Linear( -# config.hidden_size, -# self.num_channels_out * config.patch_size * config.patch_size, -# bias=True, -# ) + hidden_size = config.hidden_size // (16**depth) + hacky_config = copy.copy(config) + hacky_config.hidden_size = hidden_size + hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth) -# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: -# batch_size = x.shape[0] + self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.attention = PlainAttention( + hidden_size, config.num_attention_heads // (4**depth) + ) # head scaling law? + self.layer_scale1 = DINOv3ViTLayerScale(hacky_config) + self.drop_path = ( + DINOv3ViTDropPath(hacky_config.drop_path_rate) + if hacky_config.drop_path_rate > 0.0 + else nn.Identity() + ) -# num_special_tokens = 1 + self.config.num_register_tokens -# patch_tokens = x[:, num_special_tokens:, :] + self.norm2 = nn.LayerNorm(hidden_size, eps=hacky_config.layer_norm_eps) -# projected_tokens = self.projection(patch_tokens) + if config.use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hacky_config) + else: + self.mlp = DINOv3ViTMLP(hacky_config) + self.layer_scale2 = DINOv3ViTLayerScale(hacky_config) -# p = self.config.patch_size -# c = self.num_channels_out -# h_grid = image_size[0] // p -# w_grid = image_size[1] // p + # adaln + self.adaln = nn.Sequential( + nn.SiLU(), + nn.Linear(config.hidden_size, 6 * hidden_size, bias=True), + ) -# assert patch_tokens.shape[1] == h_grid * w_grid, ( -# "Number of patches does not match image size." -# ) + def forward( + self, + hidden_states: torch.Tensor, + *, + conditioning_input: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + conditioning_input = conditioning_input.squeeze(1) # type: ignore + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln( + conditioning_input + ).chunk(6, dim=-1) -# x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c) + hidden_states = ( + self.drop_path( + gate_msa.unsqueeze(1) + * self.layer_scale1( + self.attention( + modulate(self.norm1(hidden_states), shift_msa, scale_msa), + position_embeddings=position_embeddings, + ) + ) + ) + + hidden_states + ) -# x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped) + hidden_states = ( + self.drop_path( + gate_mlp.unsqueeze(1) + * self.layer_scale2( + self.mlp(modulate(self.norm2(hidden_states), shift_mlp, scale_mlp)) + ) + ) + + hidden_states + ) -# reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p) + return hidden_states -# return reconstructed_image -# lets try conv decoder +class ResidualUpscaler(nn.Module): + def __init__( + self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128 + ): # max depth 2 (4**2 = 16 = patch size) + super().__init__() + + def build_encoder(in_dim, num_layers=2): + return nn.Sequential( + nn.Conv2d( + in_dim, + bottleneck_dim, + kernel_size=1, + padding=0, + padding_mode="reflect", + bias=False, + ), + *[ + ResBlock( + bottleneck_dim, + bottleneck_dim, + kernel_size=1, + num_groups=8, + pad_mode="reflect", + norm_fn=nn.GroupNorm, + activation_fn=nn.SiLU, + use_conv_shortcut=False, + ) + for _ in range(num_layers) + ], + ) + + self.config = config + self.depth = depth + self.global_encode = nn.Linear( + config.hidden_size * len(depth), + bottleneck_dim * 2, + ) + self.local_encode = nn.ModuleList( + [ + nn.Linear(config.hidden_size, bottleneck_dim * 2) + if d != 0 + else nn.Identity() + for d in depth + ] + ) + self.q_norm = nn.ModuleList( + [nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth] + ) + self.k_norm = nn.ModuleList( + [nn.LayerNorm(bottleneck_dim) if d != 0 else nn.Identity() for d in depth] + ) + self.v_downsample = nn.ModuleList( + [ + nn.Linear(config.hidden_size, config.hidden_size // (16**d)) + if d != 0 + else nn.Identity() + for d in depth + ] + ) + self.cross_attn = nn.ModuleList( + [ + CrossAttention( + bottleneck_dim, + bottleneck_dim, + config.hidden_size // (16**d), + ) + if d != 0 + else nn.Identity() + for d in depth + ] + ) + + self.image_encoder = build_encoder(3) + self.q_encoder = build_encoder(bottleneck_dim) + self.k_encoder = build_encoder(bottleneck_dim) + + self.rope = RoPE(bottleneck_dim) + + def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]): + # residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)] + # objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well + assert self.config.patch_size is not None + image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1] + + global_shift, global_scale = self.global_encode( + einops.rearrange( + torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)" + ) + ).chunk(2, dim=-1) # patch-level global btw + + image_residual = self.image_encoder(pixel_values) # messy; todo: cleanup + coords = create_coordinate(pixel_values.shape[-2], pixel_values.shape[-1]) + image_residual = rearrange(image_residual, "b c h w -> b (h w) c") + image_residual = self.rope(image_residual, coords) + image_residual = rearrange( + image_residual, "b (h w) c -> b c h w", h=pixel_values.shape[-2] + ) + q = self.q_encoder(image_residual) + k = self.k_encoder(image_residual) + + reformed_residual = [] + for i, (depth, residual) in enumerate(zip(self.depth, residuals)): + if depth == 0: + reformed_residual.append(residual) + continue + + local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1) + local_q = self.q_norm[i]( + einops.rearrange( + F.adaptive_avg_pool2d( + q, + output_size=( + image_h // self.config.patch_size * (4**depth), + image_w // self.config.patch_size * (4**depth), + ), + ), + "b c h w -> b (h w) c", + ) + ) + local_k = (1 + local_scale) * ( + (1 + global_scale) + * self.k_norm[i]( + einops.rearrange( + F.adaptive_avg_pool2d( + k, + output_size=( + image_h // self.config.patch_size, + image_w // self.config.patch_size, + ), + ), + "b c h w -> b (h w) c", + ) + ) + + global_shift + ) + local_shift + local_v = self.v_downsample[i](residual) + + reformed_residual.append(self.cross_attn[i](local_q, local_k, local_v)) + + return reformed_residual class DinoV3ViTDecoder(nn.Module): @@ -185,9 +382,8 @@ class DinoV3ViTDecoder(nn.Module): self.patch_size = config.patch_size self.projection = nn.Linear( - config.hidden_size, config.num_channels * (self.patch_size**2), bias=True + config.hidden_size // 16 // 16, config.num_channels, bias=True ) - self.pixel_shuffle = nn.PixelShuffle(self.patch_size) nn.init.zeros_(self.projection.weight) nn.init.zeros_( @@ -195,69 +391,11 @@ class DinoV3ViTDecoder(nn.Module): ) if self.projection.bias is not None else None def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: - batch_size = x.shape[0] - - x = x[:, 1 + self.config.num_register_tokens :, :] - - p = self.config.patch_size - h_grid = image_size[0] // p - w_grid = image_size[1] // p - - assert x.shape[1] == h_grid * w_grid - x = self.projection(x) - x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2) - x = self.pixel_shuffle(x) - - return x - - -# how about transposed conv decoderclass DinoV3ViTDecoder(nn.Module): -# class DinoV3ViTDecoder(nn.Module): -# def __init__(self, config: DINOv3ViTConfig): -# super().__init__() -# self.config = config -# self.num_channels_out = config.num_channels -# self.patch_size = config.patch_size - -# intermediate_channels = config.hidden_size // 4 - -# self.decoder_block = nn.Sequential( -# nn.ConvTranspose2d( -# in_channels=config.hidden_size, -# out_channels=intermediate_channels, -# kernel_size=self.patch_size, -# stride=self.patch_size, -# bias=True, -# ), -# nn.LayerNorm(intermediate_channels), -# nn.GELU(), -# nn.Conv2d( -# in_channels=intermediate_channels, -# out_channels=config.num_channels, -# kernel_size=1, -# bias=True, -# ), -# ) - -# nn.init.zeros_(self.decoder_block[-1].weight) -# nn.init.zeros_(self.decoder_block[-1].bias) - -# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: -# batch_size = x.shape[0] - -# x = x[:, 1 + self.config.num_register_tokens :, :] - -# p = self.config.patch_size -# h_grid = image_size[0] // p -# w_grid = image_size[1] // p -# assert x.shape[1] == h_grid * w_grid - -# x = x.transpose(1, 2).reshape( -# batch_size, self.config.hidden_size, h_grid, w_grid -# ) -# x = self.decoder_block(x) - -# return x + return self.projection( + einops.rearrange( + x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1] + ) + ).permute(0, 3, 1, 2) class UTransformer(nn.Module): @@ -268,41 +406,71 @@ class UTransformer(nn.Module): self.config = config self.scale_factor = scale_factor - assert config.num_hidden_layers % scale_factor == 0 + assert config.num_hidden_layers % scale_factor % 3 == 0 self.embeddings = DINOv3ViTEmbeddings(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) + + def gen_rope(depth: int): + hidden_size = config.hidden_size // (16**depth) + hacky_config = copy.copy(config) + hacky_config.hidden_size = hidden_size + hacky_config.intermediate_size = hacky_config.intermediate_size // ( + 16**depth + ) + hacky_config.num_attention_heads = hacky_config.num_attention_heads // ( + 4**depth + ) + return DINOv3ViTRopePositionEmbedding(hacky_config) + + self.decode_ropes = nn.ModuleList([gen_rope(i + 1) for i in range(2)]) self.t_embedder = TimestepEmbedder(config.hidden_size) - # self.y_embedder = LabelEmbedder( - # num_classes, config.hidden_size, config.drop_path_rate - # ) # disable cond for now self.encoder_layers = nn.ModuleList( - [ - DinoConditionedLayer(config, True) - for _ in range(config.num_hidden_layers) - ] + [DinoEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.residual_upscaler = ResidualUpscaler( + config, + [0, 0, 1, 1, 2, 2], # hardcoded, sorry + bottleneck_dim=128, + ) self.decoder_layers = nn.ModuleList( [ - DinoConditionedLayer(config, False) - for _ in range(config.num_hidden_layers // scale_factor) + nn.ModuleList( + [ + DinoDecoderLayer(config, depth) + for _ in range((config.num_hidden_layers // scale_factor) // 3) + ] + ) + for depth in range(3) ] ) self.residual_merger = nn.ModuleList( [ - nn.Sequential( - nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size) + nn.ModuleList( + [ + nn.Sequential( + nn.SiLU(), + nn.Linear( + config.hidden_size // (16**depth), + 2 * config.hidden_size // (16**depth), + ), + ) + for _ in range((config.num_hidden_layers // scale_factor) // 3) + ] ) - for _ in range(config.num_hidden_layers // scale_factor) + for depth in range(3) ] ) + self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)]) self.rest_decoder = nn.ModuleList( - [DinoConditionedLayer(config, False) for _ in range(4)] + [DinoDecoderLayer(config, 2) for _ in range(4)] + ) + self.decoder_norm = nn.LayerNorm( + (config.hidden_size // (16**2)), eps=config.layer_norm_eps ) - self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.decoder = DinoV3ViTDecoder(config) # freeze pretrained @@ -333,7 +501,7 @@ class UTransformer(nn.Module): residual = [] for i, layer_module in enumerate(self.encoder_layers): if i % self.scale_factor == 0: - residual.append(x) + residual.append(x[:, 1 + self.config.num_register_tokens :]) layer_head_mask = head_mask[i] if head_mask is not None else None x = layer_module( x, @@ -343,69 +511,66 @@ class UTransformer(nn.Module): ) x = self.encoder_norm(x) + x = x[:, 1 + self.config.num_register_tokens :] + reversed_residual = self.residual_upscaler(pixel_values, residual[::-1]) - reversed_residual = residual[::-1] - for i, layer_module in enumerate(self.decoder_layers): - layer_head_mask = head_mask[i] if head_mask is not None else None - x = layer_module( - x, - conditioning_input=conditioning_input, - attention_mask=layer_head_mask, - position_embeddings=position_embeddings, - do_condition=False, + residual_idx = 0 + for depth, layers in enumerate(self.decoder_layers): + for i, layer_module in enumerate(layers): # type: ignore + x = layer_module( + x, + conditioning_input=conditioning_input, + attention_mask=None, + position_embeddings=position_embeddings, + ) + shift, scale = self.residual_merger[depth][i]( # type: ignore + reversed_residual[residual_idx] + ).chunk(2, dim=-1) + x = x * (1 + scale) + shift + residual_idx += 1 + x = ( + rearrange( + self.upsample[depth]( + rearrange( + x, + "b (h w) d -> b d h w", + h=pixel_values.shape[-2] + // (self.config.patch_size) + * (4**depth), + ) + ), + "b d h w -> b (h w) d", + ) + if depth != 2 + else x ) - shift, scale = self.residual_merger[i](reversed_residual[i]).chunk( - 2, dim=-1 + position_embeddings = ( + self.decode_ropes[depth]( + torch.zeros( + ( + 1, + 1, + pixel_values.shape[-2] * (4 ** (depth + 1)), + pixel_values.shape[-1] * (4 ** (depth + 1)), + ), + device=x.device, + ).to(self.embeddings.patch_embeddings.weight.dtype) + ) + if depth != 2 + else position_embeddings ) - x = x * (1 + scale) + shift for i, layer_module in enumerate(self.rest_decoder): - layer_head_mask = head_mask[i] if head_mask is not None else None - x = layer_module( - x, - conditioning_input=conditioning_input, - attention_mask=layer_head_mask, - position_embeddings=position_embeddings, - do_condition=False, - ) - - x = self.decoder_norm(x) - - return self.decoder(x, image_size=pixel_values.shape[-2:]), residual - - def get_residual( - self, - pixel_values: torch.Tensor, - time: Optional[torch.Tensor], - do_condition: bool, - ): - pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) - position_embeddings = self.rope_embeddings(pixel_values) - - x = self.embeddings(pixel_values, bool_masked_pos=None) - - if do_condition: - t = self.t_embedder(time).unsqueeze(1) - # y = self.y_embedder(cond, self.training).unsqueeze(1) - # conditioning_input = t.to(x.dtype) + y.to(x.dtype) - conditioning_input = t.to(x.dtype) - else: - conditioning_input = None - - residual = [] - for i, layer_module in enumerate(self.encoder_layers): - if i % self.scale_factor == 0: - residual.append(x) - x = layer_module( x, conditioning_input=conditioning_input, attention_mask=None, position_embeddings=position_embeddings, - do_condition=do_condition, ) - return residual + x = self.decoder_norm(x) + + return self.decoder(x, image_size=pixel_values.shape[-2:]), residual @staticmethod def from_pretrained_backbone(name: str): diff --git a/uv.lock b/uv.lock index 1c58421..048d20f 100644 --- a/uv.lock +++ b/uv.lock @@ -298,6 +298,7 @@ dependencies = [ { name = "datasets" }, { name = "einops" }, { name = "flash-attn" }, + { name = "lovely-tensors" }, { name = "lpips" }, { name = "pyright" }, { name = "python-lsp-server" }, @@ -320,6 +321,7 @@ requires-dist = [ { name = "datasets", specifier = ">=4.1.1" }, { name = "einops", specifier = ">=0.8.1" }, { name = "flash-attn" }, + { name = "lovely-tensors", specifier = ">=0.1.19" }, { name = "lpips", specifier = ">=0.1.4" }, { name = "pyright", specifier = ">=1.1.405" }, { name = "python-lsp-server", specifier = ">=1.13.1" }, @@ -524,6 +526,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] +[[package]] +name = "fastcore" +version = "1.8.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/af/03bfd50e750139cb84c969ce628ad77ab06505d0834f791b090432f2486e/fastcore-1.8.12.tar.gz", hash = "sha256:c6febf2c689ea365d2edb708dcfce687dbc7e05437a5962c9187a9f170c82ff7", size = 81170, upload-time = "2025-09-29T00:56:34.855Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/56/202cddea9ea549be21867dbc1a98ead8660e61e943374a28339d02aa5b4a/fastcore-1.8.12-py3-none-any.whl", hash = "sha256:7b9bfb8c564f307313971bbb27a4542e945e48e8149361d8a5e4e14f7d436937", size = 84190, upload-time = "2025-09-29T00:56:33.086Z" }, +] + [[package]] name = "filelock" version = "3.19.1" @@ -995,6 +1009,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/73/3d757cb3fc16f0f9794dd289bcd0c4a031d9cf54d8137d6b984b2d02edf3/lightning_utilities-0.15.2-py3-none-any.whl", hash = "sha256:ad3ab1703775044bbf880dbf7ddaaac899396c96315f3aa1779cec9d618a9841", size = 29431, upload-time = "2025-08-06T13:57:38.046Z" }, ] +[[package]] +name = "lovely-numpy" +version = "0.2.16" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastcore" }, + { name = "matplotlib" }, + { name = "numpy" }, + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/1c/be708c2fcca304da657a6e808ddb7488c83fbac7808e1c4bd5c30dbce025/lovely_numpy-0.2.16.tar.gz", hash = "sha256:fda8ca67134f48814ddb9efe9eb38dffc74c0849d69eb2dd89d0b2757e14b628", size = 24407, upload-time = "2025-10-04T15:52:55.261Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/ef/93f039e84d7ab39b5b5946c2338509e0accfc62e4ec0bd918d7e64a00cc0/lovely_numpy-0.2.16-py3-none-any.whl", hash = "sha256:cc7ff99a18f79d03bbe1f8727e315037b4f5db2a9cc8d04be29be43876eeb0fb", size = 24425, upload-time = "2025-10-04T15:52:53.932Z" }, +] + +[[package]] +name = "lovely-tensors" +version = "0.1.19" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lovely-numpy" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f8/c2/8c21ea15038a2e98079285c52a180e53ea7f0aadcf5b1fa58c1cadfb68da/lovely_tensors-0.1.19.tar.gz", hash = "sha256:0a9cec41c6a13d7de3ca3688c10f61991071352116c5303c3e62c91febf32016", size = 22558, upload-time = "2025-10-04T15:56:33.085Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/ae/93fd971b7a067ddca543302e7632509a36078b48ad0d08643bc17ac35cf3/lovely_tensors-0.1.19-py3-none-any.whl", hash = "sha256:b0fee4f5cfbb0494be6bbce0963bf8e39edc1e4ef89afe922c30edfb55d65237", size = 19394, upload-time = "2025-10-04T15:56:31.682Z" }, +] + [[package]] name = "lpips" version = "0.1.4"