This commit is contained in:
neulus
2025-10-10 15:55:35 +09:00
parent 6bb6c09638
commit c47d91a349
10 changed files with 1381 additions and 112 deletions

View File

@@ -188,30 +188,83 @@ class DinoV3ViTDecoder(nn.Module):
self.projection = nn.Linear(
config.hidden_size,
self.num_channels_out * (self.patch_size**2),
config.num_channels * (self.patch_size**2),
bias=True,
)
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
nn.init.zeros_(self.projection.weight)
nn.init.zeros_(self.projection.bias)
nn.init.zeros_(
self.projection.bias
) if self.projection.bias is not None else None
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
batch_size = x.shape[0]
x = x[:, 1 + self.config.num_register_tokens :, :]
x = self.projection(x)
p = self.config.patch_size
h_grid = image_size[0] // p
w_grid = image_size[1] // p
assert x.shape[1] == h_grid * w_grid
x = self.projection(x)
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
return self.pixel_shuffle(x)
x = self.pixel_shuffle(x)
return x
# how about transposed conv decoderclass DinoV3ViTDecoder(nn.Module):
# class DinoV3ViTDecoder(nn.Module):
# def __init__(self, config: DINOv3ViTConfig):
# super().__init__()
# self.config = config
# self.num_channels_out = config.num_channels
# self.patch_size = config.patch_size
# intermediate_channels = config.hidden_size // 4
# self.decoder_block = nn.Sequential(
# nn.ConvTranspose2d(
# in_channels=config.hidden_size,
# out_channels=intermediate_channels,
# kernel_size=self.patch_size,
# stride=self.patch_size,
# bias=True,
# ),
# nn.LayerNorm(intermediate_channels),
# nn.GELU(),
# nn.Conv2d(
# in_channels=intermediate_channels,
# out_channels=config.num_channels,
# kernel_size=1,
# bias=True,
# ),
# )
# nn.init.zeros_(self.decoder_block[-1].weight)
# nn.init.zeros_(self.decoder_block[-1].bias)
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
# batch_size = x.shape[0]
# x = x[:, 1 + self.config.num_register_tokens :, :]
# p = self.config.patch_size
# h_grid = image_size[0] // p
# w_grid = image_size[1] // p
# assert x.shape[1] == h_grid * w_grid
# x = x.transpose(1, 2).reshape(
# batch_size, self.config.hidden_size, h_grid, w_grid
# )
# x = self.decoder_block(x)
# return x
class UTransformer(nn.Module):
@@ -261,6 +314,9 @@ class UTransformer(nn.Module):
bool_masked_pos: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
if time.dim() == 0:
time = time.repeat(pixel_values.shape[0])
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)