fix wrong average of psnr
This commit is contained in:
@@ -105,10 +105,11 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
conditioning_input: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
do_condition: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert position_embeddings is not None
|
||||
assert conditioning_input is not None
|
||||
assert conditioning_input is not None or not do_condition
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
@@ -120,13 +121,14 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_cond(hidden_states)
|
||||
hidden_states, _ = self.cond(
|
||||
hidden_states, conditioning_input, conditioning_input
|
||||
)
|
||||
hidden_states = self.layer_scale_cond(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
if do_condition:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_cond(hidden_states)
|
||||
hidden_states, _ = self.cond(
|
||||
hidden_states, conditioning_input, conditioning_input
|
||||
)
|
||||
hidden_states = self.layer_scale_cond(hidden_states)
|
||||
hidden_states = self.drop_path(hidden_states) + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
@@ -191,6 +193,8 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
)
|
||||
|
||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||
nn.init.zeros_(self.projection.weight)
|
||||
nn.init.zeros_(self.projection.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
@@ -211,9 +215,14 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
def __init__(self, config: DINOv3ViTConfig, num_classes: int):
|
||||
def __init__(
|
||||
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
assert config.num_hidden_layers % scale_factor == 0
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
@@ -228,18 +237,21 @@ class UTransformer(nn.Module):
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, False)
|
||||
for _ in range(config.num_hidden_layers // 2)
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
]
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
# freeze pretrained
|
||||
self.embeddings.requires_grad_(False)
|
||||
self.rope_embeddings.requires_grad_(False)
|
||||
self.encoder_norm.requires_grad_(False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -260,7 +272,8 @@ class UTransformer(nn.Module):
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
residual.append(x)
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
@@ -269,35 +282,71 @@ class UTransformer(nn.Module):
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
x = self.encoder_norm(x)
|
||||
|
||||
reversed_residual = residual[::-1]
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = x + residual.pop() + residual.pop()
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
x = x + reversed_residual[i]
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:])
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
def get_residual(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: Optional[torch.Tensor],
|
||||
do_condition: bool,
|
||||
):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=None)
|
||||
|
||||
if do_condition:
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
conditioning_input = t.to(x.dtype)
|
||||
else:
|
||||
conditioning_input = None
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=do_condition,
|
||||
)
|
||||
|
||||
return residual
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = UTransformer(config, 0).to("cuda:2")
|
||||
instance = UTransformer(config, 0).to("cuda:1")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:2"
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
new_key = key.replace("layer.", "encoder_layers.").replace(
|
||||
"norm.", "encoder_norm."
|
||||
)
|
||||
|
||||
if key.startswith("norm."):
|
||||
continue
|
||||
|
||||
weight_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
instance.load_state_dict(weight_dict, strict=False)
|
||||
|
||||
Reference in New Issue
Block a user