diff --git a/main.py b/main.py index f44cd7d..ab89deb 100644 --- a/main.py +++ b/main.py @@ -137,50 +137,51 @@ for epoch in range(start_epoch, total_epoch): wandb.log(epoch_metrics) if (epoch + 1) % 50 == 0: - rf.model.eval() - psnr_sum = 0 - ssim_sum = 0 - lpips_sum = 0 - count = 0 + with autocast(dtype=torch.bfloat16): + rf.model.eval() + psnr_sum = 0 + ssim_sum = 0 + lpips_sum = 0 + count = 0 - with torch.no_grad(): - for i in tqdm( - range(0, len(test_dataset), batch_size), - desc=f"Benchmark {epoch + 1}/{total_epoch}", - ): - batch = test_dataset[i : i + batch_size] - images = rf.sample(batch["cloud"].to(device)) - image = denormalize(images[-1]).clamp(0, 1) - original = denormalize(batch["gt"]).clamp(0, 1) + with torch.no_grad(): + for i in tqdm( + range(0, len(test_dataset), batch_size), + desc=f"Benchmark {epoch + 1}/{total_epoch}", + ): + batch = test_dataset[i : i + batch_size] + images = rf.sample(batch["cloud"].to(device)) + image = denormalize(images[-1]).clamp(0, 1) + original = denormalize(batch["gt"]).clamp(0, 1) - if i == 0: - for step, demo in enumerate([images[0], images[-1]]): - images = wandb.Image( - make_grid( - denormalize(demo).clamp(0, 1).float()[:4], nrow=2 - ), - caption=f"step {step}", - ) - wandb.log({"viz/decoded": images}) + if i == 0: + for step, demo in enumerate([images[0], images[-1]]): + images = wandb.Image( + make_grid( + denormalize(demo).clamp(0, 1).float()[:4], nrow=2 + ), + caption=f"step {step}", + ) + wandb.log({"viz/decoded": images}) - psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) - psnr_sum += psnr.sum().item() - ssim_sum += ssim.sum().item() - lpips_sum += lpips.sum().item() - count += image.shape[0] + psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) + psnr_sum += psnr.sum().item() + ssim_sum += ssim.sum().item() + lpips_sum += lpips.sum().item() + count += image.shape[0] - avg_psnr = psnr_sum / count - avg_ssim = ssim_sum / count - avg_lpips = lpips_sum / count - wandb.log( - { - "eval/psnr": avg_psnr, - "eval/ssim": avg_ssim, - "eval/lpips": avg_lpips, - "epoch": epoch + 1, - } - ) - rf.model.train() + avg_psnr = psnr_sum / count + avg_ssim = ssim_sum / count + avg_lpips = lpips_sum / count + wandb.log( + { + "eval/psnr": avg_psnr, + "eval/ssim": avg_ssim, + "eval/lpips": avg_lpips, + "epoch": epoch + 1, + } + ) + rf.model.train() torch.save( { diff --git a/src/model/utransformer.py b/src/model/utransformer.py index fcfc4a5..3a87acb 100644 --- a/src/model/utransformer.py +++ b/src/model/utransformer.py @@ -4,7 +4,6 @@ from typing import Optional import einops import torch -import torch.nn.functional as F from einops import rearrange from huggingface_hub import hf_hub_download from safetensors import safe_open @@ -162,14 +161,14 @@ class DinoDecoderLayer(DINOv3ViTLayer): def __init__(self, config: DINOv3ViTConfig, depth: int): super().__init__(config) - hidden_size = config.hidden_size // (16**depth) + hidden_size = config.hidden_size // (4**depth) hacky_config = copy.copy(config) hacky_config.hidden_size = hidden_size - hacky_config.intermediate_size = hacky_config.intermediate_size // (16**depth) + hacky_config.intermediate_size = hacky_config.intermediate_size // (3**depth) self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) self.attention = PlainAttention( - hidden_size, config.num_attention_heads // (4**depth) + hidden_size, config.num_attention_heads // (2**depth) ) # head scaling law? self.layer_scale1 = DINOv3ViTLayerScale(hacky_config) self.drop_path = ( @@ -235,7 +234,7 @@ class DinoDecoderLayer(DINOv3ViTLayer): class ResidualUpscaler(nn.Module): def __init__( self, config: DINOv3ViTConfig, depth: list[int], bottleneck_dim: int = 128 - ): # max depth 2 (4**2 = 16 = patch size) + ): # max depth 2 (2**2 = 16 = patch size) super().__init__() def build_encoder(in_dim, num_layers=2): @@ -285,7 +284,7 @@ class ResidualUpscaler(nn.Module): ) self.v_downsample = nn.ModuleList( [ - nn.Linear(config.hidden_size, config.hidden_size // (16**d)) + nn.Linear(config.hidden_size, config.hidden_size // (4**d)) if d != 0 else nn.Identity() for d in depth @@ -296,7 +295,7 @@ class ResidualUpscaler(nn.Module): CrossAttention( bottleneck_dim, bottleneck_dim, - config.hidden_size // (16**d), + config.hidden_size // (4**d), ) if d != 0 else nn.Identity() @@ -310,10 +309,14 @@ class ResidualUpscaler(nn.Module): self.rope = RoPE(bottleneck_dim) + # ok just shuffle it; no dont + # self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)] + def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]): # residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)] # objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well assert self.config.patch_size is not None + image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1] global_shift, global_scale = self.global_encode( @@ -341,11 +344,11 @@ class ResidualUpscaler(nn.Module): local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1) local_q = self.q_norm[i]( einops.rearrange( - F.adaptive_avg_pool2d( + torch.nn.functional.adaptive_avg_pool2d( q, output_size=( - image_h // self.config.patch_size * (4**depth), - image_w // self.config.patch_size * (4**depth), + image_h // self.config.patch_size * (2**depth), + image_w // self.config.patch_size * (2**depth), ), ), "b c h w -> b (h w) c", @@ -355,7 +358,7 @@ class ResidualUpscaler(nn.Module): (1 + global_scale) * self.k_norm[i]( einops.rearrange( - F.adaptive_avg_pool2d( + torch.nn.functional.adaptive_avg_pool2d( k, output_size=( image_h // self.config.patch_size, @@ -382,8 +385,9 @@ class DinoV3ViTDecoder(nn.Module): self.patch_size = config.patch_size self.projection = nn.Linear( - config.hidden_size // 16 // 16, config.num_channels, bias=True + config.hidden_size // 16, config.num_channels * 16, bias=True ) + self.upscale = nn.PixelShuffle(4) nn.init.zeros_(self.projection.weight) nn.init.zeros_( @@ -391,11 +395,16 @@ class DinoV3ViTDecoder(nn.Module): ) if self.projection.bias is not None else None def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: - return self.projection( - einops.rearrange( - x, "b (h w) d -> b h w d", h=image_size[0], w=image_size[1] - ) - ).permute(0, 3, 1, 2) + return self.upscale( + self.projection( + einops.rearrange( + x, + "b (h w) d -> b h w d", + h=image_size[0] // 4, + w=image_size[1] // 4, + ) + ).permute(0, 3, 1, 2) + ) class UTransformer(nn.Module): @@ -412,14 +421,14 @@ class UTransformer(nn.Module): self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) def gen_rope(depth: int): - hidden_size = config.hidden_size // (16**depth) + hidden_size = config.hidden_size // (4**depth) hacky_config = copy.copy(config) hacky_config.hidden_size = hidden_size hacky_config.intermediate_size = hacky_config.intermediate_size // ( - 16**depth + 3**depth ) hacky_config.num_attention_heads = hacky_config.num_attention_heads // ( - 4**depth + 2**depth ) return DINOv3ViTRopePositionEmbedding(hacky_config) @@ -431,9 +440,10 @@ class UTransformer(nn.Module): ) self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + DEPTH_LAYER = [0, 0, 1, 1, 2, 2] self.residual_upscaler = ResidualUpscaler( config, - [0, 0, 1, 1, 2, 2], # hardcoded, sorry + DEPTH_LAYER, # hardcoded, sorry bottleneck_dim=128, ) self.decoder_layers = nn.ModuleList( @@ -441,10 +451,10 @@ class UTransformer(nn.Module): nn.ModuleList( [ DinoDecoderLayer(config, depth) - for _ in range((config.num_hidden_layers // scale_factor) // 3) + for _ in range(DEPTH_LAYER.count(depth)) ] ) - for depth in range(3) + for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index) ] ) self.residual_merger = nn.ModuleList( @@ -454,22 +464,22 @@ class UTransformer(nn.Module): nn.Sequential( nn.SiLU(), nn.Linear( - config.hidden_size // (16**depth), - 2 * config.hidden_size // (16**depth), + config.hidden_size // (4**depth), + 2 * config.hidden_size // (4**depth), ), ) - for _ in range((config.num_hidden_layers // scale_factor) // 3) + for _ in range(DEPTH_LAYER.count(depth)) ] ) - for depth in range(3) + for depth in sorted(set(DEPTH_LAYER), key=DEPTH_LAYER.index) ] ) - self.upsample = nn.ModuleList([nn.PixelShuffle(4) for _ in range(2)]) + self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)]) self.rest_decoder = nn.ModuleList( [DinoDecoderLayer(config, 2) for _ in range(4)] ) self.decoder_norm = nn.LayerNorm( - (config.hidden_size // (16**2)), eps=config.layer_norm_eps + (config.hidden_size // (4**2)), eps=config.layer_norm_eps ) self.decoder = DinoV3ViTDecoder(config) @@ -536,7 +546,7 @@ class UTransformer(nn.Module): "b (h w) d -> b d h w", h=pixel_values.shape[-2] // (self.config.patch_size) - * (4**depth), + * (2**depth), ) ), "b d h w -> b (h w) d", @@ -550,8 +560,8 @@ class UTransformer(nn.Module): ( 1, 1, - pixel_values.shape[-2] * (4 ** (depth + 1)), - pixel_values.shape[-1] * (4 ** (depth + 1)), + pixel_values.shape[-2] * (2 ** (depth + 1)), + pixel_values.shape[-1] * (2 ** (depth + 1)), ), device=x.device, ).to(self.embeddings.patch_embeddings.weight.dtype)