import torch from torchvision.transforms import v2 # note that its SAT def make_transform(resize_size: int = 256): to_tensor = v2.ToImage() resize = v2.Resize((resize_size, resize_size), antialias=True) to_float = v2.ToDtype(torch.float32, scale=True) normalize = v2.Normalize( mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143), ) return v2.Compose([to_tensor, resize, to_float, normalize]) def denormalize(tensor: torch.Tensor) -> torch.Tensor: mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device) std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device) return tensor * std + mean def normalize(tensor: torch.Tensor) -> torch.Tensor: mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device) std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device) return (tensor - mean) / std