things
This commit is contained in:
@@ -78,6 +78,7 @@ class LabelEmbedder(nn.Module):
|
||||
class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
|
||||
super().__init__(config)
|
||||
self.is_encoder = is_encoder
|
||||
|
||||
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cond = nn.MultiheadAttention(
|
||||
@@ -298,6 +299,17 @@ class UTransformer(nn.Module):
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
]
|
||||
)
|
||||
self.residual_merger = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size)
|
||||
)
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
]
|
||||
)
|
||||
self.rest_decoder = nn.ModuleList(
|
||||
[DinoConditionedLayer(config, False) for _ in range(4)]
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
@@ -348,8 +360,22 @@ class UTransformer(nn.Module):
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
shift, scale = self.residual_merger[i](reversed_residual[i]).chunk(
|
||||
2, dim=-1
|
||||
)
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
for i, layer_module in enumerate(self.rest_decoder):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
x = x + reversed_residual[i]
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user