This commit is contained in:
neulus
2025-10-13 23:14:44 +09:00
parent c47d91a349
commit 3b03453e5d
28 changed files with 700 additions and 208 deletions

208
main_dit.py Normal file
View File

@@ -0,0 +1,208 @@
import math
import os
import torch
import torch.optim as optim
from torchvision.utils import make_grid
from tqdm import tqdm
import wandb
from src.benchmark import benchmark
from src.dataset.cuhk_cr2 import get_dataset
from src.dataset.preprocess import denormalize
from src.model.dit import DiT_Llama
from src.rf import RF
train_dataset, test_dataset = get_dataset()
device = "cuda:1"
batch_size = 8 * 4
accumulation_steps = 1
total_epoch = 500
steps_per_epoch = len(train_dataset) // batch_size
total_steps = steps_per_epoch * total_epoch
warmup_steps = int(0.05 * total_steps)
grad_norm = 1.0
model = DiT_Llama.from_pretrained_backbone(
"facebook/dinov3-vitl16-pretrain-sat493m",
patch_size=4,
dim=256,
n_layers=8,
n_heads=32,
).to(device)
rf = RF(model, "icfm", "lpips_mse")
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
# scheduler
def get_lr(step: int) -> float:
if step < warmup_steps:
return step / warmup_steps
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, get_lr)
wandb.init(project="cloud-removal-kmu", resume="allow")
if not (wandb.run and wandb.run.name):
raise Exception("nope")
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
start_epoch = 0
checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if "scheduler_state_dict" in checkpoint:
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
start_epoch = checkpoint["epoch"] + 1
for epoch in range(start_epoch, total_epoch):
lossbin = {i: 0 for i in range(10)}
losscnt = {i: 1e-6 for i in range(10)}
train_dataset = train_dataset.shuffle(seed=epoch)
for i in tqdm(
range(0, len(train_dataset), batch_size),
desc=f"Epoch {epoch + 1}/{total_epoch}",
):
batch = train_dataset[i : i + batch_size]
cloud = batch["cloud"].to(device)
gt = batch["gt"].to(device)
loss, blsct, loss_list = rf.forward(gt, cloud, condition=True)
loss = loss / accumulation_steps
loss.backward()
if (i // batch_size + 1) % accumulation_steps == 0:
# total_norm = torch.nn.utils.clip_grad_norm_(
# model.parameters(), max_norm=grad_norm
# )
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# wandb.log(
# {
# "train/grad_norm": total_norm.item(),
# }
# )
wandb.log(
{
"train/loss": loss.item() * accumulation_steps,
"train/lr": scheduler.get_last_lr()[0],
}
)
wandb.log(loss_list)
for t, lss in blsct:
bin_idx = min(int(t * 10), 9)
lossbin[bin_idx] += lss
losscnt[bin_idx] += 1
if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0:
# total_norm = torch.nn.utils.clip_grad_norm_(
# model.parameters(), max_norm=grad_norm
# )
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# wandb.log(
# {
# "train/grad_norm": total_norm.item(),
# }
# )
epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
epoch_metrics["epoch"] = epoch
wandb.log(epoch_metrics)
if (epoch + 1) % 50 == 0:
rf.model.eval()
psnr_sum = 0
ssim_sum = 0
lpips_sum = 0
count = 0
with torch.no_grad():
for i in tqdm(
range(0, len(test_dataset), batch_size),
desc=f"Benchmark {epoch + 1}/{total_epoch}",
):
batch = test_dataset[i : i + batch_size]
images = rf.sample(batch["cloud"].to(device), condition=True)
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
if i == 0:
for step, demo in enumerate([images[0], images[-1]]):
images = wandb.Image(
make_grid(
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
),
caption=f"step {step}",
)
wandb.log({"viz/decoded": images})
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image.shape[0]
avg_psnr = psnr_sum / count
avg_ssim = ssim_sum / count
avg_lpips = lpips_sum / count
wandb.log(
{
"eval/psnr": avg_psnr,
"eval/ssim": avg_ssim,
"eval/lpips": avg_lpips,
"epoch": epoch + 1,
}
)
rf.model.train()
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
},
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
)
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
},
checkpoint_path,
)
torch.save(
{
"epoch": epoch, # type: ignore
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
},
f"artifact/{wandb.run.name}/checkpoint_final.pt",
)
wandb.finish()

