try things

This commit is contained in:
neulus
2025-10-26 20:49:00 +09:00
parent 09c1c2220a
commit f51313a56c

View File

@@ -313,12 +313,19 @@ class ResidualUpscaler(nn.Module):
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)] # self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]): def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)] # residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, 1 + self.config.num_register_tokens + seq, d), (b, 1 + self.config.num_register_tokens + seq, d)]
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well # objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
assert self.config.patch_size is not None assert self.config.patch_size is not None
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1] image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
rest = [
residual[:, : 1 + self.config.num_register_tokens] for residual in residuals
]
residuals = [
residual[:, 1 + self.config.num_register_tokens :] for residual in residuals
]
global_shift, global_scale = self.global_encode( global_shift, global_scale = self.global_encode(
einops.rearrange( einops.rearrange(
torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)" torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)"
@@ -336,9 +343,9 @@ class ResidualUpscaler(nn.Module):
k = self.k_encoder(image_residual) k = self.k_encoder(image_residual)
reformed_residual = [] reformed_residual = []
for i, (depth, residual) in enumerate(zip(self.depth, residuals)): for i, (depth, residual, rest) in enumerate(zip(self.depth, residuals, rest)):
if depth == 0: if depth == 0:
reformed_residual.append(residual) reformed_residual.append(torch.cat((rest, residual), dim=1))
continue continue
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1) local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
@@ -372,7 +379,13 @@ class ResidualUpscaler(nn.Module):
) + local_shift ) + local_shift
local_v = self.v_downsample[i](residual) local_v = self.v_downsample[i](residual)
reformed_residual.append(self.cross_attn[i](local_q, local_k, local_v)) local_rest = self.v_downsample[i](rest)
final_residual = torch.concat(
(local_rest, self.cross_attn[i](local_q, local_k, local_v)), dim=1
)
reformed_residual.append(final_residual)
return reformed_residual return reformed_residual
@@ -475,6 +488,15 @@ class UTransformer(nn.Module):
] ]
) )
self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)]) self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
self.upsample_latent = nn.ModuleList(
[
nn.Linear(
config.hidden_size // (4**depth),
config.hidden_size // (4 ** (depth + 1)),
)
for depth in range(2)
]
)
self.rest_decoder = nn.ModuleList( self.rest_decoder = nn.ModuleList(
[DinoDecoderLayer(config, 2) for _ in range(4)] [DinoDecoderLayer(config, 2) for _ in range(4)]
) )
@@ -511,7 +533,7 @@ class UTransformer(nn.Module):
residual = [] residual = []
for i, layer_module in enumerate(self.encoder_layers): for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0: if i % self.scale_factor == 0:
residual.append(x[:, 1 + self.config.num_register_tokens :]) residual.append(x)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module( x = layer_module(
x, x,
@@ -521,7 +543,6 @@ class UTransformer(nn.Module):
) )
x = self.encoder_norm(x) x = self.encoder_norm(x)
x = x[:, 1 + self.config.num_register_tokens :]
reversed_residual = self.residual_upscaler(pixel_values, residual[::-1]) reversed_residual = self.residual_upscaler(pixel_values, residual[::-1])
residual_idx = 0 residual_idx = 0
@@ -538,21 +559,29 @@ class UTransformer(nn.Module):
).chunk(2, dim=-1) ).chunk(2, dim=-1)
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
residual_idx += 1 residual_idx += 1
x = ( x = torch.cat(
rearrange( (
self.upsample[depth]( rearrange(
rearrange( self.upsample[depth](
x, rearrange(
"b (h w) d -> b d h w", x[:, 1 + self.config.num_register_tokens :],
h=pixel_values.shape[-2] "b (h w) d -> b d h w",
// (self.config.patch_size) h=pixel_values.shape[-2]
* (2**depth), // (self.config.patch_size)
) * (2**depth),
), )
"b d h w -> b (h w) d", ),
) "b d h w -> b (h w) d",
if depth != 2 )
else x if depth != 2
else x[:, 1 + self.config.num_register_tokens :],
self.upsample_latent[depth](
x[:, : 1 + self.config.num_register_tokens]
)
if depth != 2
else x[:, : 1 + self.config.num_register_tokens],
),
dim=1,
) )
position_embeddings = ( position_embeddings = (
self.decode_ropes[depth]( self.decode_ropes[depth](
@@ -577,7 +606,7 @@ class UTransformer(nn.Module):
attention_mask=None, attention_mask=None,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
) )
x = x[:, 1 + self.config.num_register_tokens :]
x = self.decoder_norm(x) x = self.decoder_norm(x)
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual return self.decoder(x, image_size=pixel_values.shape[-2:]), residual