This commit is contained in:
neulus
2025-09-30 10:27:41 +09:00
parent 0ccf1ff42d
commit 8966cafb8f
2 changed files with 46 additions and 19 deletions

View File

@@ -2,19 +2,19 @@ import torch
from torchvision.transforms import v2
# note that its LVD-1689M (not SAT)
# note that its SAT
def make_transform(resize_size: int = 256):
to_tensor = v2.ToImage()
resize = v2.Resize((resize_size, resize_size), antialias=True)
to_float = v2.ToDtype(torch.float32, scale=True)
normalize = v2.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
mean=(0.430, 0.411, 0.296),
std=(0.213, 0.156, 0.143),
)
return v2.Compose([to_tensor, resize, to_float, normalize])
def denormalize(tensor):
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device)
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
return tensor * std + mean