upscale on middle
This commit is contained in:
1
main.py
1
main.py
@@ -137,6 +137,7 @@ for epoch in range(start_epoch, total_epoch):
|
||||
wandb.log(epoch_metrics)
|
||||
|
||||
if (epoch + 1) % 50 == 0:
|
||||
with autocast(dtype=torch.bfloat16):
|
||||
rf.model.eval()
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
@@ -162,14 +161,14 @@ class DinoDecoderLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, depth: int):
|
||||
super().__init__(config)
|
||||
|
||||
hidden_size = config.hidden_size // (16**depth)
|
||||
hidden_size = config.hidden_size // (4**depth)
|
||||
hacky_config = copy.copy(config)
|
||||
hacky_config.hidden_size = hidden_size
|
||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth)
|
||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (3**depth)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = PlainAttention(
|
||||
hidden_size, config.num_attention_heads // (4**depth)
|
||||
hidden_size, config.num_attention_heads // (2**depth)
|
||||
) # head scaling law?
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
|
||||
self.drop_path = (
|
||||
@@ -235,7 +234,7 @@ class DinoDecoderLayer(DINOv3ViTLayer):
|
||||
class ResidualUpscaler(nn.Module):
|
||||
def __init__(
|
||||
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
|
||||
): # max depth 2 (4**2 = 16 = patch size)
|
||||
): # max depth 2 (2**2 = 16 = patch size)
|
||||
super().__init__()
|
||||
|
||||
def build_encoder(in_dim, num_layers=2):
|
||||
@@ -285,7 +284,7 @@ class ResidualUpscaler(nn.Module):
|
||||
)
|
||||
self.v_downsample = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(config.hidden_size, config.hidden_size // (16**d))
|
||||
nn.Linear(config.hidden_size, config.hidden_size // (4**d))
|
||||
if d != 0
|
||||
else nn.Identity()
|
||||
for d in depth
|
||||
@@ -296,7 +295,7 @@ class ResidualUpscaler(nn.Module):
|
||||
CrossAttention(
|
||||
bottleneck_dim,
|
||||
bottleneck_dim,
|
||||
config.hidden_size // (16**d),
|
||||
config.hidden_size // (4**d),
|
||||
)
|
||||
if d != 0
|
||||
else nn.Identity()
|
||||
@@ -310,10 +309,14 @@ class ResidualUpscaler(nn.Module):
|
||||
|
||||
self.rope = RoPE(bottleneck_dim)
|
||||
|
||||
# ok just shuffle it; no dont
|
||||
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
||||
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
|
||||
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
||||
assert self.config.patch_size is not None
|
||||
|
||||
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
||||
|
||||
global_shift, global_scale = self.global_encode(
|
||||
@@ -341,11 +344,11 @@ class ResidualUpscaler(nn.Module):
|
||||
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
||||
local_q = self.q_norm[i](
|
||||
einops.rearrange(
|
||||
F.adaptive_avg_pool2d(
|
||||
torch.nn.functional.adaptive_avg_pool2d(
|
||||
q,
|
||||
output_size=(
|
||||
image_h // self.config.patch_size * (4**depth),
|
||||
image_w // self.config.patch_size * (4**depth),
|
||||
image_h // self.config.patch_size * (2**depth),
|
||||
image_w // self.config.patch_size * (2**depth),
|
||||
),
|
||||
),
|
||||
"b c h w -> b (h w) c",
|
||||
@@ -355,7 +358,7 @@ class ResidualUpscaler(nn.Module):
|
||||
(1 + global_scale)
|
||||
* self.k_norm[i](
|
||||
einops.rearrange(
|
||||
F.adaptive_avg_pool2d(
|
||||
torch.nn.functional.adaptive_avg_pool2d(
|
||||
k,
|
||||
output_size=(
|
||||
image_h // self.config.patch_size,
|
||||
@@ -382,8 +385,9 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size // 16 // 16, config.num_channels, bias=True
|
||||
config.hidden_size // 16, config.num_channels * 16, bias=True
|
||||
)
|
||||
self.upscale = nn.PixelShuffle(4)
|
||||
|
||||
nn.init.zeros_(self.projection.weight)
|
||||
nn.init.zeros_(
|
||||
@@ -391,11 +395,16 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
) if self.projection.bias is not None else None
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
return self.projection(
|
||||
return self.upscale(
|
||||
self.projection(
|
||||
einops.rearrange(
|
||||
x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1]
|
||||
x,
|
||||
"b (h w) d -> b h w d",
|
||||
h=image_size[0] // 4,
|
||||
w=image_size[1] // 4,
|
||||
)
|
||||
).permute(0, 3, 1, 2)
|
||||
)
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
@@ -412,14 +421,14 @@ class UTransformer(nn.Module):
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
|
||||
def gen_rope(depth: int):
|
||||
hidden_size = config.hidden_size // (16**depth)
|
||||
hidden_size = config.hidden_size // (4**depth)
|
||||
hacky_config = copy.copy(config)
|
||||
hacky_config.hidden_size = hidden_size
|
||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (
|
||||
16**depth
|
||||
3**depth
|
||||
)
|
||||
hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
|
||||
4**depth
|
||||
2**depth
|
||||
)
|
||||
return DINOv3ViTRopePositionEmbedding(hacky_config)
|
||||
|
||||
@@ -431,9 +440,10 @@ class UTransformer(nn.Module):
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
DEPTH_LAYER = [0, 0, 1, 1, 2, 2]
|
||||
self.residual_upscaler = ResidualUpscaler(
|
||||
config,
|
||||
[0, 0, 1, 1, 2, 2], # hardcoded, sorry
|
||||
DEPTH_LAYER, # hardcoded, sorry
|
||||
bottleneck_dim=128,
|
||||
)
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
@@ -441,10 +451,10 @@ class UTransformer(nn.Module):
|
||||
nn.ModuleList(
|
||||
[
|
||||
DinoDecoderLayer(config, depth)
|
||||
for _ in range((config.num_hidden_layers // scale_factor) // 3)
|
||||
for _ in range(DEPTH_LAYER.count(depth))
|
||||
]
|
||||
)
|
||||
for depth in range(3)
|
||||
for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
|
||||
]
|
||||
)
|
||||
self.residual_merger = nn.ModuleList(
|
||||
@@ -454,22 +464,22 @@ class UTransformer(nn.Module):
|
||||
nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
config.hidden_size // (16**depth),
|
||||
2 * config.hidden_size // (16**depth),
|
||||
config.hidden_size // (4**depth),
|
||||
2 * config.hidden_size // (4**depth),
|
||||
),
|
||||
)
|
||||
for _ in range((config.num_hidden_layers // scale_factor) // 3)
|
||||
for _ in range(DEPTH_LAYER.count(depth))
|
||||
]
|
||||
)
|
||||
for depth in range(3)
|
||||
for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
|
||||
]
|
||||
)
|
||||
self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)])
|
||||
self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
|
||||
self.rest_decoder = nn.ModuleList(
|
||||
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(
|
||||
(config.hidden_size // (16**2)), eps=config.layer_norm_eps
|
||||
(config.hidden_size // (4**2)), eps=config.layer_norm_eps
|
||||
)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
@@ -536,7 +546,7 @@ class UTransformer(nn.Module):
|
||||
"b (h w) d -> b d h w",
|
||||
h=pixel_values.shape[-2]
|
||||
// (self.config.patch_size)
|
||||
* (4**depth),
|
||||
* (2**depth),
|
||||
)
|
||||
),
|
||||
"b d h w -> b (h w) d",
|
||||
@@ -550,8 +560,8 @@ class UTransformer(nn.Module):
|
||||
(
|
||||
1,
|
||||
1,
|
||||
pixel_values.shape[-2] * (4 ** (depth + 1)),
|
||||
pixel_values.shape[-1] * (4 ** (depth + 1)),
|
||||
pixel_values.shape[-2] * (2 ** (depth + 1)),
|
||||
pixel_values.shape[-1] * (2 ** (depth + 1)),
|
||||
),
|
||||
device=x.device,
|
||||
).to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
|
||||
Reference in New Issue
Block a user