pixel shuffle
This commit is contained in:
@@ -182,60 +182,32 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_channels_out = config.num_channels
|
||||
hidden_dim = config.hidden_size
|
||||
patch_size = config.patch_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.projection = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
if patch_size == 14:
|
||||
final_upsample = 7
|
||||
elif patch_size == 16:
|
||||
final_upsample = 8
|
||||
elif patch_size == 8:
|
||||
final_upsample = 4
|
||||
else:
|
||||
raise ValueError("invalid")
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
||||
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Upsample(
|
||||
scale_factor=final_upsample, mode="bilinear", align_corners=False
|
||||
),
|
||||
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, self.num_channels_out, kernel_size=1),
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.num_channels_out * (self.patch_size**2),
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
patch_tokens = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
projected_tokens = self.projection(patch_tokens)
|
||||
x = self.projection(x)
|
||||
|
||||
p = self.config.patch_size
|
||||
h_grid = image_size[0] // p
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert patch_tokens.shape[1] == h_grid * w_grid
|
||||
assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
x_spatial = projected_tokens.reshape(
|
||||
batch_size, h_grid, w_grid, self.config.hidden_size
|
||||
)
|
||||
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
||||
|
||||
x_spatial = x_spatial.permute(0, 3, 1, 2)
|
||||
reconstructed_image = self.decoder(x_spatial)
|
||||
|
||||
return reconstructed_image
|
||||
return self.pixel_shuffle(x)
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
@@ -256,7 +228,6 @@ class UTransformer(nn.Module):
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
@@ -269,7 +240,6 @@ class UTransformer(nn.Module):
|
||||
# freeze pretrained
|
||||
self.embeddings.requires_grad_(False)
|
||||
self.rope_embeddings.requires_grad_(False)
|
||||
self.encoder_norm.requires_grad_(False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -298,7 +268,6 @@ class UTransformer(nn.Module):
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
x = self.encoder_norm(x)
|
||||
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
@@ -315,16 +284,20 @@ class UTransformer(nn.Module):
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = UTransformer(config, 0).to("cuda:3")
|
||||
instance = UTransformer(config, 0).to("cuda:2")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:3"
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:2"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
new_key = key.replace("layer.", "encoder_layers.").replace(
|
||||
"norm.", "encoder_norm."
|
||||
)
|
||||
|
||||
if key.startswith("norm."):
|
||||
continue
|
||||
|
||||
weight_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
instance.load_state_dict(weight_dict, strict=False)
|
||||
|
||||
Reference in New Issue
Block a user