try things
This commit is contained in:
@@ -313,12 +313,19 @@ class ResidualUpscaler(nn.Module):
|
|||||||
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
|
# self.pixel_shuffle = [nn.PixelShuffle(2), nn.PixelShuffle(4)]
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
def forward(self, pixel_values: torch.Tensor, residuals: list[torch.Tensor]):
|
||||||
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, seq, d), (b, seq, d)]
|
# residual[0] => deepest, -1 => shallowest; pixel values (b, 3, h, w) / residuals [(b, 1 + self.config.num_register_tokens + seq, d), (b, 1 + self.config.num_register_tokens + seq, d)]
|
||||||
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
# objective: say we have (1024, 1024, 512) residual. we want to make multi head attention query well
|
||||||
assert self.config.patch_size is not None
|
assert self.config.patch_size is not None
|
||||||
|
|
||||||
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
image_h, image_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
||||||
|
|
||||||
|
rest = [
|
||||||
|
residual[:, : 1 + self.config.num_register_tokens] for residual in residuals
|
||||||
|
]
|
||||||
|
residuals = [
|
||||||
|
residual[:, 1 + self.config.num_register_tokens :] for residual in residuals
|
||||||
|
]
|
||||||
|
|
||||||
global_shift, global_scale = self.global_encode(
|
global_shift, global_scale = self.global_encode(
|
||||||
einops.rearrange(
|
einops.rearrange(
|
||||||
torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)"
|
torch.stack(residuals, dim=1), "b depth s h -> b s (depth h)"
|
||||||
@@ -336,9 +343,9 @@ class ResidualUpscaler(nn.Module):
|
|||||||
k = self.k_encoder(image_residual)
|
k = self.k_encoder(image_residual)
|
||||||
|
|
||||||
reformed_residual = []
|
reformed_residual = []
|
||||||
for i, (depth, residual) in enumerate(zip(self.depth, residuals)):
|
for i, (depth, residual, rest) in enumerate(zip(self.depth, residuals, rest)):
|
||||||
if depth == 0:
|
if depth == 0:
|
||||||
reformed_residual.append(residual)
|
reformed_residual.append(torch.cat((rest, residual), dim=1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
local_shift, local_scale = self.local_encode[i](residual).chunk(2, dim=-1)
|
||||||
@@ -372,7 +379,13 @@ class ResidualUpscaler(nn.Module):
|
|||||||
) + local_shift
|
) + local_shift
|
||||||
local_v = self.v_downsample[i](residual)
|
local_v = self.v_downsample[i](residual)
|
||||||
|
|
||||||
reformed_residual.append(self.cross_attn[i](local_q, local_k, local_v))
|
local_rest = self.v_downsample[i](rest)
|
||||||
|
|
||||||
|
final_residual = torch.concat(
|
||||||
|
(local_rest, self.cross_attn[i](local_q, local_k, local_v)), dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
reformed_residual.append(final_residual)
|
||||||
|
|
||||||
return reformed_residual
|
return reformed_residual
|
||||||
|
|
||||||
@@ -475,6 +488,15 @@ class UTransformer(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
|
self.upsample = nn.ModuleList([nn.PixelShuffle(2) for _ in range(2)])
|
||||||
|
self.upsample_latent = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Linear(
|
||||||
|
config.hidden_size // (4**depth),
|
||||||
|
config.hidden_size // (4 ** (depth + 1)),
|
||||||
|
)
|
||||||
|
for depth in range(2)
|
||||||
|
]
|
||||||
|
)
|
||||||
self.rest_decoder = nn.ModuleList(
|
self.rest_decoder = nn.ModuleList(
|
||||||
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
[DinoDecoderLayer(config, 2) for _ in range(4)]
|
||||||
)
|
)
|
||||||
@@ -511,7 +533,7 @@ class UTransformer(nn.Module):
|
|||||||
residual = []
|
residual = []
|
||||||
for i, layer_module in enumerate(self.encoder_layers):
|
for i, layer_module in enumerate(self.encoder_layers):
|
||||||
if i % self.scale_factor == 0:
|
if i % self.scale_factor == 0:
|
||||||
residual.append(x[:, 1 + self.config.num_register_tokens :])
|
residual.append(x)
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
x = layer_module(
|
x = layer_module(
|
||||||
x,
|
x,
|
||||||
@@ -521,7 +543,6 @@ class UTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
x = self.encoder_norm(x)
|
x = self.encoder_norm(x)
|
||||||
x = x[:, 1 + self.config.num_register_tokens :]
|
|
||||||
reversed_residual = self.residual_upscaler(pixel_values, residual[::-1])
|
reversed_residual = self.residual_upscaler(pixel_values, residual[::-1])
|
||||||
|
|
||||||
residual_idx = 0
|
residual_idx = 0
|
||||||
@@ -538,21 +559,29 @@ class UTransformer(nn.Module):
|
|||||||
).chunk(2, dim=-1)
|
).chunk(2, dim=-1)
|
||||||
x = x * (1 + scale) + shift
|
x = x * (1 + scale) + shift
|
||||||
residual_idx += 1
|
residual_idx += 1
|
||||||
x = (
|
x = torch.cat(
|
||||||
rearrange(
|
(
|
||||||
self.upsample[depth](
|
rearrange(
|
||||||
rearrange(
|
self.upsample[depth](
|
||||||
x,
|
rearrange(
|
||||||
"b (h w) d -> b d h w",
|
x[:, 1 + self.config.num_register_tokens :],
|
||||||
h=pixel_values.shape[-2]
|
"b (h w) d -> b d h w",
|
||||||
// (self.config.patch_size)
|
h=pixel_values.shape[-2]
|
||||||
* (2**depth),
|
// (self.config.patch_size)
|
||||||
)
|
* (2**depth),
|
||||||
),
|
)
|
||||||
"b d h w -> b (h w) d",
|
),
|
||||||
)
|
"b d h w -> b (h w) d",
|
||||||
if depth != 2
|
)
|
||||||
else x
|
if depth != 2
|
||||||
|
else x[:, 1 + self.config.num_register_tokens :],
|
||||||
|
self.upsample_latent[depth](
|
||||||
|
x[:, : 1 + self.config.num_register_tokens]
|
||||||
|
)
|
||||||
|
if depth != 2
|
||||||
|
else x[:, : 1 + self.config.num_register_tokens],
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
position_embeddings = (
|
position_embeddings = (
|
||||||
self.decode_ropes[depth](
|
self.decode_ropes[depth](
|
||||||
@@ -577,7 +606,7 @@ class UTransformer(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
x = x[:, 1 + self.config.num_register_tokens :]
|
||||||
x = self.decoder_norm(x)
|
x = self.decoder_norm(x)
|
||||||
|
|
||||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||||
|
|||||||
Reference in New Issue
Block a user