things
This commit is contained in:
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from datasets import Dataset, DatasetDict, Image
|
||||
from src.dataset.preprocess import make_transform
|
||||
|
||||
transform = make_transform(512)
|
||||
transform = make_transform(256)
|
||||
|
||||
|
||||
def get_dataset() -> tuple[Dataset, Dataset]:
|
||||
|
||||
@@ -18,3 +18,9 @@ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
||||
return tensor * std + mean
|
||||
|
||||
|
||||
def normalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
||||
return (tensor - mean) / std
|
||||
|
||||
154
src/gan.py
Normal file
154
src/gan.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
|
||||
class ScalingLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(ScalingLayer, self).__init__()
|
||||
self.register_buffer(
|
||||
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
||||
)
|
||||
self.register_buffer(
|
||||
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
return (inp - self.shift) / self.scale
|
||||
|
||||
|
||||
class vgg16(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(vgg16, self).__init__()
|
||||
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(4):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(4, 9):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(9, 16):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(16, 23):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(23, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple(
|
||||
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
||||
)
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
return out
|
||||
|
||||
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(PatchDiscriminator, self).__init__()
|
||||
self.scaling_layer = ScalingLayer()
|
||||
|
||||
_vgg = models.vgg16(pretrained=True)
|
||||
|
||||
self.slice1 = nn.Sequential(_vgg.features[:4])
|
||||
self.slice2 = nn.Sequential(_vgg.features[4:9])
|
||||
self.slice3 = nn.Sequential(_vgg.features[9:16])
|
||||
self.slice4 = nn.Sequential(_vgg.features[16:23])
|
||||
self.slice5 = nn.Sequential(_vgg.features[23:30])
|
||||
|
||||
self.binary_classifier1 = nn.Sequential(
|
||||
nn.Conv2d(64, 32, kernel_size=4, stride=4, padding=0, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 1, kernel_size=4, stride=4, padding=0, bias=True),
|
||||
)
|
||||
nn.init.zeros_(self.binary_classifier1[-1].weight)
|
||||
|
||||
self.binary_classifier2 = nn.Sequential(
|
||||
nn.Conv2d(128, 64, kernel_size=4, stride=4, padding=0, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 1, kernel_size=2, stride=2, padding=0, bias=True),
|
||||
)
|
||||
nn.init.zeros_(self.binary_classifier2[-1].weight)
|
||||
|
||||
self.binary_classifier3 = nn.Sequential(
|
||||
nn.Conv2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 1, kernel_size=2, stride=2, padding=0, bias=True),
|
||||
)
|
||||
nn.init.zeros_(self.binary_classifier3[-1].weight)
|
||||
|
||||
self.binary_classifier4 = nn.Sequential(
|
||||
nn.Conv2d(512, 1, kernel_size=2, stride=2, padding=0, bias=True),
|
||||
)
|
||||
nn.init.zeros_(self.binary_classifier4[-1].weight)
|
||||
|
||||
self.binary_classifier5 = nn.Sequential(
|
||||
nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
||||
)
|
||||
nn.init.zeros_(self.binary_classifier5[-1].weight)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.scaling_layer(x)
|
||||
features1 = self.slice1(x)
|
||||
features2 = self.slice2(features1)
|
||||
features3 = self.slice3(features2)
|
||||
features4 = self.slice4(features3)
|
||||
features5 = self.slice5(features4)
|
||||
|
||||
# torch.Size([1, 64, 256, 256]) torch.Size([1, 128, 128, 128]) torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 512, 16, 16])
|
||||
|
||||
bc1 = self.binary_classifier1(features1).flatten(1)
|
||||
bc2 = self.binary_classifier2(features2).flatten(1)
|
||||
bc3 = self.binary_classifier3(features3).flatten(1)
|
||||
bc4 = self.binary_classifier4(features4).flatten(1)
|
||||
bc5 = self.binary_classifier5(features5).flatten(1)
|
||||
|
||||
return bc1 + bc2 + bc3 + bc4 + bc5
|
||||
|
||||
|
||||
def gan_disc_loss(real_preds, fake_preds, disc_type="bce"):
|
||||
if disc_type == "bce":
|
||||
real_loss = nn.functional.binary_cross_entropy_with_logits(
|
||||
real_preds, torch.ones_like(real_preds)
|
||||
)
|
||||
fake_loss = nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_preds, torch.zeros_like(fake_preds)
|
||||
)
|
||||
# eval its online performance
|
||||
avg_real_preds = real_preds.mean().item()
|
||||
avg_fake_preds = fake_preds.mean().item()
|
||||
|
||||
with torch.no_grad():
|
||||
acc = (real_preds > 0).sum().item() + (fake_preds < 0).sum().item()
|
||||
acc = acc / (real_preds.numel() + fake_preds.numel())
|
||||
|
||||
if disc_type == "hinge":
|
||||
real_loss = nn.functional.relu(1 - real_preds).mean()
|
||||
fake_loss = nn.functional.relu(1 + fake_preds).mean()
|
||||
|
||||
with torch.no_grad():
|
||||
acc = (real_preds > 0).sum().item() + (fake_preds < 0).sum().item()
|
||||
acc = acc / (real_preds.numel() + fake_preds.numel())
|
||||
|
||||
avg_real_preds = real_preds.mean().item()
|
||||
avg_fake_preds = fake_preds.mean().item()
|
||||
|
||||
return (real_loss + fake_loss) * 0.5, avg_real_preds, avg_fake_preds, acc # type: ignore
|
||||
@@ -419,6 +419,7 @@ class DINOv3ViTModel(nn.Module):
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
latents = []
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
hidden_states = layer_module(
|
||||
@@ -426,11 +427,13 @@ class DINOv3ViTModel(nn.Module):
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
latents.append(hidden_states)
|
||||
|
||||
sequence_output = self.norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
return {
|
||||
"last_hidden_state": sequence_output,
|
||||
"latents": latents,
|
||||
"pooler_output": pooled_output,
|
||||
}
|
||||
|
||||
144
src/model/hourglass.py
Normal file
144
src/model/hourglass.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.dino import (
|
||||
DINOv3ViTEmbeddings,
|
||||
DINOv3ViTRopePositionEmbedding,
|
||||
)
|
||||
from src.model.utransformer import DinoConditionedLayer, TimestepEmbedder
|
||||
|
||||
|
||||
class Hourgrass(nn.Module):
|
||||
def __init__(
|
||||
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
assert config.num_hidden_layers % scale_factor == 0
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, True)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
# freeze pretrained
|
||||
self.embeddings.requires_grad_(False)
|
||||
self.rope_embeddings.requires_grad_(False)
|
||||
self.encoder_norm.requires_grad_(False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: torch.Tensor,
|
||||
# cond: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if time.dim() == 0:
|
||||
time = time.repeat(pixel_values.shape[0])
|
||||
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
|
||||
conditioning_input = t.to(x.dtype)
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
x = self.encoder_norm(x)
|
||||
|
||||
reversed_residual = residual[::-1]
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
x = x + reversed_residual[i]
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
def get_residual(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: Optional[torch.Tensor],
|
||||
do_condition: bool,
|
||||
):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=None)
|
||||
|
||||
if do_condition:
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
conditioning_input = t.to(x.dtype)
|
||||
else:
|
||||
conditioning_input = None
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=do_condition,
|
||||
)
|
||||
|
||||
return residual
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = UTransformer(config, 0).to("cuda:1")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
new_key = key.replace("layer.", "encoder_layers.").replace(
|
||||
"norm.", "encoder_norm."
|
||||
)
|
||||
|
||||
weight_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
instance.load_state_dict(weight_dict, strict=False)
|
||||
|
||||
return instance
|
||||
@@ -188,30 +188,83 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.num_channels_out * (self.patch_size**2),
|
||||
config.num_channels * (self.patch_size**2),
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||
|
||||
nn.init.zeros_(self.projection.weight)
|
||||
nn.init.zeros_(self.projection.bias)
|
||||
nn.init.zeros_(
|
||||
self.projection.bias
|
||||
) if self.projection.bias is not None else None
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
x = self.projection(x)
|
||||
|
||||
p = self.config.patch_size
|
||||
h_grid = image_size[0] // p
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
x = self.projection(x)
|
||||
|
||||
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
|
||||
|
||||
return self.pixel_shuffle(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# how about transposed conv decoderclass DinoV3ViTDecoder(nn.Module):
|
||||
# class DinoV3ViTDecoder(nn.Module):
|
||||
# def __init__(self, config: DINOv3ViTConfig):
|
||||
# super().__init__()
|
||||
# self.config = config
|
||||
# self.num_channels_out = config.num_channels
|
||||
# self.patch_size = config.patch_size
|
||||
|
||||
# intermediate_channels = config.hidden_size // 4
|
||||
|
||||
# self.decoder_block = nn.Sequential(
|
||||
# nn.ConvTranspose2d(
|
||||
# in_channels=config.hidden_size,
|
||||
# out_channels=intermediate_channels,
|
||||
# kernel_size=self.patch_size,
|
||||
# stride=self.patch_size,
|
||||
# bias=True,
|
||||
# ),
|
||||
# nn.LayerNorm(intermediate_channels),
|
||||
# nn.GELU(),
|
||||
# nn.Conv2d(
|
||||
# in_channels=intermediate_channels,
|
||||
# out_channels=config.num_channels,
|
||||
# kernel_size=1,
|
||||
# bias=True,
|
||||
# ),
|
||||
# )
|
||||
|
||||
# nn.init.zeros_(self.decoder_block[-1].weight)
|
||||
# nn.init.zeros_(self.decoder_block[-1].bias)
|
||||
|
||||
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
# batch_size = x.shape[0]
|
||||
|
||||
# x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
# p = self.config.patch_size
|
||||
# h_grid = image_size[0] // p
|
||||
# w_grid = image_size[1] // p
|
||||
# assert x.shape[1] == h_grid * w_grid
|
||||
|
||||
# x = x.transpose(1, 2).reshape(
|
||||
# batch_size, self.config.hidden_size, h_grid, w_grid
|
||||
# )
|
||||
# x = self.decoder_block(x)
|
||||
|
||||
# return x
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
@@ -261,6 +314,9 @@ class UTransformer(nn.Module):
|
||||
bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if time.dim() == 0:
|
||||
time = time.repeat(pixel_values.shape[0])
|
||||
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
|
||||
230
src/rf.py
230
src/rf.py
@@ -1,117 +1,175 @@
|
||||
import lpips
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from pytorch_msssim import ms_ssim
|
||||
from torchcfm.conditional_flow_matching import (
|
||||
ConditionalFlowMatcher,
|
||||
ExactOptimalTransportConditionalFlowMatcher,
|
||||
TargetConditionalFlowMatcher,
|
||||
VariancePreservingConditionalFlowMatcher,
|
||||
)
|
||||
from torchdiffeq import odeint
|
||||
|
||||
import wandb
|
||||
from src.dataset.preprocess import denormalize
|
||||
from src.gan import PatchDiscriminator, gan_disc_loss
|
||||
|
||||
|
||||
def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
|
||||
"""Loss = sqrt(||x||₂² + c²) - c"""
|
||||
d = x.shape[1:].numel()
|
||||
c = c * (d**0.5)
|
||||
x = torch.linalg.vector_norm(x.flatten(1), ord=2, dim=1)
|
||||
return torch.sqrt(x**2 + c**2) - c
|
||||
lecam_loss_weight = 0.1
|
||||
lecam_anchor_real_logits = 0.0
|
||||
lecam_anchor_fake_logits = 0.0
|
||||
lecam_beta = 0.9
|
||||
use_lecam = True
|
||||
|
||||
|
||||
class RF:
|
||||
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_mse_enhanced"):
|
||||
def __init__(self, model, fm="otcfm", loss="mse"):
|
||||
self.model = model
|
||||
self.ln = ln
|
||||
self.ushaped = ushaped
|
||||
self.loss_fn = loss_fn
|
||||
self.loss = loss
|
||||
|
||||
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1") if "lpips" in loss_fn else None
|
||||
|
||||
def forward(self, x0, z1):
|
||||
# x0 is gt / z is noise
|
||||
b = x0.size(0)
|
||||
if self.ushaped:
|
||||
a = 4.0 # HYPERPARMS
|
||||
u = torch.rand((b,), device=x0.device)
|
||||
t = torch.acosh(1 + (torch.cosh(torch.tensor(a)) - 1) * u) / a
|
||||
t = t * (torch.randint(0, 2, (b,), device=x0.device) * 2 - 1) * 0.5 + 0.5
|
||||
elif self.ln:
|
||||
nt = torch.randn((b,)).to(x0.device)
|
||||
t = torch.sigmoid(nt)
|
||||
sigma = 0.0
|
||||
if fm == "otcfm":
|
||||
self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
|
||||
elif fm == "icfm":
|
||||
self.fm = ConditionalFlowMatcher(sigma=sigma)
|
||||
elif fm == "fm":
|
||||
self.fm = TargetConditionalFlowMatcher(sigma=sigma)
|
||||
elif fm == "si":
|
||||
self.fm = VariancePreservingConditionalFlowMatcher(sigma=sigma)
|
||||
else:
|
||||
t = torch.rand((b,)).to(x0.device)
|
||||
texp = t.view([b, *([1] * len(x0.shape[1:]))])
|
||||
zt = (1 - texp) * x0 + texp * z1
|
||||
raise NotImplementedError(
|
||||
f"Unknown model {fm}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
|
||||
)
|
||||
|
||||
vtheta, residual = self.model(zt, t)
|
||||
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1")
|
||||
self.lpips2 = lpips.LPIPS(net="alex").to("cuda:1")
|
||||
|
||||
if self.loss_fn == "lpips_huber":
|
||||
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
|
||||
if not self.lpips:
|
||||
raise Exception
|
||||
discriminator = PatchDiscriminator().to("cuda:1")
|
||||
discriminator.requires_grad_(True)
|
||||
self.discriminator = discriminator
|
||||
self.optimizer_D = optim.AdamW(
|
||||
discriminator.parameters(),
|
||||
lr=2e-4,
|
||||
weight_decay=1e-3,
|
||||
betas=(0.9, 0.95),
|
||||
)
|
||||
|
||||
huber = torch.nn.functional.huber_loss(
|
||||
z1 - x0, vtheta, reduction="none"
|
||||
).mean(dim=list(range(1, len(x0.shape))))
|
||||
def gan_loss(self, real, fake):
|
||||
global lecam_beta, lecam_anchor_real_logits, lecam_anchor_fake_logits, use_lecam
|
||||
|
||||
real_preds = self.discriminator(real)
|
||||
fake_preds = self.discriminator(fake.detach())
|
||||
d_loss, avg_real_logits, avg_fake_logits, disc_acc = gan_disc_loss(
|
||||
real_preds, fake_preds, "hinge"
|
||||
)
|
||||
|
||||
lecam_anchor_real_logits = (
|
||||
lecam_beta * lecam_anchor_real_logits + (1 - lecam_beta) * avg_real_logits
|
||||
)
|
||||
lecam_anchor_fake_logits = (
|
||||
lecam_beta * lecam_anchor_fake_logits + (1 - lecam_beta) * avg_fake_logits
|
||||
)
|
||||
total_d_loss = d_loss.mean()
|
||||
d_loss_item = total_d_loss.item()
|
||||
if use_lecam:
|
||||
# penalize the real logits to fake and fake logits to real.
|
||||
lecam_loss = (real_preds - lecam_anchor_fake_logits).pow(2).mean() + (
|
||||
fake_preds - lecam_anchor_real_logits
|
||||
).pow(2).mean()
|
||||
lecam_loss_item = lecam_loss.item()
|
||||
total_d_loss = total_d_loss + lecam_loss * lecam_loss_weight
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
"gan/lecam_loss": lecam_loss_item,
|
||||
"gan/lecam_anchor_real_logits": lecam_anchor_real_logits,
|
||||
"gan/lecam_anchor_fake_logits": lecam_anchor_fake_logits,
|
||||
}
|
||||
)
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
"gan/discriminator_loss": d_loss_item,
|
||||
"gan/discriminator_accuracy": disc_acc,
|
||||
}
|
||||
)
|
||||
|
||||
self.optimizer_D.zero_grad()
|
||||
total_d_loss.backward(retain_graph=True)
|
||||
self.optimizer_D.step()
|
||||
|
||||
def forward(self, gt, cloud):
|
||||
t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore
|
||||
vt, _ = self.model(xt, t)
|
||||
|
||||
if self.loss == "mse":
|
||||
loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
|
||||
loss_list = {"train/mse": loss.mean().item()}
|
||||
elif self.loss == "lpips_mse":
|
||||
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
|
||||
lpips = self.lpips(
|
||||
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||
denormalize(gt) * 2 - 1,
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
|
||||
)
|
||||
weight = t.view(-1)
|
||||
|
||||
loss = (1 - weight) * huber + lpips
|
||||
elif self.loss_fn == "lpips_mse":
|
||||
if not self.lpips:
|
||||
raise Exception
|
||||
|
||||
loss_list = {
|
||||
"train/mse": mse.mean().item(),
|
||||
"train/lpips": lpips.mean().item(),
|
||||
}
|
||||
loss = mse + lpips * 2.0
|
||||
elif self.loss == "gan_lpips_mse":
|
||||
self.gan_loss(
|
||||
denormalize(gt),
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
)
|
||||
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
|
||||
lpips = self.lpips(
|
||||
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||
denormalize(gt) * 2 - 1,
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
|
||||
)
|
||||
loss = ((z1 - x0 - vtheta) ** 2).mean(
|
||||
dim=list(range(1, len(x0.shape)))
|
||||
) + 2.0 * lpips
|
||||
elif self.loss_fn == "lpips_mse_enhanced":
|
||||
if not self.lpips:
|
||||
raise Exception
|
||||
|
||||
lpips = self.lpips(
|
||||
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||
alexlpips = self.lpips2(
|
||||
denormalize(gt) * 2 - 1,
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
|
||||
)
|
||||
dino_loss = torch.stack(
|
||||
[
|
||||
(
|
||||
1
|
||||
- torch.nn.functional.cosine_similarity(
|
||||
v_residual, x0_residual, dim=-1
|
||||
)
|
||||
).mean(dim=-1)
|
||||
for v_residual, x0_residual in zip(
|
||||
residual, self.model.get_residual(x0, None, False)
|
||||
)
|
||||
]
|
||||
).mean(dim=0) * (2 - t.view(-1))
|
||||
loss = (
|
||||
((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
||||
+ 2.0 * lpips
|
||||
+ dino_loss
|
||||
gan = (
|
||||
-self.discriminator(
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
).mean(-1)
|
||||
* 0.01
|
||||
)
|
||||
ssim = 1 - ms_ssim(
|
||||
denormalize(gt),
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
data_range=1.0,
|
||||
size_average=False,
|
||||
)
|
||||
loss_list = {
|
||||
"train/mse": mse.mean().item(),
|
||||
"train/lpips": lpips.mean().item(),
|
||||
"train/alexlpips": alexlpips.mean().item(),
|
||||
"train/gan": gan.mean().item(),
|
||||
"train/ssim": ssim.mean().item(),
|
||||
}
|
||||
loss = mse + lpips * 4.0 + gan + alexlpips + ssim
|
||||
else:
|
||||
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
||||
raise Exception
|
||||
|
||||
tlist = loss.detach().cpu().reshape(-1).tolist()
|
||||
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
|
||||
|
||||
return loss.mean(), ttloss
|
||||
return loss.mean(), ttloss, loss_list
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, z1, sample_steps=5):
|
||||
b = z1.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
|
||||
images = [z1]
|
||||
z = z1
|
||||
for i in range(sample_steps, 0, -1):
|
||||
t = i / sample_steps
|
||||
t = torch.tensor([t] * b).to(z.device)
|
||||
def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor:
|
||||
t_span = torch.linspace(0, 1, 2, device=cloud.device)
|
||||
traj = odeint(
|
||||
lambda t, x: self.model(x, t)[0],
|
||||
cloud,
|
||||
t_span,
|
||||
rtol=tol,
|
||||
atol=tol,
|
||||
method=integration,
|
||||
)
|
||||
|
||||
vc, _ = self.model(z, t)
|
||||
|
||||
z = z - dt * vc
|
||||
images.append(z)
|
||||
return images
|
||||
return [traj[i] for i in range(traj.shape[0])] # type: ignore
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(self, z1, sample_steps=50):
|
||||
|
||||
Reference in New Issue
Block a user