pixel shuffle
1
.gitignore
vendored
@@ -11,3 +11,4 @@ wheels/
|
||||
artifact
|
||||
.zed
|
||||
wandb
|
||||
datasets
|
||||
|
||||
6
main.py
@@ -11,7 +11,7 @@ from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
device = "cuda:0"
|
||||
device = "cuda:2"
|
||||
|
||||
model = UTransformer.from_pretrained_backbone(
|
||||
"facebook/dinov3-vitl16-pretrain-sat493m"
|
||||
@@ -21,7 +21,7 @@ optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
wandb.init(project="cloud-removal-kmu", id="icy-field-11", resume="allow")
|
||||
wandb.init(project="cloud-removal-kmu", id="icy-field-12", resume="allow")
|
||||
|
||||
if not (wandb.run and wandb.run.name):
|
||||
raise Exception("nope")
|
||||
@@ -37,7 +37,7 @@ if os.path.exists(checkpoint_path):
|
||||
start_epoch = checkpoint["epoch"] + 1
|
||||
|
||||
batch_size = 4
|
||||
accumulation_steps = 4
|
||||
accumulation_steps = 8
|
||||
total_epoch = 1000
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
lossbin = {i: 0 for i in range(10)}
|
||||
|
||||
@@ -12,6 +12,7 @@ dependencies = [
|
||||
"ruff>=0.13.2",
|
||||
"safetensors>=0.6.2",
|
||||
"torch>=2.8.0",
|
||||
"torcheval>=0.0.7",
|
||||
"torchmetrics>=1.8.2",
|
||||
"torchvision>=0.23.0",
|
||||
"tqdm>=4.67.1",
|
||||
|
||||
@@ -10,14 +10,14 @@ from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
checkpoint_path = "artifact/wild-wave-3/checkpoint_epoch_100.pt"
|
||||
device = "cuda:0"
|
||||
checkpoint_path = "artifact/icy-field-12/checkpoint_epoch_260.pt"
|
||||
device = "cuda:2"
|
||||
save_dir = "test_images"
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
model = UTransformer.from_pretrained_backbone(
|
||||
"facebook/dinov3-vits16-pretrain-lvd1689m"
|
||||
"facebook/dinov3-vitl16-pretrain-sat493m"
|
||||
).to(device)
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
@@ -28,7 +28,7 @@ rf.model.eval()
|
||||
|
||||
_, test_dataset = get_dataset()
|
||||
|
||||
batch_size = 32
|
||||
batch_size = 1
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
lpips_sum = 0
|
||||
|
||||
@@ -4,7 +4,6 @@ from torchmetrics.image import (
|
||||
StructuralSimilarityIndexMeasure,
|
||||
)
|
||||
|
||||
psnr = PeakSignalNoiseRatio(1.0, reduction="none")
|
||||
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(
|
||||
net_type="alex", reduction="none", normalize=True
|
||||
@@ -12,4 +11,5 @@ lpips = LearnedPerceptualImagePatchSimilarity(
|
||||
|
||||
|
||||
def benchmark(image1, image2):
|
||||
psnr = PeakSignalNoiseRatio(1.0, reduction="none")
|
||||
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset, DatasetDict, Image
|
||||
|
||||
from src.dataset.preprocess import make_transform
|
||||
|
||||
transform = make_transform(512)
|
||||
|
||||
|
||||
def get_dataset() -> tuple[Dataset, Dataset]:
|
||||
if os.path.exists("datasets/CUHK-CR1"):
|
||||
dataset = DatasetDict.load_from_disk("datasets/CUHK-CR1")
|
||||
return dataset["train"], dataset["test"]
|
||||
|
||||
data_dir = Path("/data2/C-CUHK/CUHK-CR1")
|
||||
|
||||
train_cloud = sorted((data_dir / "train/cloud").glob("*.png"))
|
||||
@@ -35,24 +39,16 @@ def get_dataset() -> tuple[Dataset, Dataset]:
|
||||
.cast_column("label", Image()),
|
||||
}
|
||||
)
|
||||
train_dataset = dataset["train"]
|
||||
train_dataset = train_dataset.map(
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
batch_size=32,
|
||||
remove_columns=train_dataset.column_names,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
)
|
||||
train_dataset.set_format(type="torch", columns=["x0", "x1"])
|
||||
test_dataset = dataset["test"]
|
||||
test_dataset = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
batch_size=32,
|
||||
remove_columns=test_dataset.column_names,
|
||||
)
|
||||
test_dataset.set_format(type="torch", columns=["x0", "x1"])
|
||||
dataset.set_format(type="torch", columns=["x0", "x1"])
|
||||
dataset.save_to_disk("datasets/CUHK-CR1")
|
||||
|
||||
return train_dataset, test_dataset
|
||||
return dataset["train"], dataset["test"]
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -7,7 +8,9 @@ import torch.nn.functional as F
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
|
||||
def get_patches_center_coordinates(num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||
def get_patches_center_coordinates(
|
||||
num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||
coords_h = coords_h / num_patches_h
|
||||
@@ -18,8 +21,12 @@ def get_patches_center_coordinates(num_patches_h: int, num_patches_w: int, dtype
|
||||
return coords
|
||||
|
||||
|
||||
def augment_patches_center_coordinates(coords: torch.Tensor, shift: Optional[float] = None,
|
||||
jitter: Optional[float] = None, rescale: Optional[float] = None) -> torch.Tensor:
|
||||
def augment_patches_center_coordinates(
|
||||
coords: torch.Tensor,
|
||||
shift: Optional[float] = None,
|
||||
jitter: Optional[float] = None,
|
||||
rescale: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
if shift is not None:
|
||||
shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
|
||||
shift_hw = shift_hw.uniform_(-shift, shift)
|
||||
@@ -46,7 +53,9 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_tokens = q.shape[-2]
|
||||
num_patches = sin.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
@@ -63,12 +72,16 @@ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, si
|
||||
return q, k
|
||||
|
||||
|
||||
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
||||
def drop_path(
|
||||
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
||||
) -> torch.Tensor:
|
||||
if drop_prob == 0.0 or not training:
|
||||
return input
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
shape, dtype=input.dtype, device=input.device
|
||||
)
|
||||
random_tensor.floor_()
|
||||
output = input.div(keep_prob) * random_tensor
|
||||
return output
|
||||
@@ -80,12 +93,19 @@ class DINOv3ViTEmbeddings(nn.Module):
|
||||
self.config = config
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size))
|
||||
self.register_tokens = nn.Parameter(
|
||||
torch.empty(1, config.num_register_tokens, config.hidden_size)
|
||||
)
|
||||
self.patch_embeddings = nn.Conv2d(
|
||||
config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
|
||||
config.num_channels,
|
||||
config.hidden_size,
|
||||
kernel_size=config.patch_size,
|
||||
stride=config.patch_size,
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embeddings.weight.dtype
|
||||
|
||||
@@ -94,7 +114,9 @@ class DINOv3ViTEmbeddings(nn.Module):
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_token = self.mask_token.to(patch_embeddings.dtype)
|
||||
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||
patch_embeddings = torch.where(
|
||||
bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings
|
||||
)
|
||||
|
||||
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
||||
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
|
||||
@@ -112,7 +134,9 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
self.num_patches_h = config.image_size // config.patch_size
|
||||
self.num_patches_w = config.image_size // config.patch_size
|
||||
|
||||
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32)
|
||||
inv_freq = 1 / self.base ** torch.arange(
|
||||
0, 1, 4 / self.head_dim, dtype=torch.float32
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -121,7 +145,11 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
num_patches_w = width // self.config.patch_size
|
||||
|
||||
device = pixel_values.device
|
||||
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
||||
device_type = (
|
||||
device.type
|
||||
if isinstance(device.type, str) and device.type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
patch_coords = get_patches_center_coordinates(
|
||||
@@ -135,7 +163,9 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
rescale=self.config.pos_embed_rescale,
|
||||
)
|
||||
|
||||
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] # type: ignore
|
||||
angles = (
|
||||
2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] # type: ignore
|
||||
)
|
||||
angles = angles.flatten(1, 2)
|
||||
angles = angles.tile(2)
|
||||
|
||||
@@ -161,8 +191,12 @@ class DINOv3ViTAttention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias)
|
||||
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert position_embeddings is not None
|
||||
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
@@ -171,18 +205,32 @@ class DINOv3ViTAttention(nn.Module):
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = query_states.view(
|
||||
batch_size, patches, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
batch_size, patches, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
batch_size, patches, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = (
|
||||
torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling
|
||||
)
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
query_states.dtype
|
||||
)
|
||||
|
||||
if self.training:
|
||||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_weights = F.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
@@ -198,7 +246,9 @@ class DINOv3ViTAttention(nn.Module):
|
||||
class DINOv3ViTLayerScale(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
|
||||
self.lambda1 = nn.Parameter(
|
||||
config.layerscale_value * torch.ones(config.hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
return hidden_state * self.lambda1
|
||||
@@ -219,8 +269,12 @@ class DINOv3ViTMLP(nn.Module):
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
||||
)
|
||||
self.down_proj = nn.Linear(
|
||||
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
|
||||
)
|
||||
|
||||
if config.hidden_act == "gelu":
|
||||
self.act_fn = F.gelu
|
||||
@@ -241,9 +295,15 @@ class DINOv3ViTGatedMLP(nn.Module):
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.gate_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
||||
)
|
||||
self.up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
||||
)
|
||||
self.down_proj = nn.Linear(
|
||||
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
|
||||
)
|
||||
|
||||
if config.hidden_act == "gelu":
|
||||
self.act_fn = F.gelu
|
||||
@@ -264,7 +324,11 @@ class DINOv3ViTLayer(nn.Module):
|
||||
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = DINOv3ViTAttention(config)
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(config)
|
||||
self.drop_path = DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
self.drop_path = (
|
||||
DINOv3ViTDropPath(config.drop_path_rate)
|
||||
if config.drop_path_rate > 0.0
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
@@ -274,8 +338,14 @@ class DINOv3ViTLayer(nn.Module):
|
||||
self.mlp = DINOv3ViTMLP(config)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *, attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert position_embeddings is not None
|
||||
|
||||
residual = hidden_states
|
||||
@@ -303,7 +373,9 @@ class DINOv3ViTModel(nn.Module):
|
||||
self.config = config
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
self.layers = nn.ModuleList([DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self._init_weights()
|
||||
@@ -337,8 +409,12 @@ class DINOv3ViTModel(nn.Module):
|
||||
elif isinstance(module, DINOv3ViTLayerScale):
|
||||
module.lambda1.data.fill_(self.config.layerscale_value)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
@@ -182,60 +182,32 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_channels_out = config.num_channels
|
||||
hidden_dim = config.hidden_size
|
||||
patch_size = config.patch_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.projection = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
if patch_size == 14:
|
||||
final_upsample = 7
|
||||
elif patch_size == 16:
|
||||
final_upsample = 8
|
||||
elif patch_size == 8:
|
||||
final_upsample = 4
|
||||
else:
|
||||
raise ValueError("invalid")
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
||||
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Upsample(
|
||||
scale_factor=final_upsample, mode="bilinear", align_corners=False
|
||||
),
|
||||
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, self.num_channels_out, kernel_size=1),
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.num_channels_out * (self.patch_size**2),
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
patch_tokens = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
projected_tokens = self.projection(patch_tokens)
|
||||
x = self.projection(x)
|
||||
|
||||
p = self.config.patch_size
|
||||
h_grid = image_size[0] // p
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert patch_tokens.shape[1] == h_grid * w_grid
|
||||
assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
x_spatial = projected_tokens.reshape(
|
||||
batch_size, h_grid, w_grid, self.config.hidden_size
|
||||
)
|
||||
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
||||
|
||||
x_spatial = x_spatial.permute(0, 3, 1, 2)
|
||||
reconstructed_image = self.decoder(x_spatial)
|
||||
|
||||
return reconstructed_image
|
||||
return self.pixel_shuffle(x)
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
@@ -256,7 +228,6 @@ class UTransformer(nn.Module):
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
@@ -269,7 +240,6 @@ class UTransformer(nn.Module):
|
||||
# freeze pretrained
|
||||
self.embeddings.requires_grad_(False)
|
||||
self.rope_embeddings.requires_grad_(False)
|
||||
self.encoder_norm.requires_grad_(False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -298,7 +268,6 @@ class UTransformer(nn.Module):
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
x = self.encoder_norm(x)
|
||||
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
@@ -315,16 +284,20 @@ class UTransformer(nn.Module):
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = UTransformer(config, 0).to("cuda:3")
|
||||
instance = UTransformer(config, 0).to("cuda:2")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:3"
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:2"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
new_key = key.replace("layer.", "encoder_layers.").replace(
|
||||
"norm.", "encoder_norm."
|
||||
)
|
||||
|
||||
if key.startswith("norm."):
|
||||
continue
|
||||
|
||||
weight_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
instance.load_state_dict(weight_dict, strict=False)
|
||||
|
||||
|
Before Width: | Height: | Size: 489 KiB After Width: | Height: | Size: 325 KiB |
|
Before Width: | Height: | Size: 537 KiB After Width: | Height: | Size: 351 KiB |
|
Before Width: | Height: | Size: 482 KiB After Width: | Height: | Size: 301 KiB |
|
Before Width: | Height: | Size: 477 KiB After Width: | Height: | Size: 305 KiB |
|
Before Width: | Height: | Size: 508 KiB After Width: | Height: | Size: 349 KiB |
|
Before Width: | Height: | Size: 502 KiB After Width: | Height: | Size: 304 KiB |
|
Before Width: | Height: | Size: 541 KiB After Width: | Height: | Size: 371 KiB |
|
Before Width: | Height: | Size: 551 KiB After Width: | Height: | Size: 384 KiB |
|
Before Width: | Height: | Size: 541 KiB After Width: | Height: | Size: 366 KiB |
|
Before Width: | Height: | Size: 488 KiB After Width: | Height: | Size: 315 KiB |
14
uv.lock
generated
@@ -197,6 +197,7 @@ dependencies = [
|
||||
{ name = "ruff" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "torch" },
|
||||
{ name = "torcheval" },
|
||||
{ name = "torchmetrics" },
|
||||
{ name = "torchvision" },
|
||||
{ name = "tqdm" },
|
||||
@@ -213,6 +214,7 @@ requires-dist = [
|
||||
{ name = "ruff", specifier = ">=0.13.2" },
|
||||
{ name = "safetensors", specifier = ">=0.6.2" },
|
||||
{ name = "torch", specifier = ">=2.8.0" },
|
||||
{ name = "torcheval", specifier = ">=0.0.7" },
|
||||
{ name = "torchmetrics", specifier = ">=1.8.2" },
|
||||
{ name = "torchvision", specifier = ">=0.23.0" },
|
||||
{ name = "tqdm", specifier = ">=4.67.1" },
|
||||
@@ -1487,6 +1489,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torcheval"
|
||||
version = "0.0.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/65/8f895a132d385c3bc60023a45501637aff63b98629cf45ce47a2035c0cc3/torcheval-0.0.7.tar.gz", hash = "sha256:a498dec34137bc66c9cf1adc7353a46c604dd62255884c72dcb4e2e4fc2cd7e9", size = 100912, upload-time = "2023-08-24T22:12:44.683Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/de/e7abc784b00de9d05999657d29187f1f7a3406ed10ecaf164de06482608f/torcheval-0.0.7-py3-none-any.whl", hash = "sha256:20cc34dac7aa9b32f942c8a9f014d1d02098631b6cd0b102c078600577017956", size = 179200, upload-time = "2023-08-24T22:12:42.874Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torchmetrics"
|
||||
version = "1.8.2"
|
||||
|
||||