fix wrong average of psnr

This commit is contained in:
neulus
2025-10-02 19:40:00 +09:00
parent a601dc6095
commit 6bb6c09638
17 changed files with 221 additions and 39 deletions

View File

@@ -4,6 +4,7 @@ from torchmetrics.image import (
StructuralSimilarityIndexMeasure,
)
psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3))
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
lpips = LearnedPerceptualImagePatchSimilarity(
net_type="alex", reduction="none", normalize=True
@@ -11,5 +12,4 @@ lpips = LearnedPerceptualImagePatchSimilarity(
def benchmark(image1, image2):
psnr = PeakSignalNoiseRatio(1.0, reduction="none")
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)

View File

@@ -45,7 +45,7 @@ def get_dataset() -> tuple[Dataset, Dataset]:
batch_size=32,
remove_columns=dataset["train"].column_names,
)
dataset.set_format(type="torch", columns=["x0", "x1"])
dataset.set_format(type="torch", columns=["cloud", "gt"])
dataset.save_to_disk("datasets/CUHK-CR1")
return dataset["train"], dataset["test"]

62
src/dataset/cuhk_cr2.py Normal file
View File

@@ -0,0 +1,62 @@
import os
from pathlib import Path
from datasets import Dataset, DatasetDict, Image
from src.dataset.preprocess import make_transform
transform = make_transform(512)
def get_dataset() -> tuple[Dataset, Dataset]:
if os.path.exists("datasets/CUHK-CR2"):
dataset = DatasetDict.load_from_disk("datasets/CUHK-CR2")
return dataset["train"], dataset["test"]
data_dir = Path("/data2/C-CUHK/CUHK-CR2")
train_cloud = sorted((data_dir / "train/cloud").glob("*.png"))
train_no_cloud = sorted((data_dir / "train/label").glob("*.png"))
test_cloud = sorted((data_dir / "test/cloud").glob("*.png"))
test_no_cloud = sorted((data_dir / "test/label").glob("*.png"))
dataset = DatasetDict(
{
"train": Dataset.from_dict(
{
"cloud": [str(p) for p in train_cloud],
"label": [str(p) for p in train_no_cloud],
}
)
.cast_column("cloud", Image())
.cast_column("label", Image()),
"test": Dataset.from_dict(
{
"cloud": [str(p) for p in test_cloud],
"label": [str(p) for p in test_no_cloud],
}
)
.cast_column("cloud", Image())
.cast_column("label", Image()),
}
)
dataset = dataset.map(
preprocess_function,
batched=True,
batch_size=32,
remove_columns=dataset["train"].column_names,
)
dataset.set_format(type="torch", columns=["cloud", "gt"])
dataset.save_to_disk("datasets/CUHK-CR2")
return dataset["train"], dataset["test"]
def preprocess_function(examples):
x0_list = []
x1_list = []
for x0_img, x1_img in zip(examples["cloud"], examples["label"]):
x0_transformed = transform(x0_img)
x1_transformed = transform(x1_img)
x0_list.append(x0_transformed)
x1_list.append(x1_transformed)
return {"cloud": x0_list, "gt": x1_list}

View File

