training code

This commit is contained in:
neulus
2025-09-29 16:13:18 +09:00
parent d3793890a7
commit 1b445444cf
7 changed files with 406 additions and 22 deletions

0
src/dataset/__init__.py Normal file
View File

14
src/dataset/preprocess.py Normal file
View File

@@ -0,0 +1,14 @@
import torch
from torchvision.transforms import v2
# note that its LVD-1689M (not SAT)
def make_transform(resize_size: int = 256):
to_tensor = v2.ToImage()
resize = v2.Resize((resize_size, resize_size), antialias=True)
to_float = v2.ToDtype(torch.float32, scale=True)
normalize = v2.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
)
return v2.Compose([to_tensor, resize, to_float, normalize])

41
src/rf.py Normal file
View File

@@ -0,0 +1,41 @@
import torch
class RF:
def __init__(self, model, ln=True):
self.model = model
self.ln = ln
def forward(self, x0, x1):
b = x0.size(0)
if self.ln:
nt = torch.randn((b,)).to(x0.device)
t = torch.sigmoid(nt)
else:
t = torch.rand((b,)).to(x0.device)
texp = t.view([b, *([1] * len(x0.shape[1:]))])
zt = (1 - texp) * x0 + texp * x1
vtheta = self.model(zt, t)
batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean(
dim=list(range(1, len(x0.shape)))
)
tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
return batchwise_mse.mean(), ttloss
@torch.no_grad()
def sample(self, x0, sample_steps=50):
b = x0.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt] * b).to(x0.device).view([b, *([1] * len(x0.shape[1:]))])
images = [x0]
z = x0
for i in range(sample_steps):
t = i / sample_steps
t = torch.tensor([t] * b).to(z.device)
vc = self.model(z, t)
z = z + dt * vc
images.append(z)
return images