things
208
main_dit.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torchvision.utils import make_grid
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from src.benchmark import benchmark
|
||||
from src.dataset.cuhk_cr2 import get_dataset
|
||||
from src.dataset.preprocess import denormalize
|
||||
from src.model.dit import DiT_Llama
|
||||
from src.rf import RF
|
||||
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
device = "cuda:1"
|
||||
|
||||
batch_size = 8 * 4
|
||||
accumulation_steps = 1
|
||||
total_epoch = 500
|
||||
|
||||
steps_per_epoch = len(train_dataset) // batch_size
|
||||
total_steps = steps_per_epoch * total_epoch
|
||||
warmup_steps = int(0.05 * total_steps)
|
||||
|
||||
grad_norm = 1.0
|
||||
|
||||
|
||||
model = DiT_Llama.from_pretrained_backbone(
|
||||
"facebook/dinov3-vitl16-pretrain-sat493m",
|
||||
patch_size=4,
|
||||
dim=256,
|
||||
n_layers=8,
|
||||
n_heads=32,
|
||||
).to(device)
|
||||
rf = RF(model, "icfm", "lpips_mse")
|
||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||
|
||||
|
||||
# scheduler
|
||||
def get_lr(step: int) -> float:
|
||||
if step < warmup_steps:
|
||||
return step / warmup_steps
|
||||
else:
|
||||
progress = (step - warmup_steps) / (total_steps - warmup_steps)
|
||||
return 0.5 * (1 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
scheduler = optim.lr_scheduler.LambdaLR(optimizer, get_lr)
|
||||
|
||||
wandb.init(project="cloud-removal-kmu", resume="allow")
|
||||
|
||||
if not (wandb.run and wandb.run.name):
|
||||
raise Exception("nope")
|
||||
|
||||
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
|
||||
|
||||
start_epoch = 0
|
||||
checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt"
|
||||
if os.path.exists(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
if "scheduler_state_dict" in checkpoint:
|
||||
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
start_epoch = checkpoint["epoch"] + 1
|
||||
|
||||
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
lossbin = {i: 0 for i in range(10)}
|
||||
losscnt = {i: 1e-6 for i in range(10)}
|
||||
|
||||
train_dataset = train_dataset.shuffle(seed=epoch)
|
||||
|
||||
for i in tqdm(
|
||||
range(0, len(train_dataset), batch_size),
|
||||
desc=f"Epoch {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = train_dataset[i : i + batch_size]
|
||||
cloud = batch["cloud"].to(device)
|
||||
gt = batch["gt"].to(device)
|
||||
|
||||
loss, blsct, loss_list = rf.forward(gt, cloud, condition=True)
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
if (i // batch_size + 1) % accumulation_steps == 0:
|
||||
# total_norm = torch.nn.utils.clip_grad_norm_(
|
||||
# model.parameters(), max_norm=grad_norm
|
||||
# )
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# wandb.log(
|
||||
# {
|
||||
# "train/grad_norm": total_norm.item(),
|
||||
# }
|
||||
# )
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
"train/loss": loss.item() * accumulation_steps,
|
||||
"train/lr": scheduler.get_last_lr()[0],
|
||||
}
|
||||
)
|
||||
wandb.log(loss_list)
|
||||
|
||||
for t, lss in blsct:
|
||||
bin_idx = min(int(t * 10), 9)
|
||||
lossbin[bin_idx] += lss
|
||||
losscnt[bin_idx] += 1
|
||||
|
||||
if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0:
|
||||
# total_norm = torch.nn.utils.clip_grad_norm_(
|
||||
# model.parameters(), max_norm=grad_norm
|
||||
# )
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
# wandb.log(
|
||||
# {
|
||||
# "train/grad_norm": total_norm.item(),
|
||||
# }
|
||||
# )
|
||||
|
||||
epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
|
||||
epoch_metrics["epoch"] = epoch
|
||||
wandb.log(epoch_metrics)
|
||||
|
||||
if (epoch + 1) % 50 == 0:
|
||||
rf.model.eval()
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
lpips_sum = 0
|
||||
count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for i in tqdm(
|
||||
range(0, len(test_dataset), batch_size),
|
||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["cloud"].to(device), condition=True)
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
|
||||
if i == 0:
|
||||
for step, demo in enumerate([images[0], images[-1]]):
|
||||
images = wandb.Image(
|
||||
make_grid(
|
||||
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
|
||||
),
|
||||
caption=f"step {step}",
|
||||
)
|
||||
wandb.log({"viz/decoded": images})
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
count += image.shape[0]
|
||||
|
||||
avg_psnr = psnr_sum / count
|
||||
avg_ssim = ssim_sum / count
|
||||
avg_lpips = lpips_sum / count
|
||||
wandb.log(
|
||||
{
|
||||
"eval/psnr": avg_psnr,
|
||||
"eval/ssim": avg_ssim,
|
||||
"eval/lpips": avg_lpips,
|
||||
"epoch": epoch + 1,
|
||||
}
|
||||
)
|
||||
rf.model.train()
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
},
|
||||
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
|
||||
)
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch, # type: ignore
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
},
|
||||
f"artifact/{wandb.run.name}/checkpoint_final.pt",
|
||||
)
|
||||
wandb.finish()
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.utils import save_image
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -10,7 +11,7 @@ from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
checkpoint_path = "artifact/daily-forest-25/checkpoint_final.pt"
|
||||
checkpoint_path = "artifact/firm-darkness-98/checkpoint_final.pt"
|
||||
device = "cuda:1"
|
||||
save_dir = "test_images"
|
||||
|
||||
@@ -28,7 +29,7 @@ rf.model.eval()
|
||||
|
||||
_, test_dataset = get_dataset()
|
||||
|
||||
batch_size = 8
|
||||
batch_size = 8 * 4
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
lpips_sum = 0
|
||||
@@ -39,7 +40,7 @@ max_save = 10
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample_heun(batch["cloud"].to(device), 1)
|
||||
images = rf.sample(batch["cloud"].to(device), 1)
|
||||
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
@@ -52,6 +53,23 @@ with torch.no_grad():
|
||||
denormalize(batch["cloud"][j]).clamp(0, 1),
|
||||
f"{save_dir}/input_{saved_count}.png",
|
||||
)
|
||||
|
||||
frames = []
|
||||
for step_img in images:
|
||||
frame = denormalize(step_img[j]).clamp(0, 1)
|
||||
frame_np = (frame.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
"uint8"
|
||||
)
|
||||
frames.append(Image.fromarray(frame_np))
|
||||
|
||||
frames[0].save(
|
||||
f"{save_dir}/transform_{saved_count}.gif",
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
duration=100,
|
||||
loop=0,
|
||||
)
|
||||
|
||||
saved_count += 1
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
import lpips
|
||||
from pytorch_msssim import ssim
|
||||
from torchmetrics.image import (
|
||||
LearnedPerceptualImagePatchSimilarity,
|
||||
PeakSignalNoiseRatio,
|
||||
StructuralSimilarityIndexMeasure,
|
||||
)
|
||||
|
||||
psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3))
|
||||
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(
|
||||
net_type="alex", reduction="none", normalize=True
|
||||
)
|
||||
lp = lpips.LPIPS(net="alex")
|
||||
|
||||
|
||||
def benchmark(image1, image2):
|
||||
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)
|
||||
return (
|
||||
psnr(image1, image2),
|
||||
ssim(
|
||||
image1,
|
||||
image2,
|
||||
data_range=1.0,
|
||||
size_average=False,
|
||||
),
|
||||
lp(image1 * 2 - 1, image2 * 2 - 1),
|
||||
)
|
||||
|
||||
@@ -373,7 +373,7 @@ class DINOv3ViTModel(nn.Module):
|
||||
self.config = config
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
self.layers = nn.ModuleList(
|
||||
self.layer = nn.ModuleList(
|
||||
[DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -420,7 +420,7 @@ class DINOv3ViTModel(nn.Module):
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
latents = []
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
|
||||
387
src/model/dit.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# Code heavily based on https://github.com/Alpha-VLLM/LLaMA2-Accessory
|
||||
# this is modeling code for DiT-LLaMA model
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.dino import DINOv3ViTModel
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
||||
).to(t.device)
|
||||
args = t[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = int(dropout_prob > 0)
|
||||
self.embedding_table = nn.Embedding(
|
||||
num_classes + use_cfg_embedding, hidden_size
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
|
||||
drop_ids = drop_ids.cuda()
|
||||
drop_ids = drop_ids.to(labels.device)
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, n_heads):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.n_rep = 1
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
|
||||
self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim)
|
||||
|
||||
@staticmethod
|
||||
def reshape_for_broadcast(freqs_cis, x):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
# assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
_freqs_cis = freqs_cis[: x.shape[1]]
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return _freqs_cis.view(*shape)
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(xq, xk, freqs_cis):
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis_xq = Attention.reshape_for_broadcast(freqs_cis, xq_)
|
||||
freqs_cis_xk = Attention.reshape_for_broadcast(freqs_cis, xk_)
|
||||
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis_xq).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis_xk).flatten(3)
|
||||
return xq_out, xk_out
|
||||
|
||||
def forward(self, x, freqs_cis):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
dtype = xq.dtype
|
||||
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
|
||||
xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
xq, xk = xq.to(dtype), xk.to(dtype)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
xq.permute(0, 2, 1, 3),
|
||||
xk.permute(0, 2, 1, 3),
|
||||
xv.permute(0, 2, 1, 3),
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
).permute(0, 2, 1, 3)
|
||||
output = output.flatten(-2)
|
||||
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
if ffn_dim_multiplier:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = Attention(dim, n_heads)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(dim, 1024), 6 * dim, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, freqs_cis, adaln_input=None):
|
||||
if adaln_input is not None:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.adaLN_modulation(adaln_input).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1) * self.attention(
|
||||
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
|
||||
modulate(self.ffn_norm(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
else:
|
||||
x = x + self.attention(self.attention_norm(x), freqs_cis)
|
||||
x = x + self.feed_forward(self.ffn_norm(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, bias=True
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True),
|
||||
)
|
||||
# # init zero
|
||||
nn.init.constant_(self.linear.weight, 0)
|
||||
nn.init.constant_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DiT_Llama(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dino_cfg: DINOv3ViTConfig,
|
||||
in_channels=3,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
dim=512,
|
||||
n_layers=5,
|
||||
n_heads=16,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.input_size = input_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.init_conv_seq = nn.Sequential(
|
||||
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
|
||||
nn.SiLU(),
|
||||
nn.GroupNorm(32, dim // 2),
|
||||
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
|
||||
nn.SiLU(),
|
||||
nn.GroupNorm(32, dim // 2),
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
|
||||
nn.init.constant_(self.x_embedder.bias, 0)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024))
|
||||
self.y_embedder = DINOv3ViTModel(dino_cfg)
|
||||
self.thing = nn.Linear(dino_cfg.hidden_size, min(dim, 1024))
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
|
||||
|
||||
self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 16384)
|
||||
|
||||
# freeze
|
||||
self.y_embedder.requires_grad_(False)
|
||||
|
||||
def unpatchify(self, x):
|
||||
c = self.out_channels
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def patchify(self, x):
|
||||
B, C, H, W = x.size()
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
H // self.patch_size,
|
||||
self.patch_size,
|
||||
W // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
return x
|
||||
|
||||
def forward(self, x, t, y):
|
||||
self.freqs_cis = self.freqs_cis.to(x.device)
|
||||
|
||||
x = self.init_conv_seq(x)
|
||||
|
||||
x = self.patchify(x)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
y = self.thing(self.y_embedder(y)["pooler_output"]) # (N, D)
|
||||
adaln_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_cfg(self, x, t, y, cfg_scale):
|
||||
half = x[: len(x) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.forward(combined, t, y)
|
||||
eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim, end, theta=10000.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end)
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(
|
||||
name: str,
|
||||
in_channels=3,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
dim=512,
|
||||
n_layers=5,
|
||||
n_heads=16,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = DiT_Llama(
|
||||
config,
|
||||
in_channels,
|
||||
input_size,
|
||||
patch_size,
|
||||
dim,
|
||||
n_layers,
|
||||
n_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
).to("cuda:1")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
weight_dict[key] = f.get_tensor(key)
|
||||
|
||||
instance.y_embedder.load_state_dict(weight_dict, strict=True)
|
||||
|
||||
return instance
|
||||
@@ -1,144 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig
|
||||
|
||||
from src.model.dino import (
|
||||
DINOv3ViTEmbeddings,
|
||||
DINOv3ViTRopePositionEmbedding,
|
||||
)
|
||||
from src.model.utransformer import DinoConditionedLayer, TimestepEmbedder
|
||||
|
||||
|
||||
class Hourgrass(nn.Module):
|
||||
def __init__(
|
||||
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
assert config.num_hidden_layers % scale_factor == 0
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, True)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
# freeze pretrained
|
||||
self.embeddings.requires_grad_(False)
|
||||
self.rope_embeddings.requires_grad_(False)
|
||||
self.encoder_norm.requires_grad_(False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: torch.Tensor,
|
||||
# cond: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if time.dim() == 0:
|
||||
time = time.repeat(pixel_values.shape[0])
|
||||
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
|
||||
conditioning_input = t.to(x.dtype)
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
x = self.encoder_norm(x)
|
||||
|
||||
reversed_residual = residual[::-1]
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
x = x + reversed_residual[i]
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
|
||||
def get_residual(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
time: Optional[torch.Tensor],
|
||||
do_condition: bool,
|
||||
):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
x = self.embeddings(pixel_values, bool_masked_pos=None)
|
||||
|
||||
if do_condition:
|
||||
t = self.t_embedder(time).unsqueeze(1)
|
||||
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||
conditioning_input = t.to(x.dtype)
|
||||
else:
|
||||
conditioning_input = None
|
||||
|
||||
residual = []
|
||||
for i, layer_module in enumerate(self.encoder_layers):
|
||||
if i % self.scale_factor == 0:
|
||||
residual.append(x)
|
||||
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=None,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=do_condition,
|
||||
)
|
||||
|
||||
return residual
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = UTransformer(config, 0).to("cuda:1")
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
new_key = key.replace("layer.", "encoder_layers.").replace(
|
||||
"norm.", "encoder_norm."
|
||||
)
|
||||
|
||||
weight_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
instance.load_state_dict(weight_dict, strict=False)
|
||||
|
||||
return instance
|
||||
@@ -78,6 +78,7 @@ class LabelEmbedder(nn.Module):
|
||||
class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False):
|
||||
super().__init__(config)
|
||||
self.is_encoder = is_encoder
|
||||
|
||||
self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cond = nn.MultiheadAttention(
|
||||
@@ -298,6 +299,17 @@ class UTransformer(nn.Module):
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
]
|
||||
)
|
||||
self.residual_merger = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size)
|
||||
)
|
||||
for _ in range(config.num_hidden_layers // scale_factor)
|
||||
]
|
||||
)
|
||||
self.rest_decoder = nn.ModuleList(
|
||||
[DinoConditionedLayer(config, False) for _ in range(4)]
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
@@ -348,8 +360,22 @@ class UTransformer(nn.Module):
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
shift, scale = self.residual_merger[i](reversed_residual[i]).chunk(
|
||||
2, dim=-1
|
||||
)
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
for i, layer_module in enumerate(self.rest_decoder):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
x = x + reversed_residual[i]
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
|
||||
94
src/rf.py
@@ -25,6 +25,7 @@ class RF:
|
||||
def __init__(self, model, fm="otcfm", loss="mse"):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.iter = 0
|
||||
|
||||
sigma = 0.0
|
||||
if fm == "otcfm":
|
||||
@@ -97,9 +98,14 @@ class RF:
|
||||
total_d_loss.backward(retain_graph=True)
|
||||
self.optimizer_D.step()
|
||||
|
||||
def forward(self, gt, cloud):
|
||||
t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore
|
||||
vt, _ = self.model(xt, t)
|
||||
def forward(self, gt, cloud, condition=False):
|
||||
t, xt, ut = self.fm.sample_location_and_conditional_flow( # type: ignore
|
||||
cloud if not condition else torch.randn_like(cloud), gt
|
||||
)
|
||||
if condition:
|
||||
vt = self.model(xt, t, cloud)
|
||||
else:
|
||||
vt, _ = self.model(xt, t)
|
||||
|
||||
if self.loss == "mse":
|
||||
loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
|
||||
@@ -116,10 +122,12 @@ class RF:
|
||||
}
|
||||
loss = mse + lpips * 2.0
|
||||
elif self.loss == "gan_lpips_mse":
|
||||
self.gan_loss(
|
||||
denormalize(gt),
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
)
|
||||
self.iter += 1
|
||||
# if self.iter % 4 == 0:
|
||||
# self.gan_loss(
|
||||
# denormalize(gt),
|
||||
# denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
# )
|
||||
mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
|
||||
lpips = self.lpips(
|
||||
denormalize(gt) * 2 - 1,
|
||||
@@ -129,12 +137,9 @@ class RF:
|
||||
denormalize(gt) * 2 - 1,
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
|
||||
)
|
||||
gan = (
|
||||
-self.discriminator(
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
).mean(-1)
|
||||
* 0.01
|
||||
)
|
||||
# gan = -self.discriminator(
|
||||
# denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
# ).mean(-1)
|
||||
ssim = 1 - ms_ssim(
|
||||
denormalize(gt),
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
@@ -145,10 +150,10 @@ class RF:
|
||||
"train/mse": mse.mean().item(),
|
||||
"train/lpips": lpips.mean().item(),
|
||||
"train/alexlpips": alexlpips.mean().item(),
|
||||
"train/gan": gan.mean().item(),
|
||||
# "train/gan": gan.mean().item(),
|
||||
"train/ssim": ssim.mean().item(),
|
||||
}
|
||||
loss = mse + lpips * 4.0 + gan + alexlpips + ssim
|
||||
loss = mse + lpips * 2.0 + alexlpips + ssim
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
@@ -158,43 +163,28 @@ class RF:
|
||||
return loss.mean(), ttloss, loss_list
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor:
|
||||
def sample(
|
||||
self, cloud, tol=1e-5, integration="dopri5", condition=False
|
||||
) -> torch.Tensor:
|
||||
t_span = torch.linspace(0, 1, 2, device=cloud.device)
|
||||
traj = odeint(
|
||||
lambda t, x: self.model(x, t)[0],
|
||||
cloud,
|
||||
t_span,
|
||||
rtol=tol,
|
||||
atol=tol,
|
||||
method=integration,
|
||||
)
|
||||
if condition:
|
||||
x = torch.randn_like(cloud)
|
||||
traj = odeint(
|
||||
lambda t, x: self.model(x, t, cloud),
|
||||
x,
|
||||
t_span,
|
||||
rtol=tol,
|
||||
atol=tol,
|
||||
method=integration,
|
||||
)
|
||||
else:
|
||||
traj = odeint(
|
||||
lambda t, x: self.model(x, t)[0],
|
||||
cloud,
|
||||
t_span,
|
||||
rtol=tol,
|
||||
atol=tol,
|
||||
method=integration,
|
||||
)
|
||||
|
||||
return [traj[i] for i in range(traj.shape[0])] # type: ignore
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(self, z1, sample_steps=50):
|
||||
b = z1.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
|
||||
images = [z1]
|
||||
z = z1
|
||||
|
||||
for i in range(sample_steps, 0, -1):
|
||||
t_current = i / sample_steps
|
||||
t_next = (i - 1) / sample_steps
|
||||
|
||||
t_current_tensor = torch.tensor([t_current] * b, device=z.device)
|
||||
|
||||
v_current, _ = self.model(z, t_current_tensor)
|
||||
z_pred = z - dt * v_current
|
||||
|
||||
t_next_tensor = torch.tensor([t_next] * b, device=z.device)
|
||||
v_next, _ = self.model(z_pred, t_next_tensor)
|
||||
|
||||
v_avg = 0.5 * (v_current + v_next)
|
||||
|
||||
z = z - dt * v_avg
|
||||
|
||||
images.append(z)
|
||||
|
||||
return images
|
||||
|
||||
|
Before Width: | Height: | Size: 396 KiB After Width: | Height: | Size: 444 KiB |
|
Before Width: | Height: | Size: 401 KiB After Width: | Height: | Size: 467 KiB |
|
Before Width: | Height: | Size: 386 KiB After Width: | Height: | Size: 403 KiB |
|
Before Width: | Height: | Size: 379 KiB After Width: | Height: | Size: 409 KiB |
|
Before Width: | Height: | Size: 412 KiB After Width: | Height: | Size: 462 KiB |
|
Before Width: | Height: | Size: 407 KiB After Width: | Height: | Size: 417 KiB |
|
Before Width: | Height: | Size: 440 KiB After Width: | Height: | Size: 437 KiB |
|
Before Width: | Height: | Size: 441 KiB After Width: | Height: | Size: 454 KiB |
|
Before Width: | Height: | Size: 428 KiB After Width: | Height: | Size: 432 KiB |
|
Before Width: | Height: | Size: 386 KiB After Width: | Height: | Size: 446 KiB |
BIN
test_images/transform_0.gif
Normal file
|
After Width: | Height: | Size: 481 KiB |
BIN
test_images/transform_1.gif
Normal file
|
After Width: | Height: | Size: 450 KiB |
BIN
test_images/transform_2.gif
Normal file
|
After Width: | Height: | Size: 434 KiB |
BIN
test_images/transform_3.gif
Normal file
|
After Width: | Height: | Size: 440 KiB |
BIN
test_images/transform_4.gif
Normal file
|
After Width: | Height: | Size: 457 KiB |
BIN
test_images/transform_5.gif
Normal file
|
After Width: | Height: | Size: 520 KiB |
BIN
test_images/transform_6.gif
Normal file
|
After Width: | Height: | Size: 510 KiB |
BIN
test_images/transform_7.gif
Normal file
|
After Width: | Height: | Size: 445 KiB |
BIN
test_images/transform_8.gif
Normal file
|
After Width: | Height: | Size: 486 KiB |
BIN
test_images/transform_9.gif
Normal file
|
After Width: | Height: | Size: 440 KiB |