View File

@@ -1,6 +1,7 @@
import os import os
import torch import torch
from PIL import Image
from torchvision.utils import save_image from torchvision.utils import save_image
from tqdm import tqdm from tqdm import tqdm
@@ -10,7 +11,7 @@ from src.dataset.preprocess import denormalize
from src.model.utransformer import UTransformer from src.model.utransformer import UTransformer
from src.rf import RF from src.rf import RF
checkpoint_path = "artifact/daily-forest-25/checkpoint_final.pt" checkpoint_path = "artifact/firm-darkness-98/checkpoint_final.pt"
device = "cuda:1" device = "cuda:1"
save_dir = "test_images" save_dir = "test_images"
@@ -28,7 +29,7 @@ rf.model.eval()
_, test_dataset = get_dataset() _, test_dataset = get_dataset()
batch_size = 8 batch_size = 8 * 4
psnr_sum = 0 psnr_sum = 0
ssim_sum = 0 ssim_sum = 0
lpips_sum = 0 lpips_sum = 0
@@ -39,7 +40,7 @@ max_save = 10
with torch.no_grad(): with torch.no_grad():
for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"): for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"):
batch = test_dataset[i : i + batch_size] batch = test_dataset[i : i + batch_size]
images = rf.sample_heun(batch["cloud"].to(device), 1) images = rf.sample(batch["cloud"].to(device), 1)
image = denormalize(images[-1]).clamp(0, 1) image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1) original = denormalize(batch["gt"]).clamp(0, 1)
@@ -52,6 +53,23 @@ with torch.no_grad():
denormalize(batch["cloud"][j]).clamp(0, 1), denormalize(batch["cloud"][j]).clamp(0, 1),
f"{save_dir}/input_{saved_count}.png", f"{save_dir}/input_{saved_count}.png",
) )
frames = []
for step_img in images:
frame = denormalize(step_img[j]).clamp(0, 1)
frame_np = (frame.permute(1, 2, 0).cpu().numpy() * 255).astype(
"uint8"
)
frames.append(Image.fromarray(frame_np))
frames[0].save(
f"{save_dir}/transform_{saved_count}.gif",
save_all=True,
append_images=frames[1:],
duration=100,
loop=0,
)
saved_count += 1 saved_count += 1
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())

View File

@@ -1,15 +1,22 @@
import lpips
from pytorch_msssim import ssim
from torchmetrics.image import ( from torchmetrics.image import (
LearnedPerceptualImagePatchSimilarity,
PeakSignalNoiseRatio, PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure,
) )
psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3)) psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3))
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none") lp = lpips.LPIPS(net="alex")
lpips = LearnedPerceptualImagePatchSimilarity(
net_type="alex", reduction="none", normalize=True
)
def benchmark(image1, image2): def benchmark(image1, image2):
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2) return (
psnr(image1, image2),
ssim(
image1,
image2,
data_range=1.0,
size_average=False,
),
lp(image1 * 2 - 1, image2 * 2 - 1),
)

View File

