improved rf
This commit is contained in:
75
src/rf.py
75
src/rf.py
@@ -1,41 +1,84 @@
|
||||
import math
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
|
||||
from src.dataset.preprocess import denormalize
|
||||
|
||||
|
||||
def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
|
||||
"""Loss = sqrt(||x||₂² + c²) - c"""
|
||||
d = x.shape[1:].numel()
|
||||
c = c * (d**0.5)
|
||||
x = torch.linalg.vector_norm(x.flatten(1), ord=2, dim=1)
|
||||
return torch.sqrt(x**2 + c**2) - c
|
||||
|
||||
|
||||
class RF:
|
||||
def __init__(self, model, ln=True):
|
||||
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_huber"):
|
||||
self.model = model
|
||||
self.ln = ln
|
||||
self.ushaped = ushaped
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
def forward(self, x0, x1):
|
||||
self.lpips = (
|
||||
lpips.LPIPS(net="vgg").to(model.device)
|
||||
if loss_fn == "lpips_huber"
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x0, z1):
|
||||
# x0 is gt / z is noise
|
||||
b = x0.size(0)
|
||||
if self.ln:
|
||||
if self.ushaped:
|
||||
a = 4.0 # HYPERPARMS
|
||||
u = torch.rand((b,)).to(x0.device)
|
||||
t = torch.asinh((2 * u - 1) * math.sinh(a)) / a
|
||||
elif self.ln:
|
||||
nt = torch.randn((b,)).to(x0.device)
|
||||
t = torch.sigmoid(nt)
|
||||
else:
|
||||
t = torch.rand((b,)).to(x0.device)
|
||||
texp = t.view([b, *([1] * len(x0.shape[1:]))])
|
||||
zt = (1 - texp) * x0 + texp * x1
|
||||
zt = (1 - texp) * x0 + texp * z1
|
||||
|
||||
vtheta = self.model(zt, t)
|
||||
batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean(
|
||||
dim=list(range(1, len(x0.shape)))
|
||||
)
|
||||
tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
|
||||
|
||||
if self.loss_fn == "lpips_huber":
|
||||
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
|
||||
if not self.lpips:
|
||||
raise Exception
|
||||
|
||||
huber = torch.nn.functional.huber_loss(
|
||||
z1 - x0, vtheta, reduction="none"
|
||||
).mean(dim=list(range(1, len(x0.shape))))
|
||||
lpips = self.lpips(
|
||||
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||
)
|
||||
weight = t.view(-1)
|
||||
|
||||
loss = (1 - weight) * huber + lpips
|
||||
else:
|
||||
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
||||
|
||||
tlist = loss.detach().cpu().reshape(-1).tolist()
|
||||
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
|
||||
return batchwise_mse.mean(), ttloss
|
||||
|
||||
return loss.mean(), ttloss
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, x0, sample_steps=50):
|
||||
b = x0.size(0)
|
||||
def sample(self, z1, sample_steps=50):
|
||||
b = z1.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
dt = torch.tensor([dt] * b).to(x0.device).view([b, *([1] * len(x0.shape[1:]))])
|
||||
images = [x0]
|
||||
z = x0
|
||||
for i in range(sample_steps):
|
||||
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
|
||||
images = [z1]
|
||||
z = z1
|
||||
for i in range(sample_steps, 0, -1):
|
||||
t = i / sample_steps
|
||||
t = torch.tensor([t] * b).to(z.device)
|
||||
|
||||
vc = self.model(z, t)
|
||||
|
||||
z = z + dt * vc
|
||||
z = z - dt * vc
|
||||
images.append(z)
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user