removed dinounet, and added dino. noticed that dinounet here is suboptimal
This commit is contained in:
360
src/model/dino.py
Normal file
360
src/model/dino.py
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||||
|
|
||||||
|
|
||||||
|
def get_patches_center_coordinates(num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||||
|
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||||
|
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||||
|
coords_h = coords_h / num_patches_h
|
||||||
|
coords_w = coords_w / num_patches_w
|
||||||
|
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||||
|
coords = coords.flatten(0, 1)
|
||||||
|
coords = 2.0 * coords - 1.0
|
||||||
|
return coords
|
||||||
|
|
||||||
|
|
||||||
|
def augment_patches_center_coordinates(coords: torch.Tensor, shift: Optional[float] = None,
|
||||||
|
jitter: Optional[float] = None, rescale: Optional[float] = None) -> torch.Tensor:
|
||||||
|
if shift is not None:
|
||||||
|
shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
|
||||||
|
shift_hw = shift_hw.uniform_(-shift, shift)
|
||||||
|
coords = coords + shift_hw
|
||||||
|
|
||||||
|
if jitter is not None:
|
||||||
|
jitter_range = np.log(jitter)
|
||||||
|
jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
|
||||||
|
jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp()
|
||||||
|
coords = coords * jitter_hw
|
||||||
|
|
||||||
|
if rescale is not None:
|
||||||
|
rescale_range = np.log(rescale)
|
||||||
|
rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype)
|
||||||
|
rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp()
|
||||||
|
coords = coords * rescale_hw
|
||||||
|
|
||||||
|
return coords
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
num_tokens = q.shape[-2]
|
||||||
|
num_patches = sin.shape[-2]
|
||||||
|
num_prefix_tokens = num_tokens - num_patches
|
||||||
|
|
||||||
|
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||||
|
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||||
|
|
||||||
|
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||||
|
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||||
|
|
||||||
|
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||||
|
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||||
|
|
||||||
|
return q, k
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
||||||
|
if drop_prob == 0.0 or not training:
|
||||||
|
return input
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
|
||||||
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
|
random_tensor.floor_()
|
||||||
|
output = input.div(keep_prob) * random_tensor
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTEmbeddings(nn.Module):
|
||||||
|
def __init__(self, config: DINOv3ViTConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||||
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
|
self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size))
|
||||||
|
self.patch_embeddings = nn.Conv2d(
|
||||||
|
config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
target_dtype = self.patch_embeddings.weight.dtype
|
||||||
|
|
||||||
|
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
||||||
|
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if bool_masked_pos is not None:
|
||||||
|
mask_token = self.mask_token.to(patch_embeddings.dtype)
|
||||||
|
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||||
|
|
||||||
|
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
||||||
|
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
|
||||||
|
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, config: DINOv3ViTConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.base = config.rope_theta
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
self.num_patches_h = config.image_size // config.patch_size
|
||||||
|
self.num_patches_w = config.image_size // config.patch_size
|
||||||
|
|
||||||
|
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
_, _, height, width = pixel_values.shape
|
||||||
|
num_patches_h = height // self.config.patch_size
|
||||||
|
num_patches_w = width // self.config.patch_size
|
||||||
|
|
||||||
|
device = pixel_values.device
|
||||||
|
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
||||||
|
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
patch_coords = get_patches_center_coordinates(
|
||||||
|
num_patches_h, num_patches_w, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
if self.training:
|
||||||
|
patch_coords = augment_patches_center_coordinates(
|
||||||
|
patch_coords,
|
||||||
|
shift=self.config.pos_embed_shift,
|
||||||
|
jitter=self.config.pos_embed_jitter,
|
||||||
|
rescale=self.config.pos_embed_rescale,
|
||||||
|
)
|
||||||
|
|
||||||
|
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] # type: ignore
|
||||||
|
angles = angles.flatten(1, 2)
|
||||||
|
angles = angles.tile(2)
|
||||||
|
|
||||||
|
cos = torch.cos(angles)
|
||||||
|
sin = torch.sin(angles)
|
||||||
|
|
||||||
|
dtype = pixel_values.dtype
|
||||||
|
return cos.to(dtype=dtype), sin.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTAttention(nn.Module):
|
||||||
|
def __init__(self, config: DINOv3ViTConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias)
|
||||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias)
|
||||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias)
|
||||||
|
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
assert position_embeddings is not None
|
||||||
|
|
||||||
|
batch_size, patches, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights * attention_mask
|
||||||
|
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTLayerScale(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
return hidden_state * self.lambda1
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTDropPath(nn.Module):
|
||||||
|
def __init__(self, drop_prob: float):
|
||||||
|
super().__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
return drop_path(hidden_states, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTMLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||||
|
|
||||||
|
if config.hidden_act == "gelu":
|
||||||
|
self.act_fn = F.gelu
|
||||||
|
elif config.hidden_act == "relu":
|
||||||
|
self.act_fn = F.relu
|
||||||
|
elif config.hidden_act == "silu":
|
||||||
|
self.act_fn = F.silu
|
||||||
|
else:
|
||||||
|
self.act_fn = F.gelu
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTGatedMLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||||
|
|
||||||
|
if config.hidden_act == "gelu":
|
||||||
|
self.act_fn = F.gelu
|
||||||
|
elif config.hidden_act == "relu":
|
||||||
|
self.act_fn = F.relu
|
||||||
|
elif config.hidden_act == "silu":
|
||||||
|
self.act_fn = F.silu
|
||||||
|
else:
|
||||||
|
self.act_fn = F.gelu
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTLayer(nn.Module):
|
||||||
|
def __init__(self, config: DINOv3ViTConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.attention = DINOv3ViTAttention(config)
|
||||||
|
self.layer_scale1 = DINOv3ViTLayerScale(config)
|
||||||
|
self.drop_path = DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
if config.use_gated_mlp:
|
||||||
|
self.mlp = DINOv3ViTGatedMLP(config)
|
||||||
|
else:
|
||||||
|
self.mlp = DINOv3ViTMLP(config)
|
||||||
|
self.layer_scale2 = DINOv3ViTLayerScale(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:
|
||||||
|
assert position_embeddings is not None
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states, _ = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
hidden_states = self.layer_scale1(hidden_states)
|
||||||
|
hidden_states = self.drop_path(hidden_states) + residual
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.layer_scale2(hidden_states)
|
||||||
|
hidden_states = self.drop_path(hidden_states) + residual
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTModel(nn.Module):
|
||||||
|
def __init__(self, config: DINOv3ViTConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||||
|
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||||
|
self.layers = nn.ModuleList([DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
|
module.weight.data = nn.init.trunc_normal_(
|
||||||
|
module.weight.data.to(torch.float32),
|
||||||
|
mean=0.0,
|
||||||
|
std=self.config.initializer_range,
|
||||||
|
).to(module.weight.dtype)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
elif isinstance(module, DINOv3ViTEmbeddings):
|
||||||
|
module.cls_token.data = nn.init.trunc_normal_(
|
||||||
|
module.cls_token.data.to(torch.float32),
|
||||||
|
mean=0.0,
|
||||||
|
std=self.config.initializer_range,
|
||||||
|
).to(module.cls_token.dtype)
|
||||||
|
if module.config.num_register_tokens > 0:
|
||||||
|
module.register_tokens.data = nn.init.trunc_normal_(
|
||||||
|
module.register_tokens.data.to(torch.float32),
|
||||||
|
mean=0.0,
|
||||||
|
std=self.config.initializer_range,
|
||||||
|
).to(module.register_tokens.dtype)
|
||||||
|
module.mask_token.data.zero_()
|
||||||
|
elif isinstance(module, DINOv3ViTLayerScale):
|
||||||
|
module.lambda1.data.fill_(self.config.layerscale_value)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None):
|
||||||
|
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||||
|
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
|
position_embeddings = self.rope_embeddings(pixel_values)
|
||||||
|
|
||||||
|
for i, layer_module in enumerate(self.layers):
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
hidden_states = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=layer_head_mask,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = self.norm(hidden_states)
|
||||||
|
pooled_output = sequence_output[:, 0, :]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"last_hidden_state": sequence_output,
|
||||||
|
"pooler_output": pooled_output,
|
||||||
|
}
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers import DINOv3ViTModel
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DinoUNETConfig:
|
|
||||||
model_name: str = "facebook/dinov2-small"
|
|
||||||
num_classes: int = 2
|
|
||||||
features_per_stage: List[int] = [32, 64, 128, 256]
|
|
||||||
n_conv_per_stage_decoder: List[int] = [2, 2, 2]
|
|
||||||
deep_supervision: bool = False
|
|
||||||
rank: int = 256
|
|
||||||
|
|
||||||
class SqueezeExcitation(nn.Module):
|
|
||||||
def __init__(self, channels: int, reduction: int=16):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, max(1, channels // reduction), 1),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(max(1, channels // reduction), channels, 1),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
return x * self.fc(self.pool(x))
|
|
||||||
|
|
||||||
class DepthwiseSeparableConv(nn.Module):
|
|
||||||
def __init__(self, in_ch: int, out_ch: int, kernel_size: int=3, stride: int=1, padding: int=1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch)
|
|
||||||
self.pointwise = nn.Conv2d(in_ch, out_ch, 1)
|
|
||||||
self.bn = nn.BatchNorm2d(out_ch)
|
|
||||||
self.act = nn.ReLU(True)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
return self.act(self.bn(self.pointwise(self.depthwise(x))))
|
|
||||||
|
|
||||||
class LearnableUpsample(nn.Module):
|
|
||||||
def __init__(self, channels: int):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.up2 = nn.ConvTranspose2d(channels, channels, 2, 2)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, target_size: tuple[int, int]):
|
|
||||||
h, w = x.shape[2:]
|
|
||||||
out = x
|
|
||||||
while h * 2 <= target_size[0] and w * 2 <= target_size[1]:
|
|
||||||
out = self.up2(out)
|
|
||||||
h, w = out.shape[2:]
|
|
||||||
if (h, w) != target_size:
|
|
||||||
out = F.interpolate(out, target_size, mode='bilinear', align_corners=False)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class FAPM(nn.Module):
|
|
||||||
def __init__(self, in_ch: int, rank: int, out_ch_list: list[int]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.shared_basis = nn.Conv2d(in_ch, rank, 1)
|
|
||||||
self.specific_bases = nn.ModuleList([nn.Conv2d(in_ch, rank, 1) for _ in out_ch_list])
|
|
||||||
self.film_generators = nn.ModuleList([nn.Conv2d(rank, rank * 2, 1) for _ in out_ch_list])
|
|
||||||
self.refinement_blocks = nn.ModuleList()
|
|
||||||
self.shortcuts = nn.ModuleList()
|
|
||||||
|
|
||||||
for oc in out_ch_list:
|
|
||||||
self.refinement_blocks.append(nn.Sequential(
|
|
||||||
nn.Conv2d(rank, oc, 1),
|
|
||||||
nn.BatchNorm2d(oc),
|
|
||||||
nn.ReLU(True),
|
|
||||||
DepthwiseSeparableConv(oc, oc),
|
|
||||||
nn.Conv2d(oc, oc, 1),
|
|
||||||
SqueezeExcitation(oc)
|
|
||||||
))
|
|
||||||
self.shortcuts.append(nn.Conv2d(rank, oc, 1) if rank != oc else nn.Identity())
|
|
||||||
|
|
||||||
def forward(self, x_list: list[torch.Tensor]):
|
|
||||||
out = []
|
|
||||||
for i, x in enumerate(x_list):
|
|
||||||
z_shared = self.shared_basis(x)
|
|
||||||
z_specific = self.specific_bases[i](x)
|
|
||||||
gamma, beta = torch.chunk(self.film_generators[i](z_shared), 2, dim=1)
|
|
||||||
z_modulated = gamma * z_specific + beta
|
|
||||||
refined = self.refinement_blocks[i](z_modulated)
|
|
||||||
shortcut = self.shortcuts[i](z_modulated)
|
|
||||||
out.append(refined + shortcut)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class DINOv3Encoder(nn.Module):
|
|
||||||
def __init__(self, config: DinoUNETConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.output_channels = config.features_per_stage
|
|
||||||
self.strides = [(2, 2)] * len(config.features_per_stage)
|
|
||||||
|
|
||||||
self.backbone = DINOv3ViTModel.from_pretrained(config.model_name)
|
|
||||||
self.fapm = FAPM(self.backbone.config.hidden_size, config.rank, config.features_per_stage)
|
|
||||||
self.ups = nn.ModuleList([LearnableUpsample(oc) for oc in config.features_per_stage])
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
|
|
||||||
if C == 1:
|
|
||||||
x = x.repeat(1, 3, 1, 1)
|
|
||||||
elif C != 3:
|
|
||||||
x = x[:, :3] if C > 3 else F.pad(x, (0, 0, 0, 0, 0, 3 - C))[:, :3]
|
|
||||||
|
|
||||||
outputs = self.backbone(x, output_hidden_states=True, return_dict=True)
|
|
||||||
hidden_states = outputs.hidden_states
|
|
||||||
|
|
||||||
h, w = H // self.backbone.config.patch_size, W // self.backbone.config.patch_size
|
|
||||||
features = []
|
|
||||||
indices = [3, 6, 9, 12] if len(hidden_states) > 12 else [2, 4, 6, -1]
|
|
||||||
|
|
||||||
for idx in indices:
|
|
||||||
feat = hidden_states[idx][:, 1:].transpose(1, 2).reshape(B, -1, h, w)
|
|
||||||
features.append(feat)
|
|
||||||
|
|
||||||
features = self.fapm(features)
|
|
||||||
skips = []
|
|
||||||
for i, feat in enumerate(features):
|
|
||||||
target_size = (H // (2 ** i), W // (2 ** i))
|
|
||||||
skips.append(self.ups[i](feat, target_size))
|
|
||||||
|
|
||||||
return skips
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
|
||||||
def __init__(self, in_ch: int, out_ch: int, kernel_size: int=3, stride: int=1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, kernel_size // 2)
|
|
||||||
self.bn = nn.BatchNorm2d(out_ch)
|
|
||||||
self.act = nn.ReLU(True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.act(self.bn(self.conv(x)))
|
|
||||||
|
|
||||||
class StackedConvBlocks(nn.Module):
|
|
||||||
def __init__(self, n_blocks: int, in_ch: int, out_ch: int, kernel_size: int=3):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*(
|
|
||||||
[ConvBlock(in_ch, out_ch, kernel_size)] +
|
|
||||||
[ConvBlock(out_ch, out_ch, kernel_size) for _ in range(n_blocks - 1)]
|
|
||||||
))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.blocks(x)
|
|
||||||
|
|
||||||
class UNetDecoder(nn.Module):
|
|
||||||
def __init__(self, encoder: DINOv3Encoder, config: DinoUNETConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.deep_supervision = config.deep_supervision
|
|
||||||
self.encoder = encoder
|
|
||||||
|
|
||||||
self.stages = nn.ModuleList()
|
|
||||||
self.transpconvs = nn.ModuleList()
|
|
||||||
self.seg_layers = nn.ModuleList()
|
|
||||||
|
|
||||||
for s in range(1, len(encoder.output_channels)):
|
|
||||||
in_below = encoder.output_channels[-s]
|
|
||||||
in_skip = encoder.output_channels[-(s + 1)]
|
|
||||||
stride = encoder.strides[-s]
|
|
||||||
|
|
||||||
self.transpconvs.append(nn.ConvTranspose2d(in_below, in_skip, stride, stride))
|
|
||||||
self.stages.append(StackedConvBlocks(
|
|
||||||
config.n_conv_per_stage_decoder[s-1],
|
|
||||||
2 * in_skip,
|
|
||||||
in_skip
|
|
||||||
))
|
|
||||||
self.seg_layers.append(nn.Conv2d(in_skip, config.num_classes, 1))
|
|
||||||
|
|
||||||
def forward(self, skips: list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: # only list if doing deep supervision
|
|
||||||
lres_input = skips[-1]
|
|
||||||
seg_outputs = []
|
|
||||||
|
|
||||||
for s in range(len(self.stages)):
|
|
||||||
x = self.transpconvs[s](lres_input)
|
|
||||||
x = torch.cat((x, skips[-(s+2)]), 1)
|
|
||||||
x = self.stages[s](x)
|
|
||||||
if self.deep_supervision or s == len(self.stages) - 1:
|
|
||||||
seg_outputs.append(self.seg_layers[s](x))
|
|
||||||
lres_input = x
|
|
||||||
|
|
||||||
seg_outputs = seg_outputs[::-1]
|
|
||||||
return seg_outputs if self.deep_supervision else seg_outputs[0]
|
|
||||||
|
|
||||||
class DinoUNet(nn.Module):
|
|
||||||
def __init__(self, config: DinoUNETConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = DINOv3Encoder(config)
|
|
||||||
self.decoder = UNetDecoder(self.encoder, config)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
skips = self.encoder(x)
|
|
||||||
return self.decoder(skips)
|
|
||||||
Reference in New Issue
Block a user