import torch class RF: def __init__(self, model, ln=True): self.model = model self.ln = ln def forward(self, x0, x1, cond): b = x0.size(0) if self.ln: nt = torch.randn((b,)).to(x0.device) t = torch.sigmoid(nt) else: t = torch.rand((b,)).to(x0.device) texp = t.view([b, *([1] * len(x0.shape[1:]))]) zt = (1 - texp) * x0 + texp * x1 vtheta = self.model(zt, t, cond) batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean( dim=list(range(1, len(x0.shape))) ) tlist = batchwise_mse.detach().cpu().reshape(-1).tolist() ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)] return batchwise_mse.mean(), ttloss @torch.no_grad() def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0): b = z.size(0) dt = 1.0 / sample_steps dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) images = [z] for i in range(sample_steps, 0, -1): t = i / sample_steps t = torch.tensor([t] * b).to(z.device) vc = self.model(z, t, cond) if null_cond is not None: vu = self.model(z, t, null_cond) vc = vu + cfg * (vc - vu) z = z - dt * vc images.append(z) return images