another approach for final decoder

This commit is contained in:
neulus
2025-09-29 23:20:27 +09:00
parent 12a165e461
commit 0ccf1ff42d
14 changed files with 88 additions and 24 deletions

View File

@@ -137,40 +137,103 @@ class DinoConditionedLayer(DINOv3ViTLayer):
return hidden_states
# class DinoV3ViTDecoder(nn.Module):
# def __init__(self, config: DINOv3ViTConfig):
# super().__init__()
# self.config = config
# self.num_channels_out = config.num_channels
# self.projection = nn.Linear(
# config.hidden_size,
# self.num_channels_out * config.patch_size * config.patch_size,
# bias=True,
# )
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
# batch_size = x.shape[0]
# num_special_tokens = 1 + self.config.num_register_tokens
# patch_tokens = x[:, num_special_tokens:, :]
# projected_tokens = self.projection(patch_tokens)
# p = self.config.patch_size
# c = self.num_channels_out
# h_grid = image_size[0] // p
# w_grid = image_size[1] // p
# assert patch_tokens.shape[1] == h_grid * w_grid, (
# "Number of patches does not match image size."
# )
# x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
# x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
# reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
# return reconstructed_image
# lets try conv decoder
class DinoV3ViTDecoder(nn.Module):
def __init__(self, config: DINOv3ViTConfig):
super().__init__()
self.config = config
self.num_channels_out = config.num_channels
hidden_dim = config.hidden_size
patch_size = config.patch_size
self.projection = nn.Linear(
config.hidden_size,
self.num_channels_out * config.patch_size * config.patch_size,
bias=True,
self.projection = nn.Linear(hidden_dim, hidden_dim)
if patch_size == 14:
final_upsample = 7
elif patch_size == 16:
final_upsample = 8
elif patch_size == 8:
final_upsample = 4
else:
raise ValueError("invalid")
self.decoder = nn.Sequential(
nn.Conv2d(hidden_dim, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Upsample(
scale_factor=final_upsample, mode="bilinear", align_corners=False
),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, self.num_channels_out, kernel_size=1),
)
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
batch_size = x.shape[0]
num_special_tokens = 1 + self.config.num_register_tokens
patch_tokens = x[:, num_special_tokens:, :]
patch_tokens = x[:, 1 + self.config.num_register_tokens :, :]
projected_tokens = self.projection(patch_tokens)
p = self.config.patch_size
c = self.num_channels_out
h_grid = image_size[0] // p
w_grid = image_size[1] // p
assert patch_tokens.shape[1] == h_grid * w_grid, (
"Number of patches does not match image size."
assert patch_tokens.shape[1] == h_grid * w_grid
x_spatial = projected_tokens.reshape(
batch_size, h_grid, w_grid, self.config.hidden_size
)
x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
x_spatial = x_spatial.permute(0, 3, 1, 2)
reconstructed_image = self.decoder(x_spatial)
return reconstructed_image