@@ -105,10 +105,11 @@ class DinoConditionedLayer(DINOv3ViTLayer):
conditioning_input: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
do_condition: bool = True,
**kwargs,
) -> torch.Tensor:
assert position_embeddings is not None
assert conditioning_input is not None
assert conditioning_input is not None or not do_condition
residual = hidden_states
hidden_states = self.norm1(hidden_states)
@@ -120,13 +121,14 @@ class DinoConditionedLayer(DINOv3ViTLayer):
hidden_states = self.layer_scale1(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
residual = hidden_states
hidden_states = self.norm_cond(hidden_states)
hidden_states, _ = self.cond(
hidden_states, conditioning_input, conditioning_input
)
hidden_states = self.layer_scale_cond(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
if do_condition:
residual = hidden_states
hidden_states = self.norm_cond(hidden_states)
hidden_states, _ = self.cond(
hidden_states, conditioning_input, conditioning_input
)
hidden_states = self.layer_scale_cond(hidden_states)
hidden_states = self.drop_path(hidden_states) + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
@@ -191,6 +193,8 @@ class DinoV3ViTDecoder(nn.Module):
)
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
nn.init.zeros_(self.projection.weight)
nn.init.zeros_(self.projection.bias)
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
batch_size = x.shape[0]
@@ -211,9 +215,14 @@ class DinoV3ViTDecoder(nn.Module):
class UTransformer(nn.Module):
def __init__(self, config: DINOv3ViTConfig, num_classes: int):
def __init__(
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
):
super().__init__()
self.config = config
self.scale_factor = scale_factor
assert config.num_hidden_layers % scale_factor == 0
self.embeddings = DINOv3ViTEmbeddings(config)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
@@ -228,18 +237,21 @@ class UTransformer(nn.Module):
for _ in range(config.num_hidden_layers)
]
)
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder_layers = nn.ModuleList(
[
DinoConditionedLayer(config, False)
for _ in range(config.num_hidden_layers // 2)
for _ in range(config.num_hidden_layers // scale_factor)
]
)
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = DinoV3ViTDecoder(config)
# freeze pretrained
self.embeddings.requires_grad_(False)
self.rope_embeddings.requires_grad_(False)
self.encoder_norm.requires_grad_(False)
def forward(
self,
@@ -260,7 +272,8 @@ class UTransformer(nn.Module):
residual = []
for i, layer_module in enumerate(self.encoder_layers):
residual.append(x)
if i % self.scale_factor == 0:
residual.append(x)
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
@@ -269,35 +282,71 @@ class UTransformer(nn.Module):
position_embeddings=position_embeddings,
)
x = self.encoder_norm(x)
reversed_residual = residual[::-1]
for i, layer_module in enumerate(self.decoder_layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = x + residual.pop() + residual.pop()
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = x + reversed_residual[i]
return self.decoder(x, image_size=pixel_values.shape[-2:])
x = self.decoder_norm(x)
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
def get_residual(
self,
pixel_values: torch.Tensor,
time: Optional[torch.Tensor],
do_condition: bool,
):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
x = self.embeddings(pixel_values, bool_masked_pos=None)
if do_condition:
t = self.t_embedder(time).unsqueeze(1)
# y = self.y_embedder(cond, self.training).unsqueeze(1)
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
conditioning_input = t.to(x.dtype)
else:
conditioning_input = None
residual = []
for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0:
residual.append(x)
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=None,
position_embeddings=position_embeddings,
do_condition=do_condition,
)
return residual
@staticmethod
def from_pretrained_backbone(name: str):
config = DINOv3ViTConfig.from_pretrained(name)
instance = UTransformer(config, 0).to("cuda:2")
instance = UTransformer(config, 0).to("cuda:1")
weight_dict = {}
with safe_open(
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:2"
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
) as f:
for key in f.keys():
new_key = key.replace("layer.", "encoder_layers.").replace(
"norm.", "encoder_norm."
)
if key.startswith("norm."):
continue
weight_dict[new_key] = f.get_tensor(key)
instance.load_state_dict(weight_dict, strict=False)

View File

@@ -13,17 +13,13 @@ def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
class RF:
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_huber"):
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_mse_enhanced"):
self.model = model
self.ln = ln
self.ushaped = ushaped
self.loss_fn = loss_fn
self.lpips = (
lpips.LPIPS(net="vgg").to(model.device)
if loss_fn == "lpips_huber"
else None
)
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1") if "lpips" in loss_fn else None
def forward(self, x0, z1):
# x0 is gt / z is noise
@@ -41,7 +37,7 @@ class RF:
texp = t.view([b, *([1] * len(x0.shape[1:]))])
zt = (1 - texp) * x0 + texp * z1
vtheta = self.model(zt, t)
vtheta, residual = self.model(zt, t)
if self.loss_fn == "lpips_huber":
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
@@ -57,6 +53,41 @@ class RF:
weight = t.view(-1)
loss = (1 - weight) * huber + lpips
elif self.loss_fn == "lpips_mse":
if not self.lpips:
raise Exception
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
)
loss = ((z1 - x0 - vtheta) ** 2).mean(
dim=list(range(1, len(x0.shape)))
) + 2.0 * lpips
elif self.loss_fn == "lpips_mse_enhanced":
if not self.lpips:
raise Exception
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
)
dino_loss = torch.stack(
[
(
1
- torch.nn.functional.cosine_similarity(
v_residual, x0_residual, dim=-1
)
).mean(dim=-1)
for v_residual, x0_residual in zip(
residual, self.model.get_residual(x0, None, False)
)
]
).mean(dim=0) * (2 - t.view(-1))
loss = (
((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
+ 2.0 * lpips
+ dino_loss
)
else:
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
@@ -66,7 +97,7 @@ class RF:
return loss.mean(), ttloss
@torch.no_grad()
def sample(self, z1, sample_steps=50):
def sample(self, z1, sample_steps=5):
b = z1.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
@@ -76,8 +107,36 @@ class RF:
t = i / sample_steps
t = torch.tensor([t] * b).to(z.device)
vc = self.model(z, t)
vc, _ = self.model(z, t)
z = z - dt * vc
images.append(z)
return images
@torch.no_grad()
def sample_heun(self, z1, sample_steps=50):
b = z1.size(0)
dt = 1.0 / sample_steps
images = [z1]
z = z1
for i in range(sample_steps, 0, -1):
t_current = i / sample_steps
t_next = (i - 1) / sample_steps
t_current_tensor = torch.tensor([t_current] * b, device=z.device)
v_current, _ = self.model(z, t_current_tensor)
z_pred = z - dt * v_current
t_next_tensor = torch.tensor([t_next] * b, device=z.device)
v_next, _ = self.model(z_pred, t_next_tensor)
v_avg = 0.5 * (v_current + v_next)
z = z - dt * v_avg
images.append(z)
return images