upscale on middle

This commit is contained in:
neulus
2025-10-26 15:42:35 +09:00
parent 6ab33ceb83
commit 09c1c2220a
2 changed files with 83 additions and 72 deletions

View File

@@ -137,6 +137,7 @@ for epoch in range(start_epoch, total_epoch):
wandb.log(epoch_metrics) wandb.log(epoch_metrics)
if (epoch + 1) % 50 == 0: if (epoch + 1) % 50 == 0:
with autocast(dtype=torch.bfloat16):
rf.model.eval() rf.model.eval()
psnr_sum = 0 psnr_sum = 0
ssim_sum = 0 ssim_sum = 0

View File

@@ -4,7 +4,6 @@ from typing import Optional
import einops import einops
import torch import torch
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from safetensors import safe_open from safetensors import safe_open
@@ -162,14 +161,14 @@ class DinoDecoderLayer(DINOv3ViTLayer):
def __init__(self, config: DINOv3ViTConfig, depth: int): def __init__(self, config: DINOv3ViTConfig, depth: int):
super().__init__(config) super().__init__(config)
hidden_size = config.hidden_size // (16**depth) hidden_size = config.hidden_size // (4**depth)
hacky_config = copy.copy(config) hacky_config = copy.copy(config)
hacky_config.hidden_size = hidden_size hacky_config.hidden_size = hidden_size
hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth) hacky_config.intermediate_size = hacky_config.intermediate_size // (3**depth)
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.attention = PlainAttention( self.attention = PlainAttention(
hidden_size, config.num_attention_heads // (4**depth) hidden_size, config.num_attention_heads // (2**depth)
) # head scaling law? ) # head scaling law?
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config) self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
self.drop_path = ( self.drop_path = (
@@ -235,7 +234,7 @@ class DinoDecoderLayer(DINOv3ViTLayer):
class ResidualUpscaler(nn.Module): class ResidualUpscaler(nn.Module):
def __init__( def __init__(
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128 self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
): # max depth 2 (4**2 = 16 = patch size) ): # max depth 2 (2**2 = 16 = patch size)
super().__init__() super().__init__()
def build_encoder(in_dim, num_layers=2): def build_encoder(in_dim, num_layers=2):
@@ -285,7 +284,7 @@ class ResidualUpscaler(nn.Module):
) )
self.v_downsample = nn.ModuleList( self.v_downsample = nn.ModuleList(
[ [
nn.Linear(config.hidden_size, config.hidden_size // (16**d)) nn.Linear(config.hidden_size, config.hidden_size // (4**d))
if d != 0 if d != 0
else nn.Identity() else nn.Identity()
for d in depth for d in depth
@@ -296,7 +295,7 @@ class ResidualUpscaler(nn.Module):
CrossAttention( CrossAttention(
bottleneck_dim, bottleneck_dim,
bottleneck_dim, bottleneck_dim,
config.hidden_size // (16**d), config.hidden_size // (4**d),
) )
if d != 0 if d != 0
else nn.Identity() else nn.Identity()
@@ -310,10 +309,14 @@ class ResidualUpscaler(nn.Module):
self.rope = RoPE(bottleneck_dim) self.rope = RoPE(bottleneck_dim)
# ok just shuffle it; no dont
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]): def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)] # residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well # objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
assert self.config.patch_size is not None assert self.config.patch_size is not None
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1] image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
global_shift, global_scale = self.global_encode( global_shift, global_scale = self.global_encode(
@@ -341,11 +344,11 @@ class ResidualUpscaler(nn.Module):
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1) local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
local_q = self.q_norm[i]( local_q = self.q_norm[i](
einops.rearrange( einops.rearrange(
F.adaptive_avg_pool2d( torch.nn.functional.adaptive_avg_pool2d(
q, q,
output_size=( output_size=(
image_h // self.config.patch_size * (4**depth), image_h // self.config.patch_size * (2**depth),
image_w // self.config.patch_size * (4**depth), image_w // self.config.patch_size * (2**depth),
), ),
), ),
"b c h w -> b (h w) c", "b c h w -> b (h w) c",
@@ -355,7 +358,7 @@ class ResidualUpscaler(nn.Module):
(1 + global_scale) (1 + global_scale)
* self.k_norm[i]( * self.k_norm[i](
einops.rearrange( einops.rearrange(
F.adaptive_avg_pool2d( torch.nn.functional.adaptive_avg_pool2d(
k, k,
output_size=( output_size=(
image_h // self.config.patch_size, image_h // self.config.patch_size,
@@ -382,8 +385,9 @@ class DinoV3ViTDecoder(nn.Module):
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.projection = nn.Linear( self.projection = nn.Linear(
config.hidden_size // 16 // 16, config.num_channels, bias=True config.hidden_size // 16, config.num_channels * 16, bias=True
) )
self.upscale = nn.PixelShuffle(4)
nn.init.zeros_(self.projection.weight) nn.init.zeros_(self.projection.weight)
nn.init.zeros_( nn.init.zeros_(
@@ -391,11 +395,16 @@ class DinoV3ViTDecoder(nn.Module):
) if self.projection.bias is not None else None ) if self.projection.bias is not None else None
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
return self.projection( return self.upscale(
self.projection(
einops.rearrange( einops.rearrange(
x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1] x,
"b (h w) d -> b h w d",
h=image_size[0] // 4,
w=image_size[1] // 4,
) )
).permute(0, 3, 1, 2) ).permute(0, 3, 1, 2)
)
class UTransformer(nn.Module): class UTransformer(nn.Module):
@@ -412,14 +421,14 @@ class UTransformer(nn.Module):
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
def gen_rope(depth: int): def gen_rope(depth: int):
hidden_size = config.hidden_size // (16**depth) hidden_size = config.hidden_size // (4**depth)
hacky_config = copy.copy(config) hacky_config = copy.copy(config)
hacky_config.hidden_size = hidden_size hacky_config.hidden_size = hidden_size
hacky_config.intermediate_size = hacky_config.intermediate_size // ( hacky_config.intermediate_size = hacky_config.intermediate_size // (
16**depth 3**depth
) )
hacky_config.num_attention_heads = hacky_config.num_attention_heads // ( hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
4**depth 2**depth
) )
return DINOv3ViTRopePositionEmbedding(hacky_config) return DINOv3ViTRopePositionEmbedding(hacky_config)
@@ -431,9 +440,10 @@ class UTransformer(nn.Module):
) )
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
DEPTH_LAYER = [0, 0, 1, 1, 2, 2]
self.residual_upscaler = ResidualUpscaler( self.residual_upscaler = ResidualUpscaler(
config, config,
[0, 0, 1, 1, 2, 2], # hardcoded, sorry DEPTH_LAYER, # hardcoded, sorry
bottleneck_dim=128, bottleneck_dim=128,
) )
self.decoder_layers = nn.ModuleList( self.decoder_layers = nn.ModuleList(
@@ -441,10 +451,10 @@ class UTransformer(nn.Module):
nn.ModuleList( nn.ModuleList(
[ [
DinoDecoderLayer(config, depth) DinoDecoderLayer(config, depth)
for _ in range((config.num_hidden_layers // scale_factor) // 3) for _ in range(DEPTH_LAYER.count(depth))
] ]
) )
for depth in range(3) for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
] ]
) )
self.residual_merger = nn.ModuleList( self.residual_merger = nn.ModuleList(
@@ -454,22 +464,22 @@ class UTransformer(nn.Module):
nn.Sequential( nn.Sequential(
nn.SiLU(), nn.SiLU(),
nn.Linear( nn.Linear(
config.hidden_size // (16**depth), config.hidden_size // (4**depth),
2 * config.hidden_size // (16**depth), 2 * config.hidden_size // (4**depth),
), ),
) )
for _ in range((config.num_hidden_layers // scale_factor) // 3) for _ in range(DEPTH_LAYER.count(depth))
] ]
) )
for depth in range(3) for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
] ]
) )
self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)]) self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
self.rest_decoder = nn.ModuleList( self.rest_decoder = nn.ModuleList(
[DinoDecoderLayer(config, 2) for _ in range(4)] [DinoDecoderLayer(config, 2) for _ in range(4)]
) )
self.decoder_norm = nn.LayerNorm( self.decoder_norm = nn.LayerNorm(
(config.hidden_size // (16**2)), eps=config.layer_norm_eps (config.hidden_size // (4**2)), eps=config.layer_norm_eps
) )
self.decoder = DinoV3ViTDecoder(config) self.decoder = DinoV3ViTDecoder(config)
@@ -536,7 +546,7 @@ class UTransformer(nn.Module):
"b (h w) d -> b d h w", "b (h w) d -> b d h w",
h=pixel_values.shape[-2] h=pixel_values.shape[-2]
// (self.config.patch_size) // (self.config.patch_size)
* (4**depth), * (2**depth),
) )
), ),
"b d h w -> b (h w) d", "b d h w -> b (h w) d",
@@ -550,8 +560,8 @@ class UTransformer(nn.Module):
( (
1, 1,
1, 1,
pixel_values.shape[-2] * (4 ** (depth + 1)), pixel_values.shape[-2] * (2 ** (depth + 1)),
pixel_values.shape[-1] * (4 ** (depth + 1)), pixel_values.shape[-1] * (2 ** (depth + 1)),
), ),
device=x.device, device=x.device,
).to(self.embeddings.patch_embeddings.weight.dtype) ).to(self.embeddings.patch_embeddings.weight.dtype)