This commit is contained in:
neulus
2025-10-10 15:55:35 +09:00
parent 6bb6c09638
commit c47d91a349
10 changed files with 1381 additions and 112 deletions

230
src/rf.py
View File

@@ -1,117 +1,175 @@
import lpips
import torch
import torch.optim as optim
from pytorch_msssim import ms_ssim
from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
TargetConditionalFlowMatcher,
VariancePreservingConditionalFlowMatcher,
)
from torchdiffeq import odeint
import wandb
from src.dataset.preprocess import denormalize
from src.gan import PatchDiscriminator, gan_disc_loss
def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
"""Loss = sqrt(||x||₂² + c²) - c"""
d = x.shape[1:].numel()
c = c * (d**0.5)
x = torch.linalg.vector_norm(x.flatten(1), ord=2, dim=1)
return torch.sqrt(x**2 + c**2) - c
lecam_loss_weight = 0.1
lecam_anchor_real_logits = 0.0
lecam_anchor_fake_logits = 0.0
lecam_beta = 0.9
use_lecam = True
class RF:
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_mse_enhanced"):
def __init__(self, model, fm="otcfm", loss="mse"):
self.model = model
self.ln = ln
self.ushaped = ushaped
self.loss_fn = loss_fn
self.loss = loss
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1") if "lpips" in loss_fn else None
def forward(self, x0, z1):
# x0 is gt / z is noise
b = x0.size(0)
if self.ushaped:
a = 4.0 # HYPERPARMS
u = torch.rand((b,), device=x0.device)
t = torch.acosh(1 + (torch.cosh(torch.tensor(a)) - 1) * u) / a
t = t * (torch.randint(0, 2, (b,), device=x0.device) * 2 - 1) * 0.5 + 0.5
elif self.ln:
nt = torch.randn((b,)).to(x0.device)
t = torch.sigmoid(nt)
sigma = 0.0
if fm == "otcfm":
self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif fm == "icfm":
self.fm = ConditionalFlowMatcher(sigma=sigma)
elif fm == "fm":
self.fm = TargetConditionalFlowMatcher(sigma=sigma)
elif fm == "si":
self.fm = VariancePreservingConditionalFlowMatcher(sigma=sigma)
else:
t = torch.rand((b,)).to(x0.device)
texp = t.view([b, *([1] * len(x0.shape[1:]))])
zt = (1 - texp) * x0 + texp * z1
raise NotImplementedError(
f"Unknown model {fm}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
)
vtheta, residual = self.model(zt, t)
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1")
self.lpips2 = lpips.LPIPS(net="alex").to("cuda:1")
if self.loss_fn == "lpips_huber":
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
if not self.lpips:
raise Exception
discriminator = PatchDiscriminator().to("cuda:1")
discriminator.requires_grad_(True)
self.discriminator = discriminator
self.optimizer_D = optim.AdamW(
discriminator.parameters(),
lr=2e-4,
weight_decay=1e-3,
betas=(0.9, 0.95),
)
huber = torch.nn.functional.huber_loss(
z1 - x0, vtheta, reduction="none"
).mean(dim=list(range(1, len(x0.shape))))
def gan_loss(self, real, fake):
global lecam_beta, lecam_anchor_real_logits, lecam_anchor_fake_logits, use_lecam
real_preds = self.discriminator(real)
fake_preds = self.discriminator(fake.detach())
d_loss, avg_real_logits, avg_fake_logits, disc_acc = gan_disc_loss(
real_preds, fake_preds, "hinge"
)
lecam_anchor_real_logits = (
lecam_beta * lecam_anchor_real_logits + (1 - lecam_beta) * avg_real_logits
)
lecam_anchor_fake_logits = (
lecam_beta * lecam_anchor_fake_logits + (1 - lecam_beta) * avg_fake_logits
)
total_d_loss = d_loss.mean()
d_loss_item = total_d_loss.item()
if use_lecam:
# penalize the real logits to fake and fake logits to real.
lecam_loss = (real_preds - lecam_anchor_fake_logits).pow(2).mean() + (
fake_preds - lecam_anchor_real_logits
).pow(2).mean()
lecam_loss_item = lecam_loss.item()
total_d_loss = total_d_loss + lecam_loss * lecam_loss_weight
wandb.log(
{
"gan/lecam_loss": lecam_loss_item,
"gan/lecam_anchor_real_logits": lecam_anchor_real_logits,
"gan/lecam_anchor_fake_logits": lecam_anchor_fake_logits,
}
)
wandb.log(
{
"gan/discriminator_loss": d_loss_item,
"gan/discriminator_accuracy": disc_acc,
}
)
self.optimizer_D.zero_grad()
total_d_loss.backward(retain_graph=True)
self.optimizer_D.step()
def forward(self, gt, cloud):
t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore
vt, _ = self.model(xt, t)
if self.loss == "mse":
loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
loss_list = {"train/mse": loss.mean().item()}
elif self.loss == "lpips_mse":
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
weight = t.view(-1)
loss = (1 - weight) * huber + lpips
elif self.loss_fn == "lpips_mse":
if not self.lpips:
raise Exception
loss_list = {
"train/mse": mse.mean().item(),
"train/lpips": lpips.mean().item(),
}
loss = mse + lpips * 2.0
elif self.loss == "gan_lpips_mse":
self.gan_loss(
denormalize(gt),
denormalize(xt + (1 - t[:, None, None, None]) * vt),
)
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
loss = ((z1 - x0 - vtheta) ** 2).mean(
dim=list(range(1, len(x0.shape)))
) + 2.0 * lpips
elif self.loss_fn == "lpips_mse_enhanced":
if not self.lpips:
raise Exception
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
alexlpips = self.lpips2(
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
dino_loss = torch.stack(
[
(
1
- torch.nn.functional.cosine_similarity(
v_residual, x0_residual, dim=-1
)
).mean(dim=-1)
for v_residual, x0_residual in zip(
residual, self.model.get_residual(x0, None, False)
)
]
).mean(dim=0) * (2 - t.view(-1))
loss = (
((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
+ 2.0 * lpips
+ dino_loss
gan = (
-self.discriminator(
denormalize(xt + (1 - t[:, None, None, None]) * vt),
).mean(-1)
* 0.01
)
ssim = 1 - ms_ssim(
denormalize(gt),
denormalize(xt + (1 - t[:, None, None, None]) * vt),
data_range=1.0,
size_average=False,
)
loss_list = {
"train/mse": mse.mean().item(),
"train/lpips": lpips.mean().item(),
"train/alexlpips": alexlpips.mean().item(),
"train/gan": gan.mean().item(),
"train/ssim": ssim.mean().item(),
}
loss = mse + lpips * 4.0 + gan + alexlpips + ssim
else:
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
raise Exception
tlist = loss.detach().cpu().reshape(-1).tolist()
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
return loss.mean(), ttloss
return loss.mean(), ttloss, loss_list
@torch.no_grad()
def sample(self, z1, sample_steps=5):
b = z1.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
images = [z1]
z = z1
for i in range(sample_steps, 0, -1):
t = i / sample_steps
t = torch.tensor([t] * b).to(z.device)
def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor:
t_span = torch.linspace(0, 1, 2, device=cloud.device)
traj = odeint(
lambda t, x: self.model(x, t)[0],
cloud,
t_span,
rtol=tol,
atol=tol,
method=integration,
)
vc, _ = self.model(z, t)
z = z - dt * vc
images.append(z)
return images
return [traj[i] for i in range(traj.shape[0])] # type: ignore
@torch.no_grad()
def sample_heun(self, z1, sample_steps=50):