improved rf

This commit is contained in:
neulus
2025-10-01 18:44:26 +09:00
parent 49025c4d87
commit 29eb04d1a4
8 changed files with 150 additions and 27 deletions

View File

@@ -1,41 +1,84 @@
import math
import lpips
import torch
from src.dataset.preprocess import denormalize
def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
"""Loss = sqrt(||x||₂² + c²) - c"""
d = x.shape[1:].numel()
c = c * (d**0.5)
x = torch.linalg.vector_norm(x.flatten(1), ord=2, dim=1)
return torch.sqrt(x**2 + c**2) - c
class RF:
def __init__(self, model, ln=True):
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_huber"):
self.model = model
self.ln = ln
self.ushaped = ushaped
self.loss_fn = loss_fn
def forward(self, x0, x1):
self.lpips = (
lpips.LPIPS(net="vgg").to(model.device)
if loss_fn == "lpips_huber"
else None
)
def forward(self, x0, z1):
# x0 is gt / z is noise
b = x0.size(0)
if self.ln:
if self.ushaped:
a = 4.0 # HYPERPARMS
u = torch.rand((b,)).to(x0.device)
t = torch.asinh((2 * u - 1) * math.sinh(a)) / a
elif self.ln:
nt = torch.randn((b,)).to(x0.device)
t = torch.sigmoid(nt)
else:
t = torch.rand((b,)).to(x0.device)
texp = t.view([b, *([1] * len(x0.shape[1:]))])
zt = (1 - texp) * x0 + texp * x1
zt = (1 - texp) * x0 + texp * z1
vtheta = self.model(zt, t)
batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean(
dim=list(range(1, len(x0.shape)))
)
tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
if self.loss_fn == "lpips_huber":
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
if not self.lpips:
raise Exception
huber = torch.nn.functional.huber_loss(
z1 - x0, vtheta, reduction="none"
).mean(dim=list(range(1, len(x0.shape))))
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
)
weight = t.view(-1)
loss = (1 - weight) * huber + lpips
else:
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
tlist = loss.detach().cpu().reshape(-1).tolist()
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
return batchwise_mse.mean(), ttloss
return loss.mean(), ttloss
@torch.no_grad()
def sample(self, x0, sample_steps=50):
b = x0.size(0)
def sample(self, z1, sample_steps=50):
b = z1.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt] * b).to(x0.device).view([b, *([1] * len(x0.shape[1:]))])
images = [x0]
z = x0
for i in range(sample_steps):
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
images = [z1]
z = z1
for i in range(sample_steps, 0, -1):
t = i / sample_steps
t = torch.tensor([t] * b).to(z.device)
vc = self.model(z, t)
z = z + dt * vc
z = z - dt * vc
images.append(z)
return images