upscale on middle

This commit is contained in:
neulus
2025-10-26 15:42:35 +09:00
parent 6ab33ceb83
commit 09c1c2220a
2 changed files with 83 additions and 72 deletions

View File

@@ -4,7 +4,6 @@ from typing import Optional
import einops
import torch
import torch.nn.functional as F
from einops import rearrange
from huggingface_hub import hf_hub_download
from safetensors import safe_open
@@ -162,14 +161,14 @@ class DinoDecoderLayer(DINOv3ViTLayer):
def __init__(self, config: DINOv3ViTConfig, depth: int):
super().__init__(config)
hidden_size = config.hidden_size // (16**depth)
hidden_size = config.hidden_size // (4**depth)
hacky_config = copy.copy(config)
hacky_config.hidden_size = hidden_size
hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth)
hacky_config.intermediate_size = hacky_config.intermediate_size // (3**depth)
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.attention = PlainAttention(
hidden_size, config.num_attention_heads // (4**depth)
hidden_size, config.num_attention_heads // (2**depth)
) # head scaling law?
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
self.drop_path = (
@@ -235,7 +234,7 @@ class DinoDecoderLayer(DINOv3ViTLayer):
class ResidualUpscaler(nn.Module):
def __init__(
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
): # max depth 2 (4**2 = 16 = patch size)
): # max depth 2 (2**2 = 16 = patch size)
super().__init__()
def build_encoder(in_dim, num_layers=2):
@@ -285,7 +284,7 @@ class ResidualUpscaler(nn.Module):
)
self.v_downsample = nn.ModuleList(
[
nn.Linear(config.hidden_size, config.hidden_size // (16**d))
nn.Linear(config.hidden_size, config.hidden_size // (4**d))
if d != 0
else nn.Identity()
for d in depth
@@ -296,7 +295,7 @@ class ResidualUpscaler(nn.Module):
CrossAttention(
bottleneck_dim,
bottleneck_dim,
config.hidden_size // (16**d),
config.hidden_size // (4**d),
)
if d != 0
else nn.Identity()
@@ -310,10 +309,14 @@ class ResidualUpscaler(nn.Module):
self.rope = RoPE(bottleneck_dim)
# ok just shuffle it; no dont
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
assert self.config.patch_size is not None
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
global_shift, global_scale = self.global_encode(
@@ -341,11 +344,11 @@ class ResidualUpscaler(nn.Module):
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
local_q = self.q_norm[i](
einops.rearrange(
F.adaptive_avg_pool2d(
torch.nn.functional.adaptive_avg_pool2d(
q,
output_size=(
image_h // self.config.patch_size * (4**depth),
image_w // self.config.patch_size * (4**depth),
image_h // self.config.patch_size * (2**depth),
image_w // self.config.patch_size * (2**depth),
),
),
"b c h w -> b (h w) c",
@@ -355,7 +358,7 @@ class ResidualUpscaler(nn.Module):
(1 + global_scale)
* self.k_norm[i](
einops.rearrange(
F.adaptive_avg_pool2d(
torch.nn.functional.adaptive_avg_pool2d(
k,
output_size=(
image_h // self.config.patch_size,
@@ -382,8 +385,9 @@ class DinoV3ViTDecoder(nn.Module):
self.patch_size = config.patch_size
self.projection = nn.Linear(
config.hidden_size // 16 // 16, config.num_channels, bias=True
config.hidden_size // 16, config.num_channels * 16, bias=True
)
self.upscale = nn.PixelShuffle(4)
nn.init.zeros_(self.projection.weight)
nn.init.zeros_(
@@ -391,11 +395,16 @@ class DinoV3ViTDecoder(nn.Module):
) if self.projection.bias is not None else None
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
return self.projection(
einops.rearrange(
x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1]
)
).permute(0, 3, 1, 2)
return self.upscale(
self.projection(
einops.rearrange(
x,
"b (h w) d -> b h w d",
h=image_size[0] // 4,
w=image_size[1] // 4,
)
).permute(0, 3, 1, 2)
)
class UTransformer(nn.Module):
@@ -412,14 +421,14 @@ class UTransformer(nn.Module):
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
def gen_rope(depth: int):
hidden_size = config.hidden_size // (16**depth)
hidden_size = config.hidden_size // (4**depth)
hacky_config = copy.copy(config)
hacky_config.hidden_size = hidden_size
hacky_config.intermediate_size = hacky_config.intermediate_size // (
16**depth
3**depth
)
hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
4**depth
2**depth
)
return DINOv3ViTRopePositionEmbedding(hacky_config)
@@ -431,9 +440,10 @@ class UTransformer(nn.Module):
)
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
DEPTH_LAYER = [0, 0, 1, 1, 2, 2]
self.residual_upscaler = ResidualUpscaler(
config,
[0, 0, 1, 1, 2, 2], # hardcoded, sorry
DEPTH_LAYER, # hardcoded, sorry
bottleneck_dim=128,
)
self.decoder_layers = nn.ModuleList(
@@ -441,10 +451,10 @@ class UTransformer(nn.Module):
nn.ModuleList(
[
DinoDecoderLayer(config, depth)
for _ in range((config.num_hidden_layers // scale_factor) // 3)
for _ in range(DEPTH_LAYER.count(depth))
]
)
for depth in range(3)
for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
]
)
self.residual_merger = nn.ModuleList(
@@ -454,22 +464,22 @@ class UTransformer(nn.Module):
nn.Sequential(
nn.SiLU(),
nn.Linear(
config.hidden_size // (16**depth),
2 * config.hidden_size // (16**depth),
config.hidden_size // (4**depth),
2 * config.hidden_size // (4**depth),
),
)
for _ in range((config.num_hidden_layers // scale_factor) // 3)
for _ in range(DEPTH_LAYER.count(depth))
]
)
for depth in range(3)
for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
]
)
self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)])
self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
self.rest_decoder = nn.ModuleList(
[DinoDecoderLayer(config, 2) for _ in range(4)]
)
self.decoder_norm = nn.LayerNorm(
(config.hidden_size // (16**2)), eps=config.layer_norm_eps
(config.hidden_size // (4**2)), eps=config.layer_norm_eps
)
self.decoder = DinoV3ViTDecoder(config)
@@ -536,7 +546,7 @@ class UTransformer(nn.Module):
"b (h w) d -> b d h w",
h=pixel_values.shape[-2]
// (self.config.patch_size)
* (4**depth),
* (2**depth),
)
),
"b d h w -> b (h w) d",
@@ -550,8 +560,8 @@ class UTransformer(nn.Module):
(
1,
1,
pixel_values.shape[-2] * (4 ** (depth + 1)),
pixel_values.shape[-1] * (4 ** (depth + 1)),
pixel_values.shape[-2] * (2 ** (depth + 1)),
pixel_values.shape[-1] * (2 ** (depth + 1)),
),
device=x.device,
).to(self.embeddings.patch_embeddings.weight.dtype)