test
This commit is contained in:
13
src/benchmark/__init__.py
Normal file
13
src/benchmark/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from torchmetrics.image import (
|
||||
LearnedPerceptualImagePatchSimilarity,
|
||||
PeakSignalNoiseRatio,
|
||||
StructuralSimilarityIndexMeasure,
|
||||
)
|
||||
|
||||
psnr = PeakSignalNoiseRatio(255.0, reduction="none")
|
||||
ssim = StructuralSimilarityIndexMeasure(reduction="none")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex", reduction="none")
|
||||
|
||||
|
||||
def benchmark(image1, image2):
|
||||
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)
|
||||
66
src/dataset/cuhk_cr1.py
Normal file
66
src/dataset/cuhk_cr1.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset, DatasetDict, Image
|
||||
|
||||
from src.dataset.preprocess import make_transform
|
||||
|
||||
transform = make_transform(512)
|
||||
|
||||
|
||||
def get_dataset() -> tuple[Dataset, Dataset]:
|
||||
data_dir = Path("/data2/C-CUHK/CUHK-CR1")
|
||||
|
||||
train_cloud = sorted((data_dir / "train/cloud").glob("*.png"))
|
||||
train_no_cloud = sorted((data_dir / "train/label").glob("*.png"))
|
||||
test_cloud = sorted((data_dir / "test/cloud").glob("*.png"))
|
||||
test_no_cloud = sorted((data_dir / "test/label").glob("*.png"))
|
||||
|
||||
dataset = DatasetDict(
|
||||
{
|
||||
"train": Dataset.from_dict(
|
||||
{
|
||||
"cloud": [str(p) for p in train_cloud],
|
||||
"label": [str(p) for p in train_no_cloud],
|
||||
}
|
||||
)
|
||||
.cast_column("cloud", Image())
|
||||
.cast_column("label", Image()),
|
||||
"test": Dataset.from_dict(
|
||||
{
|
||||
"cloud": [str(p) for p in test_cloud],
|
||||
"label": [str(p) for p in test_no_cloud],
|
||||
}
|
||||
)
|
||||
.cast_column("cloud", Image())
|
||||
.cast_column("label", Image()),
|
||||
}
|
||||
)
|
||||
train_dataset = dataset["train"]
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
batch_size=32,
|
||||
remove_columns=train_dataset.column_names,
|
||||
)
|
||||
train_dataset.set_format(type="torch", columns=["x0", "x1"])
|
||||
test_dataset = dataset["test"]
|
||||
test_dataset = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
batch_size=32,
|
||||
remove_columns=test_dataset.column_names,
|
||||
)
|
||||
test_dataset.set_format(type="torch", columns=["x0", "x1"])
|
||||
|
||||
return train_dataset, test_dataset
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
x0_list = []
|
||||
x1_list = []
|
||||
for x0_img, x1_img in zip(examples["cloud"], examples["label"]):
|
||||
x0_transformed = transform(x0_img)
|
||||
x1_transformed = transform(x1_img)
|
||||
x0_list.append(x0_transformed)
|
||||
x1_list.append(x1_transformed)
|
||||
return {"x0": x0_list, "x1": x1_list}
|
||||
@@ -12,3 +12,9 @@ def make_transform(resize_size: int = 256):
|
||||
std=(0.229, 0.224, 0.225),
|
||||
)
|
||||
return v2.Compose([to_tensor, resize, to_float, normalize])
|
||||
|
||||
|
||||
def denormalize(tensor):
|
||||
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device)
|
||||
return tensor * std + mean
|
||||
|
||||
Reference in New Issue
Block a user