diff --git a/main_dit.py b/main_dit.py new file mode 100644 index 0000000..a18956b --- /dev/null +++ b/main_dit.py @@ -0,0 +1,208 @@ +import math +import os + +import torch +import torch.optim as optim +from torchvision.utils import make_grid +from tqdm import tqdm + +import wandb +from src.benchmark import benchmark +from src.dataset.cuhk_cr2 import get_dataset +from src.dataset.preprocess import denormalize +from src.model.dit import DiT_Llama +from src.rf import RF + +train_dataset, test_dataset = get_dataset() + +device = "cuda:1" + +batch_size = 8 * 4 +accumulation_steps = 1 +total_epoch = 500 + +steps_per_epoch = len(train_dataset) // batch_size +total_steps = steps_per_epoch * total_epoch +warmup_steps = int(0.05 * total_steps) + +grad_norm = 1.0 + + +model = DiT_Llama.from_pretrained_backbone( + "facebook/dinov3-vitl16-pretrain-sat493m", + patch_size=4, + dim=256, + n_layers=8, + n_heads=32, +).to(device) +rf = RF(model, "icfm", "lpips_mse") +optimizer = optim.AdamW(model.parameters(), lr=3e-4) + + +# scheduler +def get_lr(step: int) -> float: + if step < warmup_steps: + return step / warmup_steps + else: + progress = (step - warmup_steps) / (total_steps - warmup_steps) + return 0.5 * (1 + math.cos(math.pi * progress)) + + +scheduler = optim.lr_scheduler.LambdaLR(optimizer, get_lr) + +wandb.init(project="cloud-removal-kmu", resume="allow") + +if not (wandb.run and wandb.run.name): + raise Exception("nope") + +os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True) + +start_epoch = 0 +checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt" +if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if "scheduler_state_dict" in checkpoint: + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + start_epoch = checkpoint["epoch"] + 1 + + +for epoch in range(start_epoch, total_epoch): + lossbin = {i: 0 for i in range(10)} + losscnt = {i: 1e-6 for i in range(10)} + + train_dataset = train_dataset.shuffle(seed=epoch) + + for i in tqdm( + range(0, len(train_dataset), batch_size), + desc=f"Epoch {epoch + 1}/{total_epoch}", + ): + batch = train_dataset[i : i + batch_size] + cloud = batch["cloud"].to(device) + gt = batch["gt"].to(device) + + loss, blsct, loss_list = rf.forward(gt, cloud, condition=True) + loss = loss / accumulation_steps + loss.backward() + + if (i // batch_size + 1) % accumulation_steps == 0: + # total_norm = torch.nn.utils.clip_grad_norm_( + # model.parameters(), max_norm=grad_norm + # ) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + # wandb.log( + # { + # "train/grad_norm": total_norm.item(), + # } + # ) + + wandb.log( + { + "train/loss": loss.item() * accumulation_steps, + "train/lr": scheduler.get_last_lr()[0], + } + ) + wandb.log(loss_list) + + for t, lss in blsct: + bin_idx = min(int(t * 10), 9) + lossbin[bin_idx] += lss + losscnt[bin_idx] += 1 + + if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0: + # total_norm = torch.nn.utils.clip_grad_norm_( + # model.parameters(), max_norm=grad_norm + # ) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + # wandb.log( + # { + # "train/grad_norm": total_norm.item(), + # } + # ) + + epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)} + epoch_metrics["epoch"] = epoch + wandb.log(epoch_metrics) + + if (epoch + 1) % 50 == 0: + rf.model.eval() + psnr_sum = 0 + ssim_sum = 0 + lpips_sum = 0 + count = 0 + + with torch.no_grad(): + for i in tqdm( + range(0, len(test_dataset), batch_size), + desc=f"Benchmark {epoch + 1}/{total_epoch}", + ): + batch = test_dataset[i : i + batch_size] + images = rf.sample(batch["cloud"].to(device), condition=True) + image = denormalize(images[-1]).clamp(0, 1) + original = denormalize(batch["gt"]).clamp(0, 1) + + if i == 0: + for step, demo in enumerate([images[0], images[-1]]): + images = wandb.Image( + make_grid( + denormalize(demo).clamp(0, 1).float()[:4], nrow=2 + ), + caption=f"step {step}", + ) + wandb.log({"viz/decoded": images}) + + psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) + psnr_sum += psnr.sum().item() + ssim_sum += ssim.sum().item() + lpips_sum += lpips.sum().item() + count += image.shape[0] + + avg_psnr = psnr_sum / count + avg_ssim = ssim_sum / count + avg_lpips = lpips_sum / count + wandb.log( + { + "eval/psnr": avg_psnr, + "eval/ssim": avg_ssim, + "eval/lpips": avg_lpips, + "epoch": epoch + 1, + } + ) + rf.model.train() + + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + }, + f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt", + ) + + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + }, + checkpoint_path, + ) + +torch.save( + { + "epoch": epoch, # type: ignore + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + }, + f"artifact/{wandb.run.name}/checkpoint_final.pt", +) +wandb.finish() diff --git a/quick_eval.py b/quick_eval.py index a6f35b2..ca62ea7 100644 --- a/quick_eval.py +++ b/quick_eval.py @@ -1,6 +1,7 @@ import os import torch +from PIL import Image from torchvision.utils import save_image from tqdm import tqdm @@ -10,7 +11,7 @@ from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF -checkpoint_path = "artifact/daily-forest-25/checkpoint_final.pt" +checkpoint_path = "artifact/firm-darkness-98/checkpoint_final.pt" device = "cuda:1" save_dir = "test_images" @@ -28,7 +29,7 @@ rf.model.eval() _, test_dataset = get_dataset() -batch_size = 8 +batch_size = 8 * 4 psnr_sum = 0 ssim_sum = 0 lpips_sum = 0 @@ -39,7 +40,7 @@ max_save = 10 with torch.no_grad(): for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"): batch = test_dataset[i : i + batch_size] - images = rf.sample_heun(batch["cloud"].to(device), 1) + images = rf.sample(batch["cloud"].to(device), 1) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["gt"]).clamp(0, 1) @@ -52,6 +53,23 @@ with torch.no_grad(): denormalize(batch["cloud"][j]).clamp(0, 1), f"{save_dir}/input_{saved_count}.png", ) + + frames = [] + for step_img in images: + frame = denormalize(step_img[j]).clamp(0, 1) + frame_np = (frame.permute(1, 2, 0).cpu().numpy() * 255).astype( + "uint8" + ) + frames.append(Image.fromarray(frame_np)) + + frames[0].save( + f"{save_dir}/transform_{saved_count}.gif", + save_all=True, + append_images=frames[1:], + duration=100, + loop=0, + ) + saved_count += 1 psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) diff --git a/src/benchmark/__init__.py b/src/benchmark/__init__.py index 0072fdc..33fa94b 100644 --- a/src/benchmark/__init__.py +++ b/src/benchmark/__init__.py @@ -1,15 +1,22 @@ +import lpips +from pytorch_msssim import ssim from torchmetrics.image import ( - LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, ) psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3)) -ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none") -lpips = LearnedPerceptualImagePatchSimilarity( - net_type="alex", reduction="none", normalize=True -) +lp = lpips.LPIPS(net="alex") def benchmark(image1, image2): - return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2) + return ( + psnr(image1, image2), + ssim( + image1, + image2, + data_range=1.0, + size_average=False, + ), + lp(image1 * 2 - 1, image2 * 2 - 1), + ) diff --git a/src/model/dino.py b/src/model/dino.py index ed4f3fd..837d32d 100644 --- a/src/model/dino.py +++ b/src/model/dino.py @@ -373,7 +373,7 @@ class DINOv3ViTModel(nn.Module): self.config = config self.embeddings = DINOv3ViTEmbeddings(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) - self.layers = nn.ModuleList( + self.layer = nn.ModuleList( [DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -420,7 +420,7 @@ class DINOv3ViTModel(nn.Module): position_embeddings = self.rope_embeddings(pixel_values) latents = [] - for i, layer_module in enumerate(self.layers): + for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] if head_mask is not None else None hidden_states = layer_module( hidden_states, diff --git a/src/model/dit.py b/src/model/dit.py new file mode 100644 index 0000000..8f644d3 --- /dev/null +++ b/src/model/dit.py @@ -0,0 +1,387 @@ +# Code heavily based on https://github.com/Alpha-VLLM/LLaMA2-Accessory +# this is modeling code for DiT-LLaMA model + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig + +from src.model.dino import DINOv3ViTModel + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half) / half + ).to(t.device) + args = t[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( + dtype=next(self.parameters()).dtype + ) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = int(dropout_prob > 0) + self.embedding_table = nn.Embedding( + num_classes + use_cfg_embedding, hidden_size + ) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob + drop_ids = drop_ids.cuda() + drop_ids = drop_ids.to(labels.device) + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class Attention(nn.Module): + def __init__(self, dim, n_heads): + super().__init__() + + self.n_heads = n_heads + self.n_rep = 1 + self.head_dim = dim // n_heads + + self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) + self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) + self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False) + + self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim) + self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim) + + @staticmethod + def reshape_for_broadcast(freqs_cis, x): + ndim = x.ndim + assert 0 <= 1 < ndim + # assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + _freqs_cis = freqs_cis[: x.shape[1]] + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return _freqs_cis.view(*shape) + + @staticmethod + def apply_rotary_emb(xq, xk, freqs_cis): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis_xq = Attention.reshape_for_broadcast(freqs_cis, xq_) + freqs_cis_xk = Attention.reshape_for_broadcast(freqs_cis, xk_) + + xq_out = torch.view_as_real(xq_ * freqs_cis_xq).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis_xk).flatten(3) + return xq_out, xk_out + + def forward(self, x, freqs_cis): + bsz, seqlen, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + dtype = xq.dtype + + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim) + + xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = xq.to(dtype), xk.to(dtype) + + output = F.scaled_dot_product_attention( + xq.permute(0, 2, 1, 3), + xk.permute(0, 2, 1, 3), + xv.permute(0, 2, 1, 3), + dropout_p=0.0, + is_causal=False, + ).permute(0, 2, 1, 3) + output = output.flatten(-2) + + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class TransformerBlock(nn.Module): + def __init__( + self, + layer_id, + dim, + n_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.attention = Attention(dim, n_heads) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = nn.LayerNorm(dim, eps=norm_eps) + self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(dim, 1024), 6 * dim, bias=True), + ) + + def forward(self, x, freqs_cis, adaln_input=None): + if adaln_input is not None: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(adaln_input).chunk(6, dim=1) + ) + + x = x + gate_msa.unsqueeze(1) * self.attention( + modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis + ) + x = x + gate_mlp.unsqueeze(1) * self.feed_forward( + modulate(self.ffn_norm(x), shift_mlp, scale_mlp) + ) + else: + x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.feed_forward(self.ffn_norm(x)) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True), + ) + # # init zero + nn.init.constant_(self.linear.weight, 0) + nn.init.constant_(self.linear.bias, 0) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT_Llama(nn.Module): + def __init__( + self, + dino_cfg: DINOv3ViTConfig, + in_channels=3, + input_size=32, + patch_size=2, + dim=512, + n_layers=5, + n_heads=16, + multiple_of=256, + ffn_dim_multiplier=None, + norm_eps=1e-5, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = in_channels + self.input_size = input_size + self.patch_size = patch_size + + self.init_conv_seq = nn.Sequential( + nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1), + nn.SiLU(), + nn.GroupNorm(32, dim // 2), + nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1), + nn.SiLU(), + nn.GroupNorm(32, dim // 2), + ) + + self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True) + nn.init.constant_(self.x_embedder.bias, 0) + + self.t_embedder = TimestepEmbedder(min(dim, 1024)) + self.y_embedder = DINOv3ViTModel(dino_cfg) + self.thing = nn.Linear(dino_cfg.hidden_size, min(dim, 1024)) + + self.layers = nn.ModuleList( + [ + TransformerBlock( + layer_id, + dim, + n_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + ) + for layer_id in range(n_layers) + ] + ) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels) + + self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 16384) + + # freeze + self.y_embedder.requires_grad_(False) + + def unpatchify(self, x): + c = self.out_channels + p = self.patch_size + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def patchify(self, x): + B, C, H, W = x.size() + x = x.view( + B, + C, + H // self.patch_size, + self.patch_size, + W // self.patch_size, + self.patch_size, + ) + x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + return x + + def forward(self, x, t, y): + self.freqs_cis = self.freqs_cis.to(x.device) + + x = self.init_conv_seq(x) + + x = self.patchify(x) + x = self.x_embedder(x) + + t = self.t_embedder(t) # (N, D) + y = self.thing(self.y_embedder(y)["pooler_output"]) # (N, D) + adaln_input = t.to(x.dtype) + y.to(x.dtype) + + for layer in self.layers: + x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input) + + x = self.final_layer(x, adaln_input) + x = self.unpatchify(x) # (N, out_channels, H, W) + + return x + + def forward_with_cfg(self, x, t, y, cfg_scale): + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + @staticmethod + def precompute_freqs_cis(dim, end, theta=10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + @staticmethod + def from_pretrained_backbone( + name: str, + in_channels=3, + input_size=32, + patch_size=2, + dim=512, + n_layers=5, + n_heads=16, + multiple_of=256, + ffn_dim_multiplier=None, + norm_eps=1e-5, + ): + config = DINOv3ViTConfig.from_pretrained(name) + instance = DiT_Llama( + config, + in_channels, + input_size, + patch_size, + dim, + n_layers, + n_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + ).to("cuda:1") + + weight_dict = {} + with safe_open( + hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1" + ) as f: + for key in f.keys(): + weight_dict[key] = f.get_tensor(key) + + instance.y_embedder.load_state_dict(weight_dict, strict=True) + + return instance diff --git a/src/model/hourglass.py b/src/model/hourglass.py deleted file mode 100644 index 009c77a..0000000 --- a/src/model/hourglass.py +++ /dev/null @@ -1,144 +0,0 @@ -from typing import Optional - -import torch -from huggingface_hub import hf_hub_download -from safetensors import safe_open -from torch import nn -from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig - -from src.model.dino import ( - DINOv3ViTEmbeddings, - DINOv3ViTRopePositionEmbedding, -) -from src.model.utransformer import DinoConditionedLayer, TimestepEmbedder - - -class Hourgrass(nn.Module): - def __init__( - self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4 - ): - super().__init__() - self.config = config - self.scale_factor = scale_factor - - assert config.num_hidden_layers % scale_factor == 0 - - self.embeddings = DINOv3ViTEmbeddings(config) - self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) - self.t_embedder = TimestepEmbedder(config.hidden_size) - - self.encoder_layers = nn.ModuleList( - [ - DinoConditionedLayer(config, True) - for _ in range(config.num_hidden_layers) - ] - ) - self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - # freeze pretrained - self.embeddings.requires_grad_(False) - self.rope_embeddings.requires_grad_(False) - self.encoder_norm.requires_grad_(False) - - def forward( - self, - pixel_values: torch.Tensor, - time: torch.Tensor, - # cond: torch.Tensor, - bool_masked_pos: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - ): - if time.dim() == 0: - time = time.repeat(pixel_values.shape[0]) - - pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) - position_embeddings = self.rope_embeddings(pixel_values) - - x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) - t = self.t_embedder(time).unsqueeze(1) - - conditioning_input = t.to(x.dtype) - - residual = [] - for i, layer_module in enumerate(self.encoder_layers): - if i % self.scale_factor == 0: - residual.append(x) - - layer_head_mask = head_mask[i] if head_mask is not None else None - x = layer_module( - x, - conditioning_input=conditioning_input, - attention_mask=layer_head_mask, - position_embeddings=position_embeddings, - ) - - x = self.encoder_norm(x) - - reversed_residual = residual[::-1] - for i, layer_module in enumerate(self.decoder_layers): - layer_head_mask = head_mask[i] if head_mask is not None else None - x = layer_module( - x, - conditioning_input=conditioning_input, - attention_mask=layer_head_mask, - position_embeddings=position_embeddings, - ) - x = x + reversed_residual[i] - - x = self.decoder_norm(x) - - return self.decoder(x, image_size=pixel_values.shape[-2:]), residual - - def get_residual( - self, - pixel_values: torch.Tensor, - time: Optional[torch.Tensor], - do_condition: bool, - ): - pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) - position_embeddings = self.rope_embeddings(pixel_values) - - x = self.embeddings(pixel_values, bool_masked_pos=None) - - if do_condition: - t = self.t_embedder(time).unsqueeze(1) - # y = self.y_embedder(cond, self.training).unsqueeze(1) - # conditioning_input = t.to(x.dtype) + y.to(x.dtype) - conditioning_input = t.to(x.dtype) - else: - conditioning_input = None - - residual = [] - for i, layer_module in enumerate(self.encoder_layers): - if i % self.scale_factor == 0: - residual.append(x) - - x = layer_module( - x, - conditioning_input=conditioning_input, - attention_mask=None, - position_embeddings=position_embeddings, - do_condition=do_condition, - ) - - return residual - - @staticmethod - def from_pretrained_backbone(name: str): - config = DINOv3ViTConfig.from_pretrained(name) - instance = UTransformer(config, 0).to("cuda:1") - - weight_dict = {} - with safe_open( - hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1" - ) as f: - for key in f.keys(): - new_key = key.replace("layer.", "encoder_layers.").replace( - "norm.", "encoder_norm." - ) - - weight_dict[new_key] = f.get_tensor(key) - - instance.load_state_dict(weight_dict, strict=False) - - return instance diff --git a/src/model/utransformer.py b/src/model/utransformer.py index c9d0f44..a6dd2a2 100644 --- a/src/model/utransformer.py +++ b/src/model/utransformer.py @@ -78,6 +78,7 @@ class LabelEmbedder(nn.Module): class DinoConditionedLayer(DINOv3ViTLayer): def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False): super().__init__(config) + self.is_encoder = is_encoder self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.cond = nn.MultiheadAttention( @@ -298,6 +299,17 @@ class UTransformer(nn.Module): for _ in range(config.num_hidden_layers // scale_factor) ] ) + self.residual_merger = nn.ModuleList( + [ + nn.Sequential( + nn.SiLU(), nn.Linear(config.hidden_size, 2 * config.hidden_size) + ) + for _ in range(config.num_hidden_layers // scale_factor) + ] + ) + self.rest_decoder = nn.ModuleList( + [DinoConditionedLayer(config, False) for _ in range(4)] + ) self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.decoder = DinoV3ViTDecoder(config) @@ -348,8 +360,22 @@ class UTransformer(nn.Module): conditioning_input=conditioning_input, attention_mask=layer_head_mask, position_embeddings=position_embeddings, + do_condition=False, + ) + shift, scale = self.residual_merger[i](reversed_residual[i]).chunk( + 2, dim=-1 + ) + x = x * (1 + scale) + shift + + for i, layer_module in enumerate(self.rest_decoder): + layer_head_mask = head_mask[i] if head_mask is not None else None + x = layer_module( + x, + conditioning_input=conditioning_input, + attention_mask=layer_head_mask, + position_embeddings=position_embeddings, + do_condition=False, ) - x = x + reversed_residual[i] x = self.decoder_norm(x) diff --git a/src/rf.py b/src/rf.py index 537762b..ecb9a22 100644 --- a/src/rf.py +++ b/src/rf.py @@ -25,6 +25,7 @@ class RF: def __init__(self, model, fm="otcfm", loss="mse"): self.model = model self.loss = loss + self.iter = 0 sigma = 0.0 if fm == "otcfm": @@ -97,9 +98,14 @@ class RF: total_d_loss.backward(retain_graph=True) self.optimizer_D.step() - def forward(self, gt, cloud): - t, xt, ut = self.fm.sample_location_and_conditional_flow(cloud, gt) # type: ignore - vt, _ = self.model(xt, t) + def forward(self, gt, cloud, condition=False): + t, xt, ut = self.fm.sample_location_and_conditional_flow( # type: ignore + cloud if not condition else torch.randn_like(cloud), gt + ) + if condition: + vt = self.model(xt, t, cloud) + else: + vt, _ = self.model(xt, t) if self.loss == "mse": loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape)))) @@ -116,10 +122,12 @@ class RF: } loss = mse + lpips * 2.0 elif self.loss == "gan_lpips_mse": - self.gan_loss( - denormalize(gt), - denormalize(xt + (1 - t[:, None, None, None]) * vt), - ) + self.iter += 1 + # if self.iter % 4 == 0: + # self.gan_loss( + # denormalize(gt), + # denormalize(xt + (1 - t[:, None, None, None]) * vt), + # ) mse = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape)))) lpips = self.lpips( denormalize(gt) * 2 - 1, @@ -129,12 +137,9 @@ class RF: denormalize(gt) * 2 - 1, denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1, ) - gan = ( - -self.discriminator( - denormalize(xt + (1 - t[:, None, None, None]) * vt), - ).mean(-1) - * 0.01 - ) + # gan = -self.discriminator( + # denormalize(xt + (1 - t[:, None, None, None]) * vt), + # ).mean(-1) ssim = 1 - ms_ssim( denormalize(gt), denormalize(xt + (1 - t[:, None, None, None]) * vt), @@ -145,10 +150,10 @@ class RF: "train/mse": mse.mean().item(), "train/lpips": lpips.mean().item(), "train/alexlpips": alexlpips.mean().item(), - "train/gan": gan.mean().item(), + # "train/gan": gan.mean().item(), "train/ssim": ssim.mean().item(), } - loss = mse + lpips * 4.0 + gan + alexlpips + ssim + loss = mse + lpips * 2.0 + alexlpips + ssim else: raise Exception @@ -158,43 +163,28 @@ class RF: return loss.mean(), ttloss, loss_list @torch.no_grad() - def sample(self, cloud, tol=1e-5, integration="dopri5") -> torch.Tensor: + def sample( + self, cloud, tol=1e-5, integration="dopri5", condition=False + ) -> torch.Tensor: t_span = torch.linspace(0, 1, 2, device=cloud.device) - traj = odeint( - lambda t, x: self.model(x, t)[0], - cloud, - t_span, - rtol=tol, - atol=tol, - method=integration, - ) + if condition: + x = torch.randn_like(cloud) + traj = odeint( + lambda t, x: self.model(x, t, cloud), + x, + t_span, + rtol=tol, + atol=tol, + method=integration, + ) + else: + traj = odeint( + lambda t, x: self.model(x, t)[0], + cloud, + t_span, + rtol=tol, + atol=tol, + method=integration, + ) return [traj[i] for i in range(traj.shape[0])] # type: ignore - - @torch.no_grad() - def sample_heun(self, z1, sample_steps=50): - b = z1.size(0) - dt = 1.0 / sample_steps - - images = [z1] - z = z1 - - for i in range(sample_steps, 0, -1): - t_current = i / sample_steps - t_next = (i - 1) / sample_steps - - t_current_tensor = torch.tensor([t_current] * b, device=z.device) - - v_current, _ = self.model(z, t_current_tensor) - z_pred = z - dt * v_current - - t_next_tensor = torch.tensor([t_next] * b, device=z.device) - v_next, _ = self.model(z_pred, t_next_tensor) - - v_avg = 0.5 * (v_current + v_next) - - z = z - dt * v_avg - - images.append(z) - - return images diff --git a/test_images/pred_0.png b/test_images/pred_0.png index ff382db..383b184 100644 Binary files a/test_images/pred_0.png and b/test_images/pred_0.png differ diff --git a/test_images/pred_1.png b/test_images/pred_1.png index 025b1fb..06be357 100644 Binary files a/test_images/pred_1.png and b/test_images/pred_1.png differ diff --git a/test_images/pred_2.png b/test_images/pred_2.png index 8170bf8..d682394 100644 Binary files a/test_images/pred_2.png and b/test_images/pred_2.png differ diff --git a/test_images/pred_3.png b/test_images/pred_3.png index b40c5b1..f0b42f3 100644 Binary files a/test_images/pred_3.png and b/test_images/pred_3.png differ diff --git a/test_images/pred_4.png b/test_images/pred_4.png index 0afb7c6..d592f3b 100644 Binary files a/test_images/pred_4.png and b/test_images/pred_4.png differ diff --git a/test_images/pred_5.png b/test_images/pred_5.png index 53a9586..a662a38 100644 Binary files a/test_images/pred_5.png and b/test_images/pred_5.png differ diff --git a/test_images/pred_6.png b/test_images/pred_6.png index 9b2d219..9edb74b 100644 Binary files a/test_images/pred_6.png and b/test_images/pred_6.png differ diff --git a/test_images/pred_7.png b/test_images/pred_7.png index c23cca9..7b36265 100644 Binary files a/test_images/pred_7.png and b/test_images/pred_7.png differ diff --git a/test_images/pred_8.png b/test_images/pred_8.png index 1a625b1..4fad79c 100644 Binary files a/test_images/pred_8.png and b/test_images/pred_8.png differ diff --git a/test_images/pred_9.png b/test_images/pred_9.png index dfc59ae..ea77884 100644 Binary files a/test_images/pred_9.png and b/test_images/pred_9.png differ diff --git a/test_images/transform_0.gif b/test_images/transform_0.gif new file mode 100644 index 0000000..91b87c1 Binary files /dev/null and b/test_images/transform_0.gif differ diff --git a/test_images/transform_1.gif b/test_images/transform_1.gif new file mode 100644 index 0000000..ac059a1 Binary files /dev/null and b/test_images/transform_1.gif differ diff --git a/test_images/transform_2.gif b/test_images/transform_2.gif new file mode 100644 index 0000000..cbb3c29 Binary files /dev/null and b/test_images/transform_2.gif differ diff --git a/test_images/transform_3.gif b/test_images/transform_3.gif new file mode 100644 index 0000000..516018b Binary files /dev/null and b/test_images/transform_3.gif differ diff --git a/test_images/transform_4.gif b/test_images/transform_4.gif new file mode 100644 index 0000000..acdd469 Binary files /dev/null and b/test_images/transform_4.gif differ diff --git a/test_images/transform_5.gif b/test_images/transform_5.gif new file mode 100644 index 0000000..98d2336 Binary files /dev/null and b/test_images/transform_5.gif differ diff --git a/test_images/transform_6.gif b/test_images/transform_6.gif new file mode 100644 index 0000000..f95e584 Binary files /dev/null and b/test_images/transform_6.gif differ diff --git a/test_images/transform_7.gif b/test_images/transform_7.gif new file mode 100644 index 0000000..66940ae Binary files /dev/null and b/test_images/transform_7.gif differ diff --git a/test_images/transform_8.gif b/test_images/transform_8.gif new file mode 100644 index 0000000..a4d14eb Binary files /dev/null and b/test_images/transform_8.gif differ diff --git a/test_images/transform_9.gif b/test_images/transform_9.gif new file mode 100644 index 0000000..d6897b5 Binary files /dev/null and b/test_images/transform_9.gif differ