things
This commit is contained in:
@@ -188,30 +188,83 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.num_channels_out * (self.patch_size**2),
|
||||
config.num_channels * (self.patch_size**2),
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||
|
||||
nn.init.zeros_(self.projection.weight)
|
||||
nn.init.zeros_(self.projection.bias)
|
||||
nn.init.zeros_(
|
||||
self.projection.bias
|
||||
) if self.projection.bias is not None else None
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
x = self.projection(x)
|
||||
|
||||
p = self.config.patch_size
|
||||
h_grid = image_size[0] // p
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
x = self.projection(x)
|
||||
|
||||
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
||||
|
||||
return self.pixel_shuffle(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# how about transposed conv decoderclass DinoV3ViTDecoder(nn.Module):
|
||||
# class DinoV3ViTDecoder(nn.Module):
|
||||
# def __init__(self, config: DINOv3ViTConfig):
|
||||
# super().__init__()
|
||||
# self.config = config
|
||||
# self.num_channels_out = config.num_channels
|
||||
# self.patch_size = config.patch_size
|
||||
|
||||
# intermediate_channels = config.hidden_size // 4
|
||||
|
||||
# self.decoder_block = nn.Sequential(
|
||||
# nn.ConvTranspose2d(
|
||||
# in_channels=config.hidden_size,
|
||||
# out_channels=intermediate_channels,
|
||||
# kernel_size=self.patch_size,
|
||||
# stride=self.patch_size,
|
||||
# bias=True,
|
||||
# ),
|
||||
# nn.LayerNorm(intermediate_channels),
|
||||
# nn.GELU(),
|
||||
# nn.Conv2d(
|
||||
# in_channels=intermediate_channels,
|
||||
# out_channels=config.num_channels,
|
||||
# kernel_size=1,
|
||||
# bias=True,
|
||||
# ),
|
||||
# )
|
||||
|
||||
# nn.init.zeros_(self.decoder_block[-1].weight)
|
||||
# nn.init.zeros_(self.decoder_block[-1].bias)
|
||||
|
||||
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
# batch_size = x.shape[0]
|
||||
|
||||
# x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
# p = self.config.patch_size
|
||||
# h_grid = image_size[0] // p
|
||||
# w_grid = image_size[1] // p
|
||||
# assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
# x = x.transpose(1, 2).reshape(
|
||||
# batch_size, self.config.hidden_size, h_grid, w_grid
|
||||
# )
|
||||
# x = self.decoder_block(x)
|
||||
|
||||
# return x
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
@@ -261,6 +314,9 @@ class UTransformer(nn.Module):
|
||||
bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if time.dim() == 0:
|
||||
time = time.repeat(pixel_values.shape[0])
|
||||
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user