fix wrong average of psnr
This commit is contained in:
77
src/rf.py
77
src/rf.py
@@ -13,17 +13,13 @@ def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
|
||||
|
||||
|
||||
class RF:
|
||||
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_huber"):
|
||||
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_mse_enhanced"):
|
||||
self.model = model
|
||||
self.ln = ln
|
||||
self.ushaped = ushaped
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
self.lpips = (
|
||||
lpips.LPIPS(net="vgg").to(model.device)
|
||||
if loss_fn == "lpips_huber"
|
||||
else None
|
||||
)
|
||||
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1") if "lpips" in loss_fn else None
|
||||
|
||||
def forward(self, x0, z1):
|
||||
# x0 is gt / z is noise
|
||||
@@ -41,7 +37,7 @@ class RF:
|
||||
texp = t.view([b, *([1] * len(x0.shape[1:]))])
|
||||
zt = (1 - texp) * x0 + texp * z1
|
||||
|
||||
vtheta = self.model(zt, t)
|
||||
vtheta, residual = self.model(zt, t)
|
||||
|
||||
if self.loss_fn == "lpips_huber":
|
||||
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
|
||||
@@ -57,6 +53,41 @@ class RF:
|
||||
weight = t.view(-1)
|
||||
|
||||
loss = (1 - weight) * huber + lpips
|
||||
elif self.loss_fn == "lpips_mse":
|
||||
if not self.lpips:
|
||||
raise Exception
|
||||
|
||||
lpips = self.lpips(
|
||||
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||
)
|
||||
loss = ((z1 - x0 - vtheta) ** 2).mean(
|
||||
dim=list(range(1, len(x0.shape)))
|
||||
) + 2.0 * lpips
|
||||
elif self.loss_fn == "lpips_mse_enhanced":
|
||||
if not self.lpips:
|
||||
raise Exception
|
||||
|
||||
lpips = self.lpips(
|
||||
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||
)
|
||||
dino_loss = torch.stack(
|
||||
[
|
||||
(
|
||||
1
|
||||
- torch.nn.functional.cosine_similarity(
|
||||
v_residual, x0_residual, dim=-1
|
||||
)
|
||||
).mean(dim=-1)
|
||||
for v_residual, x0_residual in zip(
|
||||
residual, self.model.get_residual(x0, None, False)
|
||||
)
|
||||
]
|
||||
).mean(dim=0) * (2 - t.view(-1))
|
||||
loss = (
|
||||
((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
||||
+ 2.0 * lpips
|
||||
+ dino_loss
|
||||
)
|
||||
else:
|
||||
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
||||
|
||||
@@ -66,7 +97,7 @@ class RF:
|
||||
return loss.mean(), ttloss
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, z1, sample_steps=50):
|
||||
def sample(self, z1, sample_steps=5):
|
||||
b = z1.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
|
||||
@@ -76,8 +107,36 @@ class RF:
|
||||
t = i / sample_steps
|
||||
t = torch.tensor([t] * b).to(z.device)
|
||||
|
||||
vc = self.model(z, t)
|
||||
vc, _ = self.model(z, t)
|
||||
|
||||
z = z - dt * vc
|
||||
images.append(z)
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(self, z1, sample_steps=50):
|
||||
b = z1.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
|
||||
images = [z1]
|
||||
z = z1
|
||||
|
||||
for i in range(sample_steps, 0, -1):
|
||||
t_current = i / sample_steps
|
||||
t_next = (i - 1) / sample_steps
|
||||
|
||||
t_current_tensor = torch.tensor([t_current] * b, device=z.device)
|
||||
|
||||
v_current, _ = self.model(z, t_current_tensor)
|
||||
z_pred = z - dt * v_current
|
||||
|
||||
t_next_tensor = torch.tensor([t_next] * b, device=z.device)
|
||||
v_next, _ = self.model(z_pred, t_next_tensor)
|
||||
|
||||
v_avg = 0.5 * (v_current + v_next)
|
||||
|
||||
z = z - dt * v_avg
|
||||
|
||||
images.append(z)
|
||||
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user