27 lines
927 B
Python
27 lines
927 B
Python
import torch
|
|
from torchvision.transforms import v2
|
|
|
|
|
|
# note that its SAT
|
|
def make_transform(resize_size: int = 256):
|
|
to_tensor = v2.ToImage()
|
|
resize = v2.Resize((resize_size, resize_size), antialias=True)
|
|
to_float = v2.ToDtype(torch.float32, scale=True)
|
|
normalize = v2.Normalize(
|
|
mean=(0.430, 0.411, 0.296),
|
|
std=(0.213, 0.156, 0.143),
|
|
)
|
|
return v2.Compose([to_tensor, resize, to_float, normalize])
|
|
|
|
|
|
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
|
|
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
|
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
|
return tensor * std + mean
|
|
|
|
|
|
def normalize(tensor: torch.Tensor) -> torch.Tensor:
|
|
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
|
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
|
return (tensor - mean) / std
|