@@ -373,7 +373,7 @@ class DINOv3ViTModel(nn.Module):
self.config = config self.config = config
self.embeddings = DINOv3ViTEmbeddings(config) self.embeddings = DINOv3ViTEmbeddings(config)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
self.layers = nn.ModuleList( self.layer = nn.ModuleList(
[DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)] [DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)]
) )
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -420,7 +420,7 @@ class DINOv3ViTModel(nn.Module):
position_embeddings = self.rope_embeddings(pixel_values) position_embeddings = self.rope_embeddings(pixel_values)
latents = [] latents = []
for i, layer_module in enumerate(self.layers): for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module( hidden_states = layer_module(
hidden_states, hidden_states,

387
src/model/dit.py Normal file
View File

@@ -0,0 +1,387 @@
# Code heavily based on https://github.com/Alpha-VLLM/LLaMA2-Accessory
# this is modeling code for DiT-LLaMA model
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
from src.model.dino import DINOv3ViTModel
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
dtype=next(self.parameters()).dtype
)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = int(dropout_prob > 0)
self.embedding_table = nn.Embedding(
num_classes + use_cfg_embedding, hidden_size
)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
drop_ids = drop_ids.cuda()
drop_ids = drop_ids.to(labels.device)
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class Attention(nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
self.n_rep = 1
self.head_dim = dim // n_heads
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim)
@staticmethod
def reshape_for_broadcast(freqs_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
# assert freqs_cis.shape == (x.shape[1], x.shape[-1])
_freqs_cis = freqs_cis[: x.shape[1]]
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return _freqs_cis.view(*shape)
@staticmethod
def apply_rotary_emb(xq, xk, freqs_cis):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis_xq = Attention.reshape_for_broadcast(freqs_cis, xq_)
freqs_cis_xk = Attention.reshape_for_broadcast(freqs_cis, xk_)
xq_out = torch.view_as_real(xq_ * freqs_cis_xq).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis_xk).flatten(3)
return xq_out, xk_out
def forward(self, x, freqs_cis):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
dtype = xq.dtype
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xq, xk = xq.to(dtype), xk.to(dtype)
output = F.scaled_dot_product_attention(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
dropout_p=0.0,
is_causal=False,
).permute(0, 2, 1, 3)
output = output.flatten(-2)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
class TransformerBlock(nn.Module):
def __init__(
self,
layer_id,
dim,
n_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = Attention(dim, n_heads)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(dim, 1024), 6 * dim, bias=True),
)
def forward(self, x, freqs_cis, adaln_input=None):
if adaln_input is not None:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaLN_modulation(adaln_input).chunk(6, dim=1)
)
x = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
)
x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
modulate(self.ffn_norm(x), shift_mlp, scale_mlp)
)
else:
x = x + self.attention(self.attention_norm(x), freqs_cis)
x = x + self.feed_forward(self.ffn_norm(x))
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True),
)
# # init zero
nn.init.constant_(self.linear.weight, 0)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiT_Llama(nn.Module):
def __init__(
self,
dino_cfg: DINOv3ViTConfig,
in_channels=3,
input_size=32,
patch_size=2,
dim=512,
n_layers=5,
n_heads=16,
multiple_of=256,
ffn_dim_multiplier=None,
norm_eps=1e-5,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.input_size = input_size
self.patch_size = patch_size
self.init_conv_seq = nn.Sequential(
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
nn.SiLU(),
nn.GroupNorm(32, dim // 2),
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
nn.SiLU(),
nn.GroupNorm(32, dim // 2),
)
self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
nn.init.constant_(self.x_embedder.bias, 0)
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.y_embedder = DINOv3ViTModel(dino_cfg)
self.thing = nn.Linear(dino_cfg.hidden_size, min(dim, 1024))
self.layers = nn.ModuleList(
[
TransformerBlock(
layer_id,
dim,
n_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
)
for layer_id in range(n_layers)
]
)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 16384)
# freeze
self.y_embedder.requires_grad_(False)
def unpatchify(self, x):
c = self.out_channels
p = self.patch_size
h = w = int(x.shape[1] ** 0.5)
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def patchify(self, x):
B, C, H, W = x.size()
x = x.view(
B,
C,
H // self.patch_size,
self.patch_size,
W // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
def forward(self, x, t, y):
self.freqs_cis = self.freqs_cis.to(x.device)
x = self.init_conv_seq(x)
x = self.patchify(x)
x = self.x_embedder(x)
t = self.t_embedder(t) # (N, D)
y = self.thing(self.y_embedder(y)["pooler_output"]) # (N, D)
adaln_input = t.to(x.dtype) + y.to(x.dtype)
for layer in self.layers:
x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_cfg(self, x, t, y, cfg_scale):
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, t, y)
eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
@staticmethod
def precompute_freqs_cis(dim, end, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
@staticmethod
def from_pretrained_backbone(
name: str,
in_channels=3,
input_size=32,
patch_size=2,
dim=512,
n_layers=5,
n_heads=16,
multiple_of=256,
ffn_dim_multiplier=None,
norm_eps=1e-5,
):
config = DINOv3ViTConfig.from_pretrained(name)
instance = DiT_Llama(
config,
in_channels,
input_size,
patch_size,
dim,
n_layers,
n_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
).to("cuda:1")
weight_dict = {}
with safe_open(
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
) as f:
for key in f.keys():
weight_dict[key] = f.get_tensor(key)
instance.y_embedder.load_state_dict(weight_dict, strict=True)
return instance

View File

@@ -1,144 +0,0 @@
from typing import Optional
import torch
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torch import nn
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
from src.model.dino import (
DINOv3ViTEmbeddings,
DINOv3ViTRopePositionEmbedding,
)
from src.model.utransformer import DinoConditionedLayer, TimestepEmbedder
class Hourgrass(nn.Module):
def __init__(
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
):
super().__init__()
self.config = config
self.scale_factor = scale_factor
assert config.num_hidden_layers % scale_factor == 0
self.embeddings = DINOv3ViTEmbeddings(config)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
self.t_embedder = TimestepEmbedder(config.hidden_size)
self.encoder_layers = nn.ModuleList(
[
DinoConditionedLayer(config, True)
for _ in range(config.num_hidden_layers)
]
)
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# freeze pretrained
self.embeddings.requires_grad_(False)
self.rope_embeddings.requires_grad_(False)
self.encoder_norm.requires_grad_(False)
def forward(
self,
pixel_values: torch.Tensor,
time: torch.Tensor,
# cond: torch.Tensor,
bool_masked_pos: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
if time.dim() == 0:
time = time.repeat(pixel_values.shape[0])
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
t = self.t_embedder(time).unsqueeze(1)
conditioning_input = t.to(x.dtype)
residual = []
for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0:
residual.append(x)
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = self.encoder_norm(x)
reversed_residual = residual[::-1]
for i, layer_module in enumerate(self.decoder_layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = x + reversed_residual[i]
x = self.decoder_norm(x)
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
def get_residual(
self,
pixel_values: torch.Tensor,
time: Optional[torch.Tensor],
do_condition: bool,
):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
position_embeddings = self.rope_embeddings(pixel_values)
x = self.embeddings(pixel_values, bool_masked_pos=None)
if do_condition:
t = self.t_embedder(time).unsqueeze(1)
# y = self.y_embedder(cond, self.training).unsqueeze(1)
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
conditioning_input = t.to(x.dtype)
else:
conditioning_input = None
residual = []
for i, layer_module in enumerate(self.encoder_layers):
if i % self.scale_factor == 0:
residual.append(x)
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=None,
position_embeddings=position_embeddings,
do_condition=do_condition,
)
return residual
@staticmethod
def from_pretrained_backbone(name: str):
config = DINOv3ViTConfig.from_pretrained(name)
instance = UTransformer(config, 0).to("cuda:1")
weight_dict = {}
with safe_open(
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
) as f:
for key in f.keys():
new_key = key.replace("layer.", "encoder_layers.").replace(
"norm.", "encoder_norm."
)
weight_dict[new_key] = f.get_tensor(key)
instance.load_state_dict(weight_dict, strict=False)
return instance

View File

@@ -78,6 +78,7 @@ class LabelEmbedder(nn.Module):
class DinoConditionedLayer(DINOv3ViTLayer): class DinoConditionedLayer(DINOv3ViTLayer):
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False): def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
super().__init__(config) super().__init__(config)
self.is_encoder = is_encoder
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cond = nn.MultiheadAttention( self.cond = nn.MultiheadAttention(
@@ -298,6 +299,17 @@ class UTransformer(nn.Module):
for _ in range(config.num_hidden_layers // scale_factor) for _ in range(config.num_hidden_layers // scale_factor)
] ]
) )
self.residual_merger = nn.ModuleList(
[
nn.Sequential(
nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size)
)
for _ in range(config.num_hidden_layers // scale_factor)
]
)
self.rest_decoder = nn.ModuleList(
[DinoConditionedLayer(config, False) for _ in range(4)]
)
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = DinoV3ViTDecoder(config) self.decoder = DinoV3ViTDecoder(config)
@@ -348,8 +360,22 @@ class UTransformer(nn.Module):
conditioning_input=conditioning_input, conditioning_input=conditioning_input,
attention_mask=layer_head_mask, attention_mask=layer_head_mask,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
do_condition=False,
)
shift, scale = self.residual_merger[i](reversed_residual[i]).chunk(
2, dim=-1
)
x = x * (1 + scale) + shift
for i, layer_module in enumerate(self.rest_decoder):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
do_condition=False,
) )
x = x + reversed_residual[i]
x = self.decoder_norm(x) x = self.decoder_norm(x)

