This commit is contained in:
neulus
2025-10-10 15:55:35 +09:00
parent 6bb6c09638
commit c47d91a349
10 changed files with 1381 additions and 112 deletions

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from datasets import Dataset, DatasetDict, Image
from src.dataset.preprocess import make_transform
transform = make_transform(512)
transform = make_transform(256)
def get_dataset() -> tuple[Dataset, Dataset]:

View File

@@ -18,3 +18,9 @@ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
return tensor * std + mean
def normalize(tensor: torch.Tensor) -> torch.Tensor:
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
return (tensor - mean) / std

154
src/gan.py Normal file
View File

@@ -0,0 +1,154 @@
from collections import namedtuple
import torch
import torch.nn as nn
from torchvision import models
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
def forward(self, inp):
return (inp - self.shift) / self.scale
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class PatchDiscriminator(nn.Module):
def __init__(self):
super(PatchDiscriminator, self).__init__()
self.scaling_layer = ScalingLayer()
_vgg = models.vgg16(pretrained=True)
self.slice1 = nn.Sequential(_vgg.features[:4])
self.slice2 = nn.Sequential(_vgg.features[4:9])
self.slice3 = nn.Sequential(_vgg.features[9:16])
self.slice4 = nn.Sequential(_vgg.features[16:23])
self.slice5 = nn.Sequential(_vgg.features[23:30])
self.binary_classifier1 = nn.Sequential(
nn.Conv2d(64, 32, kernel_size=4, stride=4, padding=0, bias=True),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=4, stride=4, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier1[-1].weight)
self.binary_classifier2 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=4, stride=4, padding=0, bias=True),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=2, stride=2, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier2[-1].weight)
self.binary_classifier3 = nn.Sequential(
nn.Conv2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True),
nn.ReLU(),
nn.Conv2d(128, 1, kernel_size=2, stride=2, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier3[-1].weight)
self.binary_classifier4 = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=2, stride=2, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier4[-1].weight)
self.binary_classifier5 = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0, bias=True),
)
nn.init.zeros_(self.binary_classifier5[-1].weight)
def forward(self, x):
x = self.scaling_layer(x)
features1 = self.slice1(x)
features2 = self.slice2(features1)
features3 = self.slice3(features2)
features4 = self.slice4(features3)
features5 = self.slice5(features4)
# torch.Size([1, 64, 256, 256]) torch.Size([1, 128, 128, 128]) torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 512, 16, 16])
bc1 = self.binary_classifier1(features1).flatten(1)
bc2 = self.binary_classifier2(features2).flatten(1)
bc3 = self.binary_classifier3(features3).flatten(1)
bc4 = self.binary_classifier4(features4).flatten(1)
bc5 = self.binary_classifier5(features5).flatten(1)
return bc1 + bc2 + bc3 + bc4 + bc5
def gan_disc_loss(real_preds, fake_preds, disc_type="bce"):
if disc_type == "bce":
real_loss = nn.functional.binary_cross_entropy_with_logits(
real_preds, torch.ones_like(real_preds)
)
fake_loss = nn.functional.binary_cross_entropy_with_logits(
fake_preds, torch.zeros_like(fake_preds)
)
# eval its online performance
avg_real_preds = real_preds.mean().item()
avg_fake_preds = fake_preds.mean().item()
with torch.no_grad():
acc = (real_preds > 0).sum().item() + (fake_preds < 0).sum().item()
acc = acc / (real_preds.numel() + fake_preds.numel())
if disc_type == "hinge":
real_loss = nn.functional.relu(1 - real_preds).mean()
fake_loss = nn.functional.relu(1 + fake_preds).mean()
with torch.no_grad():
acc = (real_preds > 0).sum().item() + (fake_preds < 0).sum().item()
acc = acc / (real_preds.numel() + fake_preds.numel())
avg_real_preds = real_preds.mean().item()
avg_fake_preds = fake_preds.mean().item()
return (real_loss + fake_loss) * 0.5, avg_real_preds, avg_fake_preds, acc # type: ignore

View File

