This commit is contained in:
neulus
2025-10-13 23:14:44 +09:00
parent c47d91a349
commit 3b03453e5d
28 changed files with 700 additions and 208 deletions

View File

@@ -78,6 +78,7 @@ class LabelEmbedder(nn.Module):
class DinoConditionedLayer(DINOv3ViTLayer):
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
super().__init__(config)
self.is_encoder = is_encoder
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cond = nn.MultiheadAttention(
@@ -298,6 +299,17 @@ class UTransformer(nn.Module):
for _ in range(config.num_hidden_layers // scale_factor)
]
)
self.residual_merger = nn.ModuleList(
[
nn.Sequential(
nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size)
)
for _ in range(config.num_hidden_layers // scale_factor)
]
)
self.rest_decoder = nn.ModuleList(
[DinoConditionedLayer(config, False) for _ in range(4)]
)
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = DinoV3ViTDecoder(config)
@@ -348,8 +360,22 @@ class UTransformer(nn.Module):
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
do_condition=False,
)
shift, scale = self.residual_merger[i](reversed_residual[i]).chunk(
2, dim=-1
)
x = x * (1 + scale) + shift
for i, layer_module in enumerate(self.rest_decoder):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
do_condition=False,
)
x = x + reversed_residual[i]
x = self.decoder_norm(x)