upscale on middle
This commit is contained in:
81
main.py
81
main.py
@@ -137,50 +137,51 @@ for epoch in range(start_epoch, total_epoch):
|
|||||||
wandb.log(epoch_metrics)
|
wandb.log(epoch_metrics)
|
||||||
|
|
||||||
if (epoch + 1) % 50 == 0:
|
if (epoch + 1) % 50 == 0:
|
||||||
rf.model.eval()
|
with autocast(dtype=torch.bfloat16):
|
||||||
psnr_sum = 0
|
rf.model.eval()
|
||||||
ssim_sum = 0
|
psnr_sum = 0
|
||||||
lpips_sum = 0
|
ssim_sum = 0
|
||||||
count = 0
|
lpips_sum = 0
|
||||||
|
count = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in tqdm(
|
for i in tqdm(
|
||||||
range(0, len(test_dataset), batch_size),
|
range(0, len(test_dataset), batch_size),
|
||||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||||
):
|
):
|
||||||
batch = test_dataset[i : i + batch_size]
|
batch = test_dataset[i : i + batch_size]
|
||||||
images = rf.sample(batch["cloud"].to(device))
|
images = rf.sample(batch["cloud"].to(device))
|
||||||
image = denormalize(images[-1]).clamp(0, 1)
|
image = denormalize(images[-1]).clamp(0, 1)
|
||||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
for step, demo in enumerate([images[0], images[-1]]):
|
for step, demo in enumerate([images[0], images[-1]]):
|
||||||
images = wandb.Image(
|
images = wandb.Image(
|
||||||
make_grid(
|
make_grid(
|
||||||
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
|
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
|
||||||
),
|
),
|
||||||
caption=f"step {step}",
|
caption=f"step {step}",
|
||||||
)
|
)
|
||||||
wandb.log({"viz/decoded": images})
|
wandb.log({"viz/decoded": images})
|
||||||
|
|
||||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||||
psnr_sum += psnr.sum().item()
|
psnr_sum += psnr.sum().item()
|
||||||
ssim_sum += ssim.sum().item()
|
ssim_sum += ssim.sum().item()
|
||||||
lpips_sum += lpips.sum().item()
|
lpips_sum += lpips.sum().item()
|
||||||
count += image.shape[0]
|
count += image.shape[0]
|
||||||
|
|
||||||
avg_psnr = psnr_sum / count
|
avg_psnr = psnr_sum / count
|
||||||
avg_ssim = ssim_sum / count
|
avg_ssim = ssim_sum / count
|
||||||
avg_lpips = lpips_sum / count
|
avg_lpips = lpips_sum / count
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"eval/psnr": avg_psnr,
|
"eval/psnr": avg_psnr,
|
||||||
"eval/ssim": avg_ssim,
|
"eval/ssim": avg_ssim,
|
||||||
"eval/lpips": avg_lpips,
|
"eval/lpips": avg_lpips,
|
||||||
"epoch": epoch + 1,
|
"epoch": epoch + 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
rf.model.train()
|
rf.model.train()
|
||||||
|
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
@@ -162,14 +161,14 @@ class DinoDecoderLayer(DINOv3ViTLayer):
|
|||||||
def __init__(self, config: DINOv3ViTConfig, depth: int):
|
def __init__(self, config: DINOv3ViTConfig, depth: int):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
hidden_size = config.hidden_size // (16**depth)
|
hidden_size = config.hidden_size // (4**depth)
|
||||||
hacky_config = copy.copy(config)
|
hacky_config = copy.copy(config)
|
||||||
hacky_config.hidden_size = hidden_size
|
hacky_config.hidden_size = hidden_size
|
||||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth)
|
hacky_config.intermediate_size = hacky_config.intermediate_size // (3**depth)
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
||||||
self.attention = PlainAttention(
|
self.attention = PlainAttention(
|
||||||
hidden_size, config.num_attention_heads // (4**depth)
|
hidden_size, config.num_attention_heads // (2**depth)
|
||||||
) # head scaling law?
|
) # head scaling law?
|
||||||
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
|
self.layer_scale1 = DINOv3ViTLayerScale(hacky_config)
|
||||||
self.drop_path = (
|
self.drop_path = (
|
||||||
@@ -235,7 +234,7 @@ class DinoDecoderLayer(DINOv3ViTLayer):
|
|||||||
class ResidualUpscaler(nn.Module):
|
class ResidualUpscaler(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
|
self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128
|
||||||
): # max depth 2 (4**2 = 16 = patch size)
|
): # max depth 2 (2**2 = 16 = patch size)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def build_encoder(in_dim, num_layers=2):
|
def build_encoder(in_dim, num_layers=2):
|
||||||
@@ -285,7 +284,7 @@ class ResidualUpscaler(nn.Module):
|
|||||||
)
|
)
|
||||||
self.v_downsample = nn.ModuleList(
|
self.v_downsample = nn.ModuleList(
|
||||||
[
|
[
|
||||||
nn.Linear(config.hidden_size, config.hidden_size // (16**d))
|
nn.Linear(config.hidden_size, config.hidden_size // (4**d))
|
||||||
if d != 0
|
if d != 0
|
||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
for d in depth
|
for d in depth
|
||||||
@@ -296,7 +295,7 @@ class ResidualUpscaler(nn.Module):
|
|||||||
CrossAttention(
|
CrossAttention(
|
||||||
bottleneck_dim,
|
bottleneck_dim,
|
||||||
bottleneck_dim,
|
bottleneck_dim,
|
||||||
config.hidden_size // (16**d),
|
config.hidden_size // (4**d),
|
||||||
)
|
)
|
||||||
if d != 0
|
if d != 0
|
||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
@@ -310,10 +309,14 @@ class ResidualUpscaler(nn.Module):
|
|||||||
|
|
||||||
self.rope = RoPE(bottleneck_dim)
|
self.rope = RoPE(bottleneck_dim)
|
||||||
|
|
||||||
|
# ok just shuffle it; no dont
|
||||||
|
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
||||||
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
|
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
|
||||||
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
||||||
assert self.config.patch_size is not None
|
assert self.config.patch_size is not None
|
||||||
|
|
||||||
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
||||||
|
|
||||||
global_shift, global_scale = self.global_encode(
|
global_shift, global_scale = self.global_encode(
|
||||||
@@ -341,11 +344,11 @@ class ResidualUpscaler(nn.Module):
|
|||||||
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
||||||
local_q = self.q_norm[i](
|
local_q = self.q_norm[i](
|
||||||
einops.rearrange(
|
einops.rearrange(
|
||||||
F.adaptive_avg_pool2d(
|
torch.nn.functional.adaptive_avg_pool2d(
|
||||||
q,
|
q,
|
||||||
output_size=(
|
output_size=(
|
||||||
image_h // self.config.patch_size * (4**depth),
|
image_h // self.config.patch_size * (2**depth),
|
||||||
image_w // self.config.patch_size * (4**depth),
|
image_w // self.config.patch_size * (2**depth),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
"b c h w -> b (h w) c",
|
"b c h w -> b (h w) c",
|
||||||
@@ -355,7 +358,7 @@ class ResidualUpscaler(nn.Module):
|
|||||||
(1 + global_scale)
|
(1 + global_scale)
|
||||||
* self.k_norm[i](
|
* self.k_norm[i](
|
||||||
einops.rearrange(
|
einops.rearrange(
|
||||||
F.adaptive_avg_pool2d(
|
torch.nn.functional.adaptive_avg_pool2d(
|
||||||
k,
|
k,
|
||||||
output_size=(
|
output_size=(
|
||||||
image_h // self.config.patch_size,
|
image_h // self.config.patch_size,
|
||||||
@@ -382,8 +385,9 @@ class DinoV3ViTDecoder(nn.Module):
|
|||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
self.projection = nn.Linear(
|
self.projection = nn.Linear(
|
||||||
config.hidden_size // 16 // 16, config.num_channels, bias=True
|
config.hidden_size // 16, config.num_channels * 16, bias=True
|
||||||
)
|
)
|
||||||
|
self.upscale = nn.PixelShuffle(4)
|
||||||
|
|
||||||
nn.init.zeros_(self.projection.weight)
|
nn.init.zeros_(self.projection.weight)
|
||||||
nn.init.zeros_(
|
nn.init.zeros_(
|
||||||
@@ -391,11 +395,16 @@ class DinoV3ViTDecoder(nn.Module):
|
|||||||
) if self.projection.bias is not None else None
|
) if self.projection.bias is not None else None
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||||
return self.projection(
|
return self.upscale(
|
||||||
einops.rearrange(
|
self.projection(
|
||||||
x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1]
|
einops.rearrange(
|
||||||
)
|
x,
|
||||||
).permute(0, 3, 1, 2)
|
"b (h w) d -> b h w d",
|
||||||
|
h=image_size[0] // 4,
|
||||||
|
w=image_size[1] // 4,
|
||||||
|
)
|
||||||
|
).permute(0, 3, 1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UTransformer(nn.Module):
|
class UTransformer(nn.Module):
|
||||||
@@ -412,14 +421,14 @@ class UTransformer(nn.Module):
|
|||||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||||
|
|
||||||
def gen_rope(depth: int):
|
def gen_rope(depth: int):
|
||||||
hidden_size = config.hidden_size // (16**depth)
|
hidden_size = config.hidden_size // (4**depth)
|
||||||
hacky_config = copy.copy(config)
|
hacky_config = copy.copy(config)
|
||||||
hacky_config.hidden_size = hidden_size
|
hacky_config.hidden_size = hidden_size
|
||||||
hacky_config.intermediate_size = hacky_config.intermediate_size // (
|
hacky_config.intermediate_size = hacky_config.intermediate_size // (
|
||||||
16**depth
|
3**depth
|
||||||
)
|
)
|
||||||
hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
|
hacky_config.num_attention_heads = hacky_config.num_attention_heads // (
|
||||||
4**depth
|
2**depth
|
||||||
)
|
)
|
||||||
return DINOv3ViTRopePositionEmbedding(hacky_config)
|
return DINOv3ViTRopePositionEmbedding(hacky_config)
|
||||||
|
|
||||||
@@ -431,9 +440,10 @@ class UTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
DEPTH_LAYER = [0, 0, 1, 1, 2, 2]
|
||||||
self.residual_upscaler = ResidualUpscaler(
|
self.residual_upscaler = ResidualUpscaler(
|
||||||
config,
|
config,
|
||||||
[0, 0, 1, 1, 2, 2], # hardcoded, sorry
|
DEPTH_LAYER, # hardcoded, sorry
|
||||||
bottleneck_dim=128,
|
bottleneck_dim=128,
|
||||||
)
|
)
|
||||||
self.decoder_layers = nn.ModuleList(
|
self.decoder_layers = nn.ModuleList(
|
||||||
@@ -441,10 +451,10 @@ class UTransformer(nn.Module):
|
|||||||
nn.ModuleList(
|
nn.ModuleList(
|
||||||
[
|
[
|
||||||
DinoDecoderLayer(config, depth)
|
DinoDecoderLayer(config, depth)
|
||||||
for _ in range((config.num_hidden_layers // scale_factor) // 3)
|
for _ in range(DEPTH_LAYER.count(depth))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for depth in range(3)
|
for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.residual_merger = nn.ModuleList(
|
self.residual_merger = nn.ModuleList(
|
||||||
@@ -454,22 +464,22 @@ class UTransformer(nn.Module):
|
|||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
config.hidden_size // (16**depth),
|
config.hidden_size // (4**depth),
|
||||||
2 * config.hidden_size // (16**depth),
|
2 * config.hidden_size // (4**depth),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for _ in range((config.num_hidden_layers // scale_factor) // 3)
|
for _ in range(DEPTH_LAYER.count(depth))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for depth in range(3)
|
for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)])
|
self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
|
||||||
self.rest_decoder = nn.ModuleList(
|
self.rest_decoder = nn.ModuleList(
|
||||||
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
||||||
)
|
)
|
||||||
self.decoder_norm = nn.LayerNorm(
|
self.decoder_norm = nn.LayerNorm(
|
||||||
(config.hidden_size // (16**2)), eps=config.layer_norm_eps
|
(config.hidden_size // (4**2)), eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
self.decoder = DinoV3ViTDecoder(config)
|
self.decoder = DinoV3ViTDecoder(config)
|
||||||
|
|
||||||
@@ -536,7 +546,7 @@ class UTransformer(nn.Module):
|
|||||||
"b (h w) d -> b d h w",
|
"b (h w) d -> b d h w",
|
||||||
h=pixel_values.shape[-2]
|
h=pixel_values.shape[-2]
|
||||||
// (self.config.patch_size)
|
// (self.config.patch_size)
|
||||||
* (4**depth),
|
* (2**depth),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
"b d h w -> b (h w) d",
|
"b d h w -> b (h w) d",
|
||||||
@@ -550,8 +560,8 @@ class UTransformer(nn.Module):
|
|||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
pixel_values.shape[-2] * (4 ** (depth + 1)),
|
pixel_values.shape[-2] * (2 ** (depth + 1)),
|
||||||
pixel_values.shape[-1] * (4 ** (depth + 1)),
|
pixel_values.shape[-1] * (2 ** (depth + 1)),
|
||||||
),
|
),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
).to(self.embeddings.patch_embeddings.weight.dtype)
|
).to(self.embeddings.patch_embeddings.weight.dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user