View File

@@ -25,6 +25,7 @@ class RF:
def __init__(self, model, fm="otcfm", loss="mse"): def __init__(self, model, fm="otcfm", loss="mse"):
self.model = model self.model = model
self.loss = loss self.loss = loss
self.iter = 0
sigma = 0.0 sigma = 0.0
if fm == "otcfm": if fm == "otcfm":
@@ -97,9 +98,14 @@ class RF:
total_d_loss.backward(retain_graph=True) total_d_loss.backward(retain_graph=True)
self.optimizer_D.step() self.optimizer_D.step()
def forward(self, gt, cloud): def forward(self, gt, cloud, condition=False):
t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore t, xt, ut = self.fm.sample_location_and_conditional_flow( # type: ignore
vt, _ = self.model(xt, t) cloud if not condition else torch.randn_like(cloud), gt
)
if condition:
vt = self.model(xt, t, cloud)
else:
vt, _ = self.model(xt, t)
if self.loss == "mse": if self.loss == "mse":
loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape)))) loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
@@ -116,10 +122,12 @@ class RF:
} }
loss = mse + lpips * 2.0 loss = mse + lpips * 2.0
elif self.loss == "gan_lpips_mse": elif self.loss == "gan_lpips_mse":
self.gan_loss( self.iter += 1
denormalize(gt), # if self.iter % 4 == 0:
denormalize(xt + (1 - t[:, None, None, None]) * vt), # self.gan_loss(
) # denormalize(gt),
# denormalize(xt + (1 - t[:, None, None, None]) * vt),
# )
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape)))) mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
lpips = self.lpips( lpips = self.lpips(
denormalize(gt) * 2 - 1, denormalize(gt) * 2 - 1,
@@ -129,12 +137,9 @@ class RF:
denormalize(gt) * 2 - 1, denormalize(gt) * 2 - 1,
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1, denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
) )
gan = ( # gan = -self.discriminator(
-self.discriminator( # denormalize(xt + (1 - t[:, None, None, None]) * vt),
denormalize(xt + (1 - t[:, None, None, None]) * vt), # ).mean(-1)
).mean(-1)
* 0.01
)
ssim = 1 - ms_ssim( ssim = 1 - ms_ssim(
denormalize(gt), denormalize(gt),
denormalize(xt + (1 - t[:, None, None, None]) * vt), denormalize(xt + (1 - t[:, None, None, None]) * vt),
@@ -145,10 +150,10 @@ class RF:
"train/mse": mse.mean().item(), "train/mse": mse.mean().item(),
"train/lpips": lpips.mean().item(), "train/lpips": lpips.mean().item(),
"train/alexlpips": alexlpips.mean().item(), "train/alexlpips": alexlpips.mean().item(),
"train/gan": gan.mean().item(), # "train/gan": gan.mean().item(),
"train/ssim": ssim.mean().item(), "train/ssim": ssim.mean().item(),
} }
loss = mse + lpips * 4.0 + gan + alexlpips + ssim loss = mse + lpips * 2.0 + alexlpips + ssim
else: else:
raise Exception raise Exception
@@ -158,43 +163,28 @@ class RF:
return loss.mean(), ttloss, loss_list return loss.mean(), ttloss, loss_list
@torch.no_grad() @torch.no_grad()
def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor: def sample(
self, cloud, tol=1e-5, integration="dopri5", condition=False
) -> torch.Tensor:
t_span = torch.linspace(0, 1, 2, device=cloud.device) t_span = torch.linspace(0, 1, 2, device=cloud.device)
traj = odeint( if condition:
lambda t, x: self.model(x, t)[0], x = torch.randn_like(cloud)
cloud, traj = odeint(
t_span, lambda t, x: self.model(x, t, cloud),
rtol=tol, x,
atol=tol, t_span,
method=integration, rtol=tol,
) atol=tol,
method=integration,
)
else:
traj = odeint(
lambda t, x: self.model(x, t)[0],
cloud,
t_span,
rtol=tol,
atol=tol,
method=integration,
)
return [traj[i] for i in range(traj.shape[0])] # type: ignore return [traj[i] for i in range(traj.shape[0])] # type: ignore
@torch.no_grad()
def sample_heun(self, z1, sample_steps=50):
b = z1.size(0)
dt = 1.0 / sample_steps
images = [z1]
z = z1
for i in range(sample_steps, 0, -1):
t_current = i / sample_steps
t_next = (i - 1) / sample_steps
t_current_tensor = torch.tensor([t_current] * b, device=z.device)
v_current, _ = self.model(z, t_current_tensor)
z_pred = z - dt * v_current
t_next_tensor = torch.tensor([t_next] * b, device=z.device)
v_next, _ = self.model(z_pred, t_next_tensor)
v_avg = 0.5 * (v_current + v_next)
z = z - dt * v_avg
images.append(z)
return images

