try things
This commit is contained in:
@@ -313,12 +313,19 @@ class ResidualUpscaler(nn.Module):
|
||||
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
||||
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
|
||||
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, 1 + self.config.num_register_tokens + seq, d), (b, 1 + self.config.num_register_tokens + seq, d)]
|
||||
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
||||
assert self.config.patch_size is not None
|
||||
|
||||
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
||||
|
||||
rest = [
|
||||
residual[:, : 1 + self.config.num_register_tokens] for residual in residuals
|
||||
]
|
||||
residuals = [
|
||||
residual[:, 1 + self.config.num_register_tokens :] for residual in residuals
|
||||
]
|
||||
|
||||
global_shift, global_scale = self.global_encode(
|
||||
einops.rearrange(
|
||||
torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)"
|
||||
@@ -336,9 +343,9 @@ class ResidualUpscaler(nn.Module):
|
||||
k = self.k_encoder(image_residual)
|
||||
|
||||
reformed_residual = []
|
||||
for i, (depth, residual) in enumerate(zip(self.depth, residuals)):
|
||||
for i, (depth, residual, rest) in enumerate(zip(self.depth, residuals, rest)):
|
||||
if depth == 0:
|
||||
reformed_residual.append(residual)
|
||||
reformed_residual.append(torch.cat((rest, residual), dim=1))
|
||||
continue
|
||||
|
||||
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
||||
@@ -372,7 +379,13 @@ class ResidualUpscaler(nn.Module):
|
||||
) + local_shift
|
||||
local_v = self.v_downsample[i](residual)
|
||||
|
||||
reformed_residual.append(self.cross_attn[i](local_q, local_k, local_v))
|
||||
local_rest = self.v_downsample[i](rest)
|
||||
|
||||
final_residual = torch.concat(
|
||||
(local_rest, self.cross_attn[i](local_q, local_k, local_v)), dim=1
|
||||
)
|
||||
|
||||
reformed_residual.append(final_residual)
|
||||
|
||||
return reformed_residual
|
||||
|
||||
@@ -475,6 +488,15 @@ class UTransformer(nn.Module):
|
||||
]
|
||||
)
|
||||
self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
|
||||
self.upsample_latent = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(
|
||||
config.hidden_size // (4**depth),
|
||||
config.hidden_size // (4 ** (depth + 1)),
|
||||
)
|
||||
for depth in range(2)
|
||||
]
|
||||
)
|
||||
self.rest_decoder = nn.ModuleList(
|
||||
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
||||
)
|
||||
@@ -511,7 +533,7 @@ class UTransformer(nn.Module):
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x[:, 1 + self.config.num_register_tokens :])
|
||||
residual.append(x)
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
@@ -521,7 +543,6 @@ class UTransformer(nn.Module):
|
||||
)
|
||||
|
||||
x = self.encoder_norm(x)
|
||||
x = x[:, 1 + self.config.num_register_tokens :]
|
||||
reversed_residual = self.residual_upscaler(pixel_values, residual[::-1])
|
||||
|
||||
residual_idx = 0
|
||||
@@ -538,21 +559,29 @@ class UTransformer(nn.Module):
|
||||
).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
residual_idx += 1
|
||||
x = (
|
||||
rearrange(
|
||||
self.upsample[depth](
|
||||
rearrange(
|
||||
x,
|
||||
"b (h w) d -> b d h w",
|
||||
h=pixel_values.shape[-2]
|
||||
// (self.config.patch_size)
|
||||
* (2**depth),
|
||||
)
|
||||
),
|
||||
"b d h w -> b (h w) d",
|
||||
)
|
||||
if depth != 2
|
||||
else x
|
||||
x = torch.cat(
|
||||
(
|
||||
rearrange(
|
||||
self.upsample[depth](
|
||||
rearrange(
|
||||
x[:, 1 + self.config.num_register_tokens :],
|
||||
"b (h w) d -> b d h w",
|
||||
h=pixel_values.shape[-2]
|
||||
// (self.config.patch_size)
|
||||
* (2**depth),
|
||||
)
|
||||
),
|
||||
"b d h w -> b (h w) d",
|
||||
)
|
||||
if depth != 2
|
||||
else x[:, 1 + self.config.num_register_tokens :],
|
||||
self.upsample_latent[depth](
|
||||
x[:, : 1 + self.config.num_register_tokens]
|
||||
)
|
||||
if depth != 2
|
||||
else x[:, : 1 + self.config.num_register_tokens],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
position_embeddings = (
|
||||
self.decode_ropes[depth](
|
||||
@@ -577,7 +606,7 @@ class UTransformer(nn.Module):
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
x = x[:, 1 + self.config.num_register_tokens :]
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
Reference in New Issue
Block a user