This commit is contained in:
neulus
2025-09-29 22:51:54 +09:00
parent 02ac62fb1d
commit 12a165e461
38 changed files with 436 additions and 30 deletions

13
src/benchmark/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from torchmetrics.image import (
LearnedPerceptualImagePatchSimilarity,
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
)
psnr = PeakSignalNoiseRatio(255.0, reduction="none")
ssim = StructuralSimilarityIndexMeasure(reduction="none")
lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex", reduction="none")
def benchmark(image1, image2):
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)

66
src/dataset/cuhk_cr1.py Normal file
View File

@@ -0,0 +1,66 @@
from pathlib import Path
from datasets import Dataset, DatasetDict, Image
from src.dataset.preprocess import make_transform
transform = make_transform(512)
def get_dataset() -> tuple[Dataset, Dataset]:
data_dir = Path("/data2/C-CUHK/CUHK-CR1")
train_cloud = sorted((data_dir / "train/cloud").glob("*.png"))
train_no_cloud = sorted((data_dir / "train/label").glob("*.png"))
test_cloud = sorted((data_dir / "test/cloud").glob("*.png"))
test_no_cloud = sorted((data_dir / "test/label").glob("*.png"))
dataset = DatasetDict(
{
"train": Dataset.from_dict(
{
"cloud": [str(p) for p in train_cloud],
"label": [str(p) for p in train_no_cloud],
}
)
.cast_column("cloud", Image())
.cast_column("label", Image()),
"test": Dataset.from_dict(
{
"cloud": [str(p) for p in test_cloud],
"label": [str(p) for p in test_no_cloud],
}
)
.cast_column("cloud", Image())
.cast_column("label", Image()),
}
)
train_dataset = dataset["train"]
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
batch_size=32,
remove_columns=train_dataset.column_names,
)
train_dataset.set_format(type="torch", columns=["x0", "x1"])
test_dataset = dataset["test"]
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
batch_size=32,
remove_columns=test_dataset.column_names,
)
test_dataset.set_format(type="torch", columns=["x0", "x1"])
return train_dataset, test_dataset
def preprocess_function(examples):
x0_list = []
x1_list = []
for x0_img, x1_img in zip(examples["cloud"], examples["label"]):
x0_transformed = transform(x0_img)
x1_transformed = transform(x1_img)
x0_list.append(x0_transformed)
x1_list.append(x1_transformed)
return {"x0": x0_list, "x1": x1_list}

View File

@@ -12,3 +12,9 @@ def make_transform(resize_size: int = 256):
std=(0.229, 0.224, 0.225),
)
return v2.Compose([to_tensor, resize, to_float, normalize])
def denormalize(tensor):
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device)
return tensor * std + mean