Binary file not shown.

Before

Width:  |  Height:  |  Size: 396 KiB

After

Width:  |  Height:  |  Size: 444 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 401 KiB

After

Width:  |  Height:  |  Size: 467 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 386 KiB

After

Width:  |  Height:  |  Size: 403 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 379 KiB

After

Width:  |  Height:  |  Size: 409 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 412 KiB

After

Width:  |  Height:  |  Size: 462 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 407 KiB

After

Width:  |  Height:  |  Size: 417 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 440 KiB

After

Width:  |  Height:  |  Size: 437 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 441 KiB

After

Width:  |  Height:  |  Size: 454 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 428 KiB

After

Width:  |  Height:  |  Size: 432 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 386 KiB

After

Width:  |  Height:  |  Size: 446 KiB

BIN
test_images/transform_0.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 481 KiB

BIN
test_images/transform_1.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 450 KiB

BIN
test_images/transform_2.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 434 KiB

BIN
test_images/transform_3.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 440 KiB

BIN
test_images/transform_4.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 457 KiB

BIN
test_images/transform_5.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 520 KiB

BIN
test_images/transform_6.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 510 KiB

BIN
test_images/transform_7.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 445 KiB

BIN
test_images/transform_8.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 486 KiB

BIN
test_images/transform_9.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 440 KiB