import lpips import torch import torch.optim as optim from pytorch_msssim import ms_ssim from torchcfm.conditional_flow_matching import ( ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, TargetConditionalFlowMatcher, VariancePreservingConditionalFlowMatcher, ) from torchdiffeq import odeint import wandb from src.dataset.preprocess import denormalize from src.gan import PatchDiscriminator, gan_disc_loss lecam_loss_weight = 0.1 lecam_anchor_real_logits = 0.0 lecam_anchor_fake_logits = 0.0 lecam_beta = 0.9 use_lecam = True class RF: def __init__(self, model, fm="otcfm", loss="mse"): self.model = model self.loss = loss sigma = 0.0 if fm == "otcfm": self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) elif fm == "icfm": self.fm = ConditionalFlowMatcher(sigma=sigma) elif fm == "fm": self.fm = TargetConditionalFlowMatcher(sigma=sigma) elif fm == "si": self.fm = VariancePreservingConditionalFlowMatcher(sigma=sigma) else: raise NotImplementedError( f"Unknown model {fm}, must be one of ['otcfm', 'icfm', 'fm', 'si']" ) self.lpips = lpips.LPIPS(net="vgg").to("cuda:1") self.lpips2 = lpips.LPIPS(net="alex").to("cuda:1") discriminator = PatchDiscriminator().to("cuda:1") discriminator.requires_grad_(True) self.discriminator = discriminator self.optimizer_D = optim.AdamW( discriminator.parameters(), lr=2e-4, weight_decay=1e-3, betas=(0.9, 0.95), ) def gan_loss(self, real, fake): global lecam_beta, lecam_anchor_real_logits, lecam_anchor_fake_logits, use_lecam real_preds = self.discriminator(real) fake_preds = self.discriminator(fake.detach()) d_loss, avg_real_logits, avg_fake_logits, disc_acc = gan_disc_loss( real_preds, fake_preds, "hinge" ) lecam_anchor_real_logits = ( lecam_beta * lecam_anchor_real_logits + (1 - lecam_beta) * avg_real_logits ) lecam_anchor_fake_logits = ( lecam_beta * lecam_anchor_fake_logits + (1 - lecam_beta) * avg_fake_logits ) total_d_loss = d_loss.mean() d_loss_item = total_d_loss.item() if use_lecam: # penalize the real logits to fake and fake logits to real. lecam_loss = (real_preds - lecam_anchor_fake_logits).pow(2).mean() + ( fake_preds - lecam_anchor_real_logits ).pow(2).mean() lecam_loss_item = lecam_loss.item() total_d_loss = total_d_loss + lecam_loss * lecam_loss_weight wandb.log( { "gan/lecam_loss": lecam_loss_item, "gan/lecam_anchor_real_logits": lecam_anchor_real_logits, "gan/lecam_anchor_fake_logits": lecam_anchor_fake_logits, } ) wandb.log( { "gan/discriminator_loss": d_loss_item, "gan/discriminator_accuracy": disc_acc, } ) self.optimizer_D.zero_grad() total_d_loss.backward(retain_graph=True) self.optimizer_D.step() def forward(self, gt, cloud): t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore vt, _ = self.model(xt, t) if self.loss == "mse": loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape)))) loss_list = {"train/mse": loss.mean().item()} elif self.loss == "lpips_mse": mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape)))) lpips = self.lpips( denormalize(gt) * 2 - 1, denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1, ) loss_list = { "train/mse": mse.mean().item(), "train/lpips": lpips.mean().item(), } loss = mse + lpips * 2.0 elif self.loss == "gan_lpips_mse": self.gan_loss( denormalize(gt), denormalize(xt + (1 - t[:, None, None, None]) * vt), ) mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape)))) lpips = self.lpips( denormalize(gt) * 2 - 1, denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1, ) alexlpips = self.lpips2( denormalize(gt) * 2 - 1, denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1, ) gan = ( -self.discriminator( denormalize(xt + (1 - t[:, None, None, None]) * vt), ).mean(-1) * 0.01 ) ssim = 1 - ms_ssim( denormalize(gt), denormalize(xt + (1 - t[:, None, None, None]) * vt), data_range=1.0, size_average=False, ) loss_list = { "train/mse": mse.mean().item(), "train/lpips": lpips.mean().item(), "train/alexlpips": alexlpips.mean().item(), "train/gan": gan.mean().item(), "train/ssim": ssim.mean().item(), } loss = mse + lpips * 4.0 + gan + alexlpips + ssim else: raise Exception tlist = loss.detach().cpu().reshape(-1).tolist() ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)] return loss.mean(), ttloss, loss_list @torch.no_grad() def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor: t_span = torch.linspace(0, 1, 2, device=cloud.device) traj = odeint( lambda t, x: self.model(x, t)[0], cloud, t_span, rtol=tol, atol=tol, method=integration, ) return [traj[i] for i in range(traj.shape[0])] # type: ignore @torch.no_grad() def sample_heun(self, z1, sample_steps=50): b = z1.size(0) dt = 1.0 / sample_steps images = [z1] z = z1 for i in range(sample_steps, 0, -1): t_current = i / sample_steps t_next = (i - 1) / sample_steps t_current_tensor = torch.tensor([t_current] * b, device=z.device) v_current, _ = self.model(z, t_current_tensor) z_pred = z - dt * v_current t_next_tensor = torch.tensor([t_next] * b, device=z.device) v_next, _ = self.model(z_pred, t_next_tensor) v_avg = 0.5 * (v_current + v_next) z = z - dt * v_avg images.append(z) return images