diff --git a/src/model/utransformer.py b/src/model/utransformer.py index 3a87acb..2d37754 100644 --- a/src/model/utransformer.py +++ b/src/model/utransformer.py @@ -313,12 +313,19 @@ class ResidualUpscaler(nn.Module): # self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)] def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]): - # residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)] + # residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, 1 + self.config.num_register_tokens + seq, d), (b, 1 + self.config.num_register_tokens + seq, d)] # objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well assert self.config.patch_size is not None image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1] + rest = [ + residual[:, : 1 + self.config.num_register_tokens] for residual in residuals + ] + residuals = [ + residual[:, 1 + self.config.num_register_tokens :] for residual in residuals + ] + global_shift, global_scale = self.global_encode( einops.rearrange( torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)" @@ -336,9 +343,9 @@ class ResidualUpscaler(nn.Module): k = self.k_encoder(image_residual) reformed_residual = [] - for i, (depth, residual) in enumerate(zip(self.depth, residuals)): + for i, (depth, residual, rest) in enumerate(zip(self.depth, residuals, rest)): if depth == 0: - reformed_residual.append(residual) + reformed_residual.append(torch.cat((rest, residual), dim=1)) continue local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1) @@ -372,7 +379,13 @@ class ResidualUpscaler(nn.Module): ) + local_shift local_v = self.v_downsample[i](residual) - reformed_residual.append(self.cross_attn[i](local_q, local_k, local_v)) + local_rest = self.v_downsample[i](rest) + + final_residual = torch.concat( + (local_rest, self.cross_attn[i](local_q, local_k, local_v)), dim=1 + ) + + reformed_residual.append(final_residual) return reformed_residual @@ -475,6 +488,15 @@ class UTransformer(nn.Module): ] ) self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)]) + self.upsample_latent = nn.ModuleList( + [ + nn.Linear( + config.hidden_size // (4**depth), + config.hidden_size // (4 ** (depth + 1)), + ) + for depth in range(2) + ] + ) self.rest_decoder = nn.ModuleList( [DinoDecoderLayer(config, 2) for _ in range(4)] ) @@ -511,7 +533,7 @@ class UTransformer(nn.Module): residual = [] for i, layer_module in enumerate(self.encoder_layers): if i % self.scale_factor == 0: - residual.append(x[:, 1 + self.config.num_register_tokens :]) + residual.append(x) layer_head_mask = head_mask[i] if head_mask is not None else None x = layer_module( x, @@ -521,7 +543,6 @@ class UTransformer(nn.Module): ) x = self.encoder_norm(x) - x = x[:, 1 + self.config.num_register_tokens :] reversed_residual = self.residual_upscaler(pixel_values, residual[::-1]) residual_idx = 0 @@ -538,21 +559,29 @@ class UTransformer(nn.Module): ).chunk(2, dim=-1) x = x * (1 + scale) + shift residual_idx += 1 - x = ( - rearrange( - self.upsample[depth]( - rearrange( - x, - "b (h w) d -> b d h w", - h=pixel_values.shape[-2] - // (self.config.patch_size) - * (2**depth), - ) - ), - "b d h w -> b (h w) d", - ) - if depth != 2 - else x + x = torch.cat( + ( + rearrange( + self.upsample[depth]( + rearrange( + x[:, 1 + self.config.num_register_tokens :], + "b (h w) d -> b d h w", + h=pixel_values.shape[-2] + // (self.config.patch_size) + * (2**depth), + ) + ), + "b d h w -> b (h w) d", + ) + if depth != 2 + else x[:, 1 + self.config.num_register_tokens :], + self.upsample_latent[depth]( + x[:, : 1 + self.config.num_register_tokens] + ) + if depth != 2 + else x[:, : 1 + self.config.num_register_tokens], + ), + dim=1, ) position_embeddings = ( self.decode_ropes[depth]( @@ -577,7 +606,7 @@ class UTransformer(nn.Module): attention_mask=None, position_embeddings=position_embeddings, ) - + x = x[:, 1 + self.config.num_register_tokens :] x = self.decoder_norm(x) return self.decoder(x, image_size=pixel_values.shape[-2:]), residual