improved rf
This commit is contained in:
@@ -232,7 +232,7 @@ class UTransformer(nn.Module):
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, False)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
for _ in range(config.num_hidden_layers // 2)
|
||||
]
|
||||
)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
@@ -271,13 +271,13 @@ class UTransformer(nn.Module):
|
||||
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = x + residual.pop() + residual.pop()
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
x = x + residual.pop()
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user