Files
cloud-removal/src/rf.py
2025-10-10 15:55:35 +09:00

201 lines
6.7 KiB
Python

import lpips
import torch
import torch.optim as optim
from pytorch_msssim import ms_ssim
from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
TargetConditionalFlowMatcher,
VariancePreservingConditionalFlowMatcher,
)
from torchdiffeq import odeint
import wandb
from src.dataset.preprocess import denormalize
from src.gan import PatchDiscriminator, gan_disc_loss
lecam_loss_weight = 0.1
lecam_anchor_real_logits = 0.0
lecam_anchor_fake_logits = 0.0
lecam_beta = 0.9
use_lecam = True
class RF:
def __init__(self, model, fm="otcfm", loss="mse"):
self.model = model
self.loss = loss
sigma = 0.0
if fm == "otcfm":
self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif fm == "icfm":
self.fm = ConditionalFlowMatcher(sigma=sigma)
elif fm == "fm":
self.fm = TargetConditionalFlowMatcher(sigma=sigma)
elif fm == "si":
self.fm = VariancePreservingConditionalFlowMatcher(sigma=sigma)
else:
raise NotImplementedError(
f"Unknown model {fm}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
)
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1")
self.lpips2 = lpips.LPIPS(net="alex").to("cuda:1")
discriminator = PatchDiscriminator().to("cuda:1")
discriminator.requires_grad_(True)
self.discriminator = discriminator
self.optimizer_D = optim.AdamW(
discriminator.parameters(),
lr=2e-4,
weight_decay=1e-3,
betas=(0.9, 0.95),
)
def gan_loss(self, real, fake):
global lecam_beta, lecam_anchor_real_logits, lecam_anchor_fake_logits, use_lecam
real_preds = self.discriminator(real)
fake_preds = self.discriminator(fake.detach())
d_loss, avg_real_logits, avg_fake_logits, disc_acc = gan_disc_loss(
real_preds, fake_preds, "hinge"
)
lecam_anchor_real_logits = (
lecam_beta * lecam_anchor_real_logits + (1 - lecam_beta) * avg_real_logits
)
lecam_anchor_fake_logits = (
lecam_beta * lecam_anchor_fake_logits + (1 - lecam_beta) * avg_fake_logits
)
total_d_loss = d_loss.mean()
d_loss_item = total_d_loss.item()
if use_lecam:
# penalize the real logits to fake and fake logits to real.
lecam_loss = (real_preds - lecam_anchor_fake_logits).pow(2).mean() + (
fake_preds - lecam_anchor_real_logits
).pow(2).mean()
lecam_loss_item = lecam_loss.item()
total_d_loss = total_d_loss + lecam_loss * lecam_loss_weight
wandb.log(
{
"gan/lecam_loss": lecam_loss_item,
"gan/lecam_anchor_real_logits": lecam_anchor_real_logits,
"gan/lecam_anchor_fake_logits": lecam_anchor_fake_logits,
}
)
wandb.log(
{
"gan/discriminator_loss": d_loss_item,
"gan/discriminator_accuracy": disc_acc,
}
)
self.optimizer_D.zero_grad()
total_d_loss.backward(retain_graph=True)
self.optimizer_D.step()
def forward(self, gt, cloud):
t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore
vt, _ = self.model(xt, t)
if self.loss == "mse":
loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
loss_list = {"train/mse": loss.mean().item()}
elif self.loss == "lpips_mse":
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
lpips = self.lpips(
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
loss_list = {
"train/mse": mse.mean().item(),
"train/lpips": lpips.mean().item(),
}
loss = mse + lpips * 2.0
elif self.loss == "gan_lpips_mse":
self.gan_loss(
denormalize(gt),
denormalize(xt + (1 - t[:, None, None, None]) * vt),
)
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
lpips = self.lpips(
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
alexlpips = self.lpips2(
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
gan = (
-self.discriminator(
denormalize(xt + (1 - t[:, None, None, None]) * vt),
).mean(-1)
* 0.01
)
ssim = 1 - ms_ssim(
denormalize(gt),
denormalize(xt + (1 - t[:, None, None, None]) * vt),
data_range=1.0,
size_average=False,
)
loss_list = {
"train/mse": mse.mean().item(),
"train/lpips": lpips.mean().item(),
"train/alexlpips": alexlpips.mean().item(),
"train/gan": gan.mean().item(),
"train/ssim": ssim.mean().item(),
}
loss = mse + lpips * 4.0 + gan + alexlpips + ssim
else:
raise Exception
tlist = loss.detach().cpu().reshape(-1).tolist()
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
return loss.mean(), ttloss, loss_list
@torch.no_grad()
def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor:
t_span = torch.linspace(0, 1, 2, device=cloud.device)
traj = odeint(
lambda t, x: self.model(x, t)[0],
cloud,
t_span,
rtol=tol,
atol=tol,
method=integration,
)
return [traj[i] for i in range(traj.shape[0])] # type: ignore
@torch.no_grad()
def sample_heun(self, z1, sample_steps=50):
b = z1.size(0)
dt = 1.0 / sample_steps
images = [z1]
z = z1
for i in range(sample_steps, 0, -1):
t_current = i / sample_steps
t_next = (i - 1) / sample_steps
t_current_tensor = torch.tensor([t_current] * b, device=z.device)
v_current, _ = self.model(z, t_current_tensor)
z_pred = z - dt * v_current
t_next_tensor = torch.tensor([t_next] * b, device=z.device)
v_next, _ = self.model(z_pred, t_next_tensor)
v_avg = 0.5 * (v_current + v_next)
z = z - dt * v_avg
images.append(z)
return images