improved rf

This commit is contained in:
neulus
2025-10-01 18:44:26 +09:00
parent 49025c4d87
commit 29eb04d1a4
8 changed files with 150 additions and 27 deletions

View File

@@ -232,7 +232,7 @@ class UTransformer(nn.Module):
self.decoder_layers = nn.ModuleList(
[
DinoConditionedLayer(config, False)
for _ in range(config.num_hidden_layers)
for _ in range(config.num_hidden_layers // 2)
]
)
self.decoder = DinoV3ViTDecoder(config)
@@ -271,13 +271,13 @@ class UTransformer(nn.Module):
for i, layer_module in enumerate(self.decoder_layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = x + residual.pop() + residual.pop()
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = x + residual.pop()
return self.decoder(x, image_size=pixel_values.shape[-2:])