pixel shuffle

This commit is contained in:
neulus
2025-10-01 16:30:05 +09:00
parent 8966cafb8f
commit 49025c4d87
19 changed files with 162 additions and 101 deletions

1
.gitignore vendored
View File

@@ -11,3 +11,4 @@ wheels/
artifact artifact
.zed .zed
wandb wandb
datasets

View File

@@ -11,7 +11,7 @@ from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer from src.model.utransformer import UTransformer
from src.rf import RF from src.rf import RF
device = "cuda:0" device = "cuda:2"
model = UTransformer.from_pretrained_backbone( model = UTransformer.from_pretrained_backbone(
"facebook/dinov3-vitl16-pretrain-sat493m" "facebook/dinov3-vitl16-pretrain-sat493m"
@@ -21,7 +21,7 @@ optimizer = optim.AdamW(model.parameters(), lr=1e-4)
train_dataset, test_dataset = get_dataset() train_dataset, test_dataset = get_dataset()
wandb.init(project="cloud-removal-kmu", id="icy-field-11", resume="allow") wandb.init(project="cloud-removal-kmu", id="icy-field-12", resume="allow")
if not (wandb.run and wandb.run.name): if not (wandb.run and wandb.run.name):
raise Exception("nope") raise Exception("nope")
@@ -37,7 +37,7 @@ if os.path.exists(checkpoint_path):
start_epoch = checkpoint["epoch"] + 1 start_epoch = checkpoint["epoch"] + 1
batch_size = 4 batch_size = 4
accumulation_steps = 4 accumulation_steps = 8
total_epoch = 1000 total_epoch = 1000
for epoch in range(start_epoch, total_epoch): for epoch in range(start_epoch, total_epoch):
lossbin = {i: 0 for i in range(10)} lossbin = {i: 0 for i in range(10)}

View File

@@ -12,6 +12,7 @@ dependencies = [
"ruff>=0.13.2", "ruff>=0.13.2",
"safetensors>=0.6.2", "safetensors>=0.6.2",
"torch>=2.8.0", "torch>=2.8.0",
"torcheval>=0.0.7",
"torchmetrics>=1.8.2", "torchmetrics>=1.8.2",
"torchvision>=0.23.0", "torchvision>=0.23.0",
"tqdm>=4.67.1", "tqdm>=4.67.1",

View File

@@ -10,14 +10,14 @@ from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer from src.model.utransformer import UTransformer
from src.rf import RF from src.rf import RF
checkpoint_path = "artifact/wild-wave-3/checkpoint_epoch_100.pt" checkpoint_path = "artifact/icy-field-12/checkpoint_epoch_260.pt"
device = "cuda:0" device = "cuda:2"
save_dir = "test_images" save_dir = "test_images"
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
model = UTransformer.from_pretrained_backbone( model = UTransformer.from_pretrained_backbone(
"facebook/dinov3-vits16-pretrain-lvd1689m" "facebook/dinov3-vitl16-pretrain-sat493m"
).to(device) ).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device) checkpoint = torch.load(checkpoint_path, map_location=device)
@@ -28,7 +28,7 @@ rf.model.eval()
_, test_dataset = get_dataset() _, test_dataset = get_dataset()
batch_size = 32 batch_size = 1
psnr_sum = 0 psnr_sum = 0
ssim_sum = 0 ssim_sum = 0
lpips_sum = 0 lpips_sum = 0

View File

@@ -4,7 +4,6 @@ from torchmetrics.image import (
StructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure,
) )
psnr = PeakSignalNoiseRatio(1.0, reduction="none")
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none") ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
lpips = LearnedPerceptualImagePatchSimilarity( lpips = LearnedPerceptualImagePatchSimilarity(
net_type="alex", reduction="none", normalize=True net_type="alex", reduction="none", normalize=True
@@ -12,4 +11,5 @@ lpips = LearnedPerceptualImagePatchSimilarity(
def benchmark(image1, image2): def benchmark(image1, image2):
psnr = PeakSignalNoiseRatio(1.0, reduction="none")
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2) return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)

View File

@@ -1,13 +1,17 @@
import os
from pathlib import Path from pathlib import Path
from datasets import Dataset, DatasetDict, Image from datasets import Dataset, DatasetDict, Image
from src.dataset.preprocess import make_transform from src.dataset.preprocess import make_transform
transform = make_transform(512) transform = make_transform(512)
def get_dataset() -> tuple[Dataset, Dataset]: def get_dataset() -> tuple[Dataset, Dataset]:
if os.path.exists("datasets/CUHK-CR1"):
dataset = DatasetDict.load_from_disk("datasets/CUHK-CR1")
return dataset["train"], dataset["test"]
data_dir = Path("/data2/C-CUHK/CUHK-CR1") data_dir = Path("/data2/C-CUHK/CUHK-CR1")
train_cloud = sorted((data_dir / "train/cloud").glob("*.png")) train_cloud = sorted((data_dir / "train/cloud").glob("*.png"))
@@ -35,24 +39,16 @@ def get_dataset() -> tuple[Dataset, Dataset]:
.cast_column("label", Image()), .cast_column("label", Image()),
} }
) )
train_dataset = dataset["train"] dataset = dataset.map(
train_dataset = train_dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
batch_size=32, batch_size=32,
remove_columns=train_dataset.column_names, remove_columns=dataset["train"].column_names,
) )
train_dataset.set_format(type="torch", columns=["x0", "x1"]) dataset.set_format(type="torch", columns=["x0", "x1"])
test_dataset = dataset["test"] dataset.save_to_disk("datasets/CUHK-CR1")
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
batch_size=32,
remove_columns=test_dataset.column_names,
)
test_dataset.set_format(type="torch", columns=["x0", "x1"])
return train_dataset, test_dataset return dataset["train"], dataset["test"]
def preprocess_function(examples): def preprocess_function(examples):