@@ -419,6 +419,7 @@ class DINOv3ViTModel(nn.Module):
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
latents = []
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module(
@@ -426,11 +427,13 @@ class DINOv3ViTModel(nn.Module):
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
latents.append(hidden_states)
sequence_output = self.norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return {
"last_hidden_state": sequence_output,
"latents": latents,
"pooler_output": pooled_output,
}

144
src/model/hourglass.py Normal file
View File

@@ -0,0 +1,144 @@
from typing import Optional
import torch
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torch import nn
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
from src.model.dino import (
DINOv3ViTEmbeddings,
DINOv3ViTRopePositionEmbedding,
)
from src.model.utransformer import DinoConditionedLayer, TimestepEmbedder
class Hourgrass(nn.Module):
def __init__(
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
):
super().__init__()
self.config = config
self.scale_factor = scale_factor
assert config.num_hidden_layers % scale_factor == 0
self.embeddings = DINOv3ViTEmbeddings(config)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
self.t_embedder = TimestepEmbedder(config.hidden_size)
self.encoder_layers = nn.ModuleList(
[
DinoConditionedLayer(config, True)
for _ in range(config.num_hidden_layers)
]
)
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# freeze pretrained
self.embeddings.requires_grad_(False)
self.rope_embeddings.requires_grad_(False)
self.encoder_norm.requires_grad_(False)
def forward(
self,
pixel_values: torch.Tensor,
time: torch.Tensor,
# cond: torch.Tensor,
bool_masked_pos: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
if time.dim() == 0:
time = time.repeat(pixel_values.shape[0])
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
t = self.t_embedder(time).unsqueeze(1)
conditioning_input = t.to(x.dtype)
residual = []
for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0:
residual.append(x)
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = self.encoder_norm(x)
reversed_residual = residual[::-1]
for i, layer_module in enumerate(self.decoder_layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = x + reversed_residual[i]
x = self.decoder_norm(x)
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
def get_residual(
self,
pixel_values: torch.Tensor,
time: Optional[torch.Tensor],
do_condition: bool,
):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
x = self.embeddings(pixel_values, bool_masked_pos=None)
if do_condition:
t = self.t_embedder(time).unsqueeze(1)
# y = self.y_embedder(cond, self.training).unsqueeze(1)
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
conditioning_input = t.to(x.dtype)
else:
conditioning_input = None
residual = []
for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0:
residual.append(x)
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=None,
position_embeddings=position_embeddings,
do_condition=do_condition,
)
return residual
@staticmethod
def from_pretrained_backbone(name: str):
config = DINOv3ViTConfig.from_pretrained(name)
instance = UTransformer(config, 0).to("cuda:1")
weight_dict = {}
with safe_open(
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
) as f:
for key in f.keys():
new_key = key.replace("layer.", "encoder_layers.").replace(
"norm.", "encoder_norm."
)
weight_dict[new_key] = f.get_tensor(key)
instance.load_state_dict(weight_dict, strict=False)
return instance

View File

@@ -188,30 +188,83 @@ class DinoV3ViTDecoder(nn.Module):
self.projection = nn.Linear(
config.hidden_size,
self.num_channels_out * (self.patch_size**2),
config.num_channels * (self.patch_size**2),
bias=True,
)
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
nn.init.zeros_(self.projection.weight)
nn.init.zeros_(self.projection.bias)
nn.init.zeros_(
self.projection.bias
) if self.projection.bias is not None else None
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
batch_size = x.shape[0]
x = x[:, 1 + self.config.num_register_tokens :, :]
x = self.projection(x)
p = self.config.patch_size
h_grid = image_size[0] // p
w_grid = image_size[1] // p
assert x.shape[1] == h_grid * w_grid
x = self.projection(x)
x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2)
return self.pixel_shuffle(x)
x = self.pixel_shuffle(x)
return x
# how about transposed conv decoderclass DinoV3ViTDecoder(nn.Module):
# class DinoV3ViTDecoder(nn.Module):
# def __init__(self, config: DINOv3ViTConfig):
# super().__init__()
# self.config = config
# self.num_channels_out = config.num_channels
# self.patch_size = config.patch_size
# intermediate_channels = config.hidden_size // 4
# self.decoder_block = nn.Sequential(
# nn.ConvTranspose2d(
# in_channels=config.hidden_size,
# out_channels=intermediate_channels,
# kernel_size=self.patch_size,
# stride=self.patch_size,
# bias=True,
# ),
# nn.LayerNorm(intermediate_channels),
# nn.GELU(),
# nn.Conv2d(
# in_channels=intermediate_channels,
# out_channels=config.num_channels,
# kernel_size=1,
# bias=True,
# ),
# )
# nn.init.zeros_(self.decoder_block[-1].weight)
# nn.init.zeros_(self.decoder_block[-1].bias)
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
# batch_size = x.shape[0]
# x = x[:, 1 + self.config.num_register_tokens :, :]
# p = self.config.patch_size
# h_grid = image_size[0] // p
# w_grid = image_size[1] // p
# assert x.shape[1] == h_grid * w_grid
# x = x.transpose(1, 2).reshape(
# batch_size, self.config.hidden_size, h_grid, w_grid
# )
# x = self.decoder_block(x)
# return x
class UTransformer(nn.Module):
@@ -261,6 +314,9 @@ class UTransformer(nn.Module):
bool_masked_pos: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
if time.dim() == 0:
time = time.repeat(pixel_values.shape[0])
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)

