training code
This commit is contained in:
0
src/dataset/__init__.py
Normal file
0
src/dataset/__init__.py
Normal file
14
src/dataset/preprocess.py
Normal file
14
src/dataset/preprocess.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
|
||||
# note that its LVD-1689M (not SAT)
|
||||
def make_transform(resize_size: int = 256):
|
||||
to_tensor = v2.ToImage()
|
||||
resize = v2.Resize((resize_size, resize_size), antialias=True)
|
||||
to_float = v2.ToDtype(torch.float32, scale=True)
|
||||
normalize = v2.Normalize(
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
std=(0.229, 0.224, 0.225),
|
||||
)
|
||||
return v2.Compose([to_tensor, resize, to_float, normalize])
|
||||
41
src/rf.py
Normal file
41
src/rf.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
|
||||
class RF:
|
||||
def __init__(self, model, ln=True):
|
||||
self.model = model
|
||||
self.ln = ln
|
||||
|
||||
def forward(self, x0, x1):
|
||||
b = x0.size(0)
|
||||
if self.ln:
|
||||
nt = torch.randn((b,)).to(x0.device)
|
||||
t = torch.sigmoid(nt)
|
||||
else:
|
||||
t = torch.rand((b,)).to(x0.device)
|
||||
texp = t.view([b, *([1] * len(x0.shape[1:]))])
|
||||
zt = (1 - texp) * x0 + texp * x1
|
||||
vtheta = self.model(zt, t)
|
||||
batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean(
|
||||
dim=list(range(1, len(x0.shape)))
|
||||
)
|
||||
tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
|
||||
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
|
||||
return batchwise_mse.mean(), ttloss
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, x0, sample_steps=50):
|
||||
b = x0.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
dt = torch.tensor([dt] * b).to(x0.device).view([b, *([1] * len(x0.shape[1:]))])
|
||||
images = [x0]
|
||||
z = x0
|
||||
for i in range(sample_steps):
|
||||
t = i / sample_steps
|
||||
t = torch.tensor([t] * b).to(z.device)
|
||||
|
||||
vc = self.model(z, t)
|
||||
|
||||
z = z + dt * vc
|
||||
images.append(z)
|
||||
return images
|
||||
Reference in New Issue
Block a user