View File

@@ -1,5 +1,6 @@
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -7,7 +8,9 @@ import torch.nn.functional as F
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
def get_patches_center_coordinates(num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: def get_patches_center_coordinates(
num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
coords_h = coords_h / num_patches_h coords_h = coords_h / num_patches_h
@@ -18,8 +21,12 @@ def get_patches_center_coordinates(num_patches_h: int, num_patches_w: int, dtype
return coords return coords
def augment_patches_center_coordinates(coords: torch.Tensor, shift: Optional[float] = None, def augment_patches_center_coordinates(
jitter: Optional[float] = None, rescale: Optional[float] = None) -> torch.Tensor: coords: torch.Tensor,
shift: Optional[float] = None,
jitter: Optional[float] = None,
rescale: Optional[float] = None,
) -> torch.Tensor:
if shift is not None: if shift is not None:
shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
shift_hw = shift_hw.uniform_(-shift, shift) shift_hw = shift_hw.uniform_(-shift, shift)
@@ -46,7 +53,9 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
num_tokens = q.shape[-2] num_tokens = q.shape[-2]
num_patches = sin.shape[-2] num_patches = sin.shape[-2]
num_prefix_tokens = num_tokens - num_patches num_prefix_tokens = num_tokens - num_patches
@@ -63,12 +72,16 @@ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, si
return q, k return q, k
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: def drop_path(
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
) -> torch.Tensor:
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return input return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) shape = (input.shape[0],) + (1,) * (input.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor = keep_prob + torch.rand(
shape, dtype=input.dtype, device=input.device
)
random_tensor.floor_() random_tensor.floor_()
output = input.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
@@ -80,12 +93,19 @@ class DINOv3ViTEmbeddings(nn.Module):
self.config = config self.config = config
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size)) self.register_tokens = nn.Parameter(
torch.empty(1, config.num_register_tokens, config.hidden_size)
)
self.patch_embeddings = nn.Conv2d( self.patch_embeddings = nn.Conv2d(
config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size config.num_channels,
config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
) )
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(
self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
target_dtype = self.patch_embeddings.weight.dtype target_dtype = self.patch_embeddings.weight.dtype
@@ -94,7 +114,9 @@ class DINOv3ViTEmbeddings(nn.Module):
if bool_masked_pos is not None: if bool_masked_pos is not None:
mask_token = self.mask_token.to(patch_embeddings.dtype) mask_token = self.mask_token.to(patch_embeddings.dtype)
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) patch_embeddings = torch.where(
bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings
)
cls_token = self.cls_token.expand(batch_size, -1, -1) cls_token = self.cls_token.expand(batch_size, -1, -1)
register_tokens = self.register_tokens.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1)
@@ -112,7 +134,9 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
self.num_patches_h = config.image_size // config.patch_size self.num_patches_h = config.image_size // config.patch_size
self.num_patches_w = config.image_size // config.patch_size self.num_patches_w = config.image_size // config.patch_size
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32) inv_freq = 1 / self.base ** torch.arange(
0, 1, 4 / self.head_dim, dtype=torch.float32
)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -121,7 +145,11 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
num_patches_w = width // self.config.patch_size num_patches_w = width // self.config.patch_size
device = pixel_values.device device = pixel_values.device
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" device_type = (
device.type
if isinstance(device.type, str) and device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
patch_coords = get_patches_center_coordinates( patch_coords = get_patches_center_coordinates(
@@ -135,7 +163,9 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
rescale=self.config.pos_embed_rescale, rescale=self.config.pos_embed_rescale,
) )
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] # type: ignore angles = (
2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] # type: ignore
)
angles = angles.flatten(1, 2) angles = angles.flatten(1, 2)
angles = angles.tile(2) angles = angles.tile(2)
@@ -161,8 +191,12 @@ class DINOv3ViTAttention(nn.Module):
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias)
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias)
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, def forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert position_embeddings is not None assert position_embeddings is not None
batch_size, patches, _ = hidden_states.size() batch_size, patches, _ = hidden_states.size()
@@ -171,18 +205,32 @@ class DINOv3ViTAttention(nn.Module):
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) batch_size, patches, self.num_heads, self.head_dim
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) ).transpose(1, 2)
key_states = key_states.view(
batch_size, patches, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, patches, self.num_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling attn_weights = (
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling
)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
if self.training: if self.training:
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = F.dropout(
attn_weights, p=self.dropout, training=self.training
)
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights * attention_mask attn_weights = attn_weights * attention_mask
@@ -198,7 +246,9 @@ class DINOv3ViTAttention(nn.Module):
class DINOv3ViTLayerScale(nn.Module): class DINOv3ViTLayerScale(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) self.lambda1 = nn.Parameter(
config.layerscale_value * torch.ones(config.hidden_size)
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return hidden_state * self.lambda1 return hidden_state * self.lambda1
@@ -219,8 +269,12 @@ class DINOv3ViTMLP(nn.Module):
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
)
if config.hidden_act == "gelu": if config.hidden_act == "gelu":
self.act_fn = F.gelu self.act_fn = F.gelu
@@ -241,9 +295,15 @@ class DINOv3ViTGatedMLP(nn.Module):
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.gate_proj = nn.Linear(
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.hidden_size, self.intermediate_size, bias=config.mlp_bias
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) )
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
)
if config.hidden_act == "gelu": if config.hidden_act == "gelu":
self.act_fn = F.gelu self.act_fn = F.gelu
@@ -264,7 +324,11 @@ class DINOv3ViTLayer(nn.Module):
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = DINOv3ViTAttention(config) self.attention = DINOv3ViTAttention(config)
self.layer_scale1 = DINOv3ViTLayerScale(config) self.layer_scale1 = DINOv3ViTLayerScale(config)
self.drop_path = DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.drop_path = (
DINOv3ViTDropPath(config.drop_path_rate)
if config.drop_path_rate > 0.0
else nn.Identity()
)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -274,8 +338,14 @@ class DINOv3ViTLayer(nn.Module):
self.mlp = DINOv3ViTMLP(config) self.mlp = DINOv3ViTMLP(config)
self.layer_scale2 = DINOv3ViTLayerScale(config) self.layer_scale2 = DINOv3ViTLayerScale(config)
def forward(self, hidden_states: torch.Tensor, *, attention_mask: Optional[torch.Tensor] = None, def forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> torch.Tensor: self,
hidden_states: torch.Tensor,
*,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
assert position_embeddings is not None assert position_embeddings is not None
residual = hidden_states residual = hidden_states
@@ -303,7 +373,9 @@ class DINOv3ViTModel(nn.Module):
self.config = config self.config = config
self.embeddings = DINOv3ViTEmbeddings(config) self.embeddings = DINOv3ViTEmbeddings(config)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
self.layers = nn.ModuleList([DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
[DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._init_weights() self._init_weights()
@@ -337,8 +409,12 @@ class DINOv3ViTModel(nn.Module):
elif isinstance(module, DINOv3ViTLayerScale): elif isinstance(module, DINOv3ViTLayerScale):
module.lambda1.data.fill_(self.config.layerscale_value) module.lambda1.data.fill_(self.config.layerscale_value)
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None, def forward(
head_mask: Optional[torch.Tensor] = None): self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values) position_embeddings = self.rope_embeddings(pixel_values)

View File

@@ -182,60 +182,32 @@ class DinoV3ViTDecoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.num_channels_out = config.num_channels self.num_channels_out = config.num_channels
hidden_dim = config.hidden_size self.patch_size = config.patch_size
patch_size = config.patch_size
self.projection = nn.Linear(hidden_dim, hidden_dim) self.projection = nn.Linear(
config.hidden_size,
if patch_size == 14: self.num_channels_out * (self.patch_size**2),
final_upsample = 7 bias=True,
elif patch_size == 16:
final_upsample = 8
elif patch_size == 8:
final_upsample = 4
else:
raise ValueError("invalid")
self.decoder = nn.Sequential(
nn.Conv2d(hidden_dim, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Upsample(
scale_factor=final_upsample, mode="bilinear", align_corners=False
),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, self.num_channels_out, kernel_size=1),
) )
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
batch_size = x.shape[0] batch_size = x.shape[0]
patch_tokens = x[:, 1 + self.config.num_register_tokens :, :] x = x[:, 1 + self.config.num_register_tokens :, :]
projected_tokens = self.projection(patch_tokens) x = self.projection(x)
p = self.config.patch_size p = self.config.patch_size
h_grid = image_size[0] // p h_grid = image_size[0] // p
w_grid = image_size[1] // p w_grid = image_size[1] // p
assert patch_tokens.shape[1] == h_grid * w_grid assert x.shape[1] == h_grid * w_grid
x_spatial = projected_tokens.reshape( x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
batch_size, h_grid, w_grid, self.config.hidden_size
)
x_spatial = x_spatial.permute(0, 3, 1, 2) return self.pixel_shuffle(x)
reconstructed_image = self.decoder(x_spatial)
return reconstructed_image
class UTransformer(nn.Module): class UTransformer(nn.Module):
@@ -256,7 +228,6 @@ class UTransformer(nn.Module):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
) )
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder_layers = nn.ModuleList( self.decoder_layers = nn.ModuleList(
[ [
@@ -269,7 +240,6 @@ class UTransformer(nn.Module):
# freeze pretrained # freeze pretrained
self.embeddings.requires_grad_(False) self.embeddings.requires_grad_(False)
self.rope_embeddings.requires_grad_(False) self.rope_embeddings.requires_grad_(False)
self.encoder_norm.requires_grad_(False)
def forward( def forward(
self, self,
@@ -298,7 +268,6 @@ class UTransformer(nn.Module):
attention_mask=layer_head_mask, attention_mask=layer_head_mask,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
) )
x = self.encoder_norm(x)
for i, layer_module in enumerate(self.decoder_layers): for i, layer_module in enumerate(self.decoder_layers):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
@@ -315,16 +284,20 @@ class UTransformer(nn.Module):
@staticmethod @staticmethod
def from_pretrained_backbone(name: str): def from_pretrained_backbone(name: str):
config = DINOv3ViTConfig.from_pretrained(name) config = DINOv3ViTConfig.from_pretrained(name)
instance = UTransformer(config, 0).to("cuda:3") instance = UTransformer(config, 0).to("cuda:2")
weight_dict = {} weight_dict = {}
with safe_open( with safe_open(
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:3" hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:2"
) as f: ) as f:
for key in f.keys(): for key in f.keys():
new_key = key.replace("layer.", "encoder_layers.").replace( new_key = key.replace("layer.", "encoder_layers.").replace(
"norm.", "encoder_norm." "norm.", "encoder_norm."
) )
if key.startswith("norm."):
continue
weight_dict[new_key] = f.get_tensor(key) weight_dict[new_key] = f.get_tensor(key)
instance.load_state_dict(weight_dict, strict=False) instance.load_state_dict(weight_dict, strict=False)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 489 KiB

After

Width:  |  Height:  |  Size: 325 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 537 KiB

After

Width:  |  Height:  |  Size: 351 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 482 KiB

After

Width:  |  Height:  |  Size: 301 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 477 KiB

After

Width:  |  Height:  |  Size: 305 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 508 KiB

After

Width:  |  Height:  |  Size: 349 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 502 KiB

After

Width:  |  Height:  |  Size: 304 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 541 KiB

After

Width:  |  Height:  |  Size: 371 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 551 KiB

After

Width:  |  Height:  |  Size: 384 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 541 KiB

After

Width:  |  Height:  |  Size: 366 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 488 KiB

After

Width:  |  Height:  |  Size: 315 KiB

14
uv.lock generated
View File

@@ -197,6 +197,7 @@ dependencies = [
{ name = "ruff" }, { name = "ruff" },
{ name = "safetensors" }, { name = "safetensors" },
{ name = "torch" }, { name = "torch" },
{ name = "torcheval" },
{ name = "torchmetrics" }, { name = "torchmetrics" },
{ name = "torchvision" }, { name = "torchvision" },
{ name = "tqdm" }, { name = "tqdm" },
@@ -213,6 +214,7 @@ requires-dist = [
{ name = "ruff", specifier = ">=0.13.2" }, { name = "ruff", specifier = ">=0.13.2" },
{ name = "safetensors", specifier = ">=0.6.2" }, { name = "safetensors", specifier = ">=0.6.2" },
{ name = "torch", specifier = ">=2.8.0" }, { name = "torch", specifier = ">=2.8.0" },
{ name = "torcheval", specifier = ">=0.0.7" },
{ name = "torchmetrics", specifier = ">=1.8.2" }, { name = "torchmetrics", specifier = ">=1.8.2" },
{ name = "torchvision", specifier = ">=0.23.0" }, { name = "torchvision", specifier = ">=0.23.0" },
{ name = "tqdm", specifier = ">=4.67.1" }, { name = "tqdm", specifier = ">=4.67.1" },
@@ -1487,6 +1489,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" }, { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" },
] ]
[[package]]
name = "torcheval"
version = "0.0.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f2/65/8f895a132d385c3bc60023a45501637aff63b98629cf45ce47a2035c0cc3/torcheval-0.0.7.tar.gz", hash = "sha256:a498dec34137bc66c9cf1adc7353a46c604dd62255884c72dcb4e2e4fc2cd7e9", size = 100912, upload-time = "2023-08-24T22:12:44.683Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e4/de/e7abc784b00de9d05999657d29187f1f7a3406ed10ecaf164de06482608f/torcheval-0.0.7-py3-none-any.whl", hash = "sha256:20cc34dac7aa9b32f942c8a9f014d1d02098631b6cd0b102c078600577017956", size = 179200, upload-time = "2023-08-24T22:12:42.874Z" },
]
[[package]] [[package]]
name = "torchmetrics" name = "torchmetrics"
version = "1.8.2" version = "1.8.2"