230
src/rf.py
View File

@@ -1,117 +1,175 @@
import lpips
import torch
import torch.optim as optim
from pytorch_msssim import ms_ssim
from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
TargetConditionalFlowMatcher,
VariancePreservingConditionalFlowMatcher,
)
from torchdiffeq import odeint
import wandb
from src.dataset.preprocess import denormalize
from src.gan import PatchDiscriminator, gan_disc_loss
def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
"""Loss = sqrt(||x||₂² + c²) - c"""
d = x.shape[1:].numel()
c = c * (d**0.5)
x = torch.linalg.vector_norm(x.flatten(1), ord=2, dim=1)
return torch.sqrt(x**2 + c**2) - c
lecam_loss_weight = 0.1
lecam_anchor_real_logits = 0.0
lecam_anchor_fake_logits = 0.0
lecam_beta = 0.9
use_lecam = True
class RF:
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_mse_enhanced"):
def __init__(self, model, fm="otcfm", loss="mse"):
self.model = model
self.ln = ln
self.ushaped = ushaped
self.loss_fn = loss_fn
self.loss = loss
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1") if "lpips" in loss_fn else None
def forward(self, x0, z1):
# x0 is gt / z is noise
b = x0.size(0)
if self.ushaped:
a = 4.0 # HYPERPARMS
u = torch.rand((b,), device=x0.device)
t = torch.acosh(1 + (torch.cosh(torch.tensor(a)) - 1) * u) / a
t = t * (torch.randint(0, 2, (b,), device=x0.device) * 2 - 1) * 0.5 + 0.5
elif self.ln:
nt = torch.randn((b,)).to(x0.device)
t = torch.sigmoid(nt)
sigma = 0.0
if fm == "otcfm":
self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif fm == "icfm":
self.fm = ConditionalFlowMatcher(sigma=sigma)
elif fm == "fm":
self.fm = TargetConditionalFlowMatcher(sigma=sigma)
elif fm == "si":
self.fm = VariancePreservingConditionalFlowMatcher(sigma=sigma)
else:
t = torch.rand((b,)).to(x0.device)
texp = t.view([b, *([1] * len(x0.shape[1:]))])
zt = (1 - texp) * x0 + texp * z1
raise NotImplementedError(
f"Unknown model {fm}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
)
vtheta, residual = self.model(zt, t)
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1")
self.lpips2 = lpips.LPIPS(net="alex").to("cuda:1")
if self.loss_fn == "lpips_huber":
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
if not self.lpips:
raise Exception
discriminator = PatchDiscriminator().to("cuda:1")
discriminator.requires_grad_(True)
self.discriminator = discriminator
self.optimizer_D = optim.AdamW(
discriminator.parameters(),
lr=2e-4,
weight_decay=1e-3,
betas=(0.9, 0.95),
)
huber = torch.nn.functional.huber_loss(
z1 - x0, vtheta, reduction="none"
).mean(dim=list(range(1, len(x0.shape))))
def gan_loss(self, real, fake):
global lecam_beta, lecam_anchor_real_logits, lecam_anchor_fake_logits, use_lecam
real_preds = self.discriminator(real)
fake_preds = self.discriminator(fake.detach())
d_loss, avg_real_logits, avg_fake_logits, disc_acc = gan_disc_loss(
real_preds, fake_preds, "hinge"
)
lecam_anchor_real_logits = (
lecam_beta * lecam_anchor_real_logits + (1 - lecam_beta) * avg_real_logits
)
lecam_anchor_fake_logits = (
lecam_beta * lecam_anchor_fake_logits + (1 - lecam_beta) * avg_fake_logits
)
total_d_loss = d_loss.mean()
d_loss_item = total_d_loss.item()
if use_lecam:
# penalize the real logits to fake and fake logits to real.
lecam_loss = (real_preds - lecam_anchor_fake_logits).pow(2).mean() + (
fake_preds - lecam_anchor_real_logits
).pow(2).mean()
lecam_loss_item = lecam_loss.item()
total_d_loss = total_d_loss + lecam_loss * lecam_loss_weight
wandb.log(
{
"gan/lecam_loss": lecam_loss_item,
"gan/lecam_anchor_real_logits": lecam_anchor_real_logits,
"gan/lecam_anchor_fake_logits": lecam_anchor_fake_logits,
}
)
wandb.log(
{
"gan/discriminator_loss": d_loss_item,
"gan/discriminator_accuracy": disc_acc,
}
)
self.optimizer_D.zero_grad()
total_d_loss.backward(retain_graph=True)
self.optimizer_D.step()
def forward(self, gt, cloud):
t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore
vt, _ = self.model(xt, t)
if self.loss == "mse":
loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
loss_list = {"train/mse": loss.mean().item()}
elif self.loss == "lpips_mse":
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
weight = t.view(-1)
loss = (1 - weight) * huber + lpips
elif self.loss_fn == "lpips_mse":
if not self.lpips:
raise Exception
loss_list = {
"train/mse": mse.mean().item(),
"train/lpips": lpips.mean().item(),
}
loss = mse + lpips * 2.0
elif self.loss == "gan_lpips_mse":
self.gan_loss(
denormalize(gt),
denormalize(xt + (1 - t[:, None, None, None]) * vt),
)
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
loss = ((z1 - x0 - vtheta) ** 2).mean(
dim=list(range(1, len(x0.shape)))
) + 2.0 * lpips
elif self.loss_fn == "lpips_mse_enhanced":
if not self.lpips:
raise Exception
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
alexlpips = self.lpips2(
denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
)
dino_loss = torch.stack(
[
(
1
- torch.nn.functional.cosine_similarity(
v_residual, x0_residual, dim=-1
)
).mean(dim=-1)
for v_residual, x0_residual in zip(
residual, self.model.get_residual(x0, None, False)
)
]
).mean(dim=0) * (2 - t.view(-1))
loss = (
((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
+ 2.0 * lpips
+ dino_loss
gan = (
-self.discriminator(
denormalize(xt + (1 - t[:, None, None, None]) * vt),
).mean(-1)
* 0.01
)
ssim = 1 - ms_ssim(
denormalize(gt),
denormalize(xt + (1 - t[:, None, None, None]) * vt),
data_range=1.0,
size_average=False,
)
loss_list = {
"train/mse": mse.mean().item(),
"train/lpips": lpips.mean().item(),
"train/alexlpips": alexlpips.mean().item(),
"train/gan": gan.mean().item(),
"train/ssim": ssim.mean().item(),
}
loss = mse + lpips * 4.0 + gan + alexlpips + ssim
else:
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
raise Exception
tlist = loss.detach().cpu().reshape(-1).tolist()
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
return loss.mean(), ttloss
return loss.mean(), ttloss, loss_list
@torch.no_grad()
def sample(self, z1, sample_steps=5):
b = z1.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
images = [z1]
z = z1
for i in range(sample_steps, 0, -1):
t = i / sample_steps
t = torch.tensor([t] * b).to(z.device)
def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor:
t_span = torch.linspace(0, 1, 2, device=cloud.device)
traj = odeint(
lambda t, x: self.model(x, t)[0],
cloud,
t_span,
rtol=tol,
atol=tol,
method=integration,
)
vc, _ = self.model(z, t)
z = z - dt * vc
images.append(z)
return images
return [traj[i] for i in range(traj.shape[0])] # type: ignore
@torch.no_grad()
def sample_heun(self, z1, sample_steps=50):