add rf
This commit is contained in:
43
rf.py
Normal file
43
rf.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
|
||||
class RF:
|
||||
def __init__(self, model, ln=True):
|
||||
self.model = model
|
||||
self.ln = ln
|
||||
|
||||
def forward(self, x0, x1, cond):
|
||||
b = x0.size(0)
|
||||
if self.ln:
|
||||
nt = torch.randn((b,)).to(x0.device)
|
||||
t = torch.sigmoid(nt)
|
||||
else:
|
||||
t = torch.rand((b,)).to(x0.device)
|
||||
texp = t.view([b, *([1] * len(x0.shape[1:]))])
|
||||
zt = (1 - texp) * x0 + texp * x1
|
||||
vtheta = self.model(zt, t, cond)
|
||||
batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean(
|
||||
dim=list(range(1, len(x0.shape)))
|
||||
)
|
||||
tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
|
||||
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
|
||||
return batchwise_mse.mean(), ttloss
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0):
|
||||
b = z.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))])
|
||||
images = [z]
|
||||
for i in range(sample_steps, 0, -1):
|
||||
t = i / sample_steps
|
||||
t = torch.tensor([t] * b).to(z.device)
|
||||
|
||||
vc = self.model(z, t, cond)
|
||||
if null_cond is not None:
|
||||
vu = self.model(z, t, null_cond)
|
||||
vc = vu + cfg * (vc - vu)
|
||||
|
||||
z = z - dt * vc
|
||||
images.append(z)
|
||||
return images
|
||||
Reference in New Issue
Block a user