training code
This commit is contained in:
14
src/dataset/preprocess.py
Normal file
14
src/dataset/preprocess.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
|
||||
# note that its LVD-1689M (not SAT)
|
||||
def make_transform(resize_size: int = 256):
|
||||
to_tensor = v2.ToImage()
|
||||
resize = v2.Resize((resize_size, resize_size), antialias=True)
|
||||
to_float = v2.ToDtype(torch.float32, scale=True)
|
||||
normalize = v2.Normalize(
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
std=(0.229, 0.224, 0.225),
|
||||
)
|
||||
return v2.Compose([to_tensor, resize, to_float, normalize])
|
||||
Reference in New Issue
Block a user