diff --git a/main.py b/main.py index c792122..de22f0f 100644 --- a/main.py +++ b/main.py @@ -2,16 +2,17 @@ import os import torch import torch.optim as optim +from torchvision.utils import make_grid from tqdm import tqdm import wandb from src.benchmark import benchmark -from src.dataset.cuhk_cr1 import get_dataset +from src.dataset.cuhk_cr2 import get_dataset from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF -device = "cuda:2" +device = "cuda:1" model = UTransformer.from_pretrained_backbone( "facebook/dinov3-vitl16-pretrain-sat493m" @@ -21,7 +22,7 @@ optimizer = optim.AdamW(model.parameters(), lr=1e-4) train_dataset, test_dataset = get_dataset() -wandb.init(project="cloud-removal-kmu", id="icy-field-12", resume="allow") +wandb.init(project="cloud-removal-kmu", id="dashing-moon-31", resume="allow") if not (wandb.run and wandb.run.name): raise Exception("nope") @@ -36,7 +37,7 @@ if os.path.exists(checkpoint_path): optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] + 1 -batch_size = 4 +batch_size = 8 accumulation_steps = 8 total_epoch = 1000 for epoch in range(start_epoch, total_epoch): @@ -89,10 +90,20 @@ for epoch in range(start_epoch, total_epoch): desc=f"Benchmark {epoch + 1}/{total_epoch}", ): batch = test_dataset[i : i + batch_size] - images = rf.sample(batch["cloud"].to(device)) + images = rf.sample_heun(batch["cloud"].to(device)) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["gt"]).clamp(0, 1) + if i == 0: + for step, demo in enumerate([images[0], images[-1]]): + images = wandb.Image( + make_grid( + denormalize(demo).clamp(0, 1).float()[:4], nrow=2 + ), + caption=f"step {step}", + ) + wandb.log({"viz/decoded": images}) + psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() diff --git a/quick_eval.py b/quick_eval.py index 220f433..a6f35b2 100644 --- a/quick_eval.py +++ b/quick_eval.py @@ -10,8 +10,8 @@ from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF -checkpoint_path = "artifact/icy-field-12/checkpoint_epoch_260.pt" -device = "cuda:2" +checkpoint_path = "artifact/daily-forest-25/checkpoint_final.pt" +device = "cuda:1" save_dir = "test_images" os.makedirs(save_dir, exist_ok=True) @@ -28,7 +28,7 @@ rf.model.eval() _, test_dataset = get_dataset() -batch_size = 1 +batch_size = 8 psnr_sum = 0 ssim_sum = 0 lpips_sum = 0 @@ -39,7 +39,7 @@ max_save = 10 with torch.no_grad(): for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"): batch = test_dataset[i : i + batch_size] - images = rf.sample(batch["cloud"].to(device)) + images = rf.sample_heun(batch["cloud"].to(device), 1) image = denormalize(images[-1]).clamp(0, 1) original = denormalize(batch["gt"]).clamp(0, 1) @@ -49,12 +49,13 @@ with torch.no_grad(): save_image(image[j], f"{save_dir}/pred_{saved_count}.png") save_image(original[j], f"{save_dir}/gt_{saved_count}.png") save_image( - denormalize(batch["x0"][j]).clamp(0, 1), + denormalize(batch["cloud"][j]).clamp(0, 1), f"{save_dir}/input_{saved_count}.png", ) saved_count += 1 psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) + print(psnr, ssim, lpips) psnr_sum += psnr.sum().item() ssim_sum += ssim.sum().item() lpips_sum += lpips.sum().item() diff --git a/src/benchmark/__init__.py b/src/benchmark/__init__.py index fe078a9..0072fdc 100644 --- a/src/benchmark/__init__.py +++ b/src/benchmark/__init__.py @@ -4,6 +4,7 @@ from torchmetrics.image import ( StructuralSimilarityIndexMeasure, ) +psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3)) ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none") lpips = LearnedPerceptualImagePatchSimilarity( net_type="alex", reduction="none", normalize=True @@ -11,5 +12,4 @@ lpips = LearnedPerceptualImagePatchSimilarity( def benchmark(image1, image2): - psnr = PeakSignalNoiseRatio(1.0, reduction="none") return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2) diff --git a/src/dataset/cuhk_cr1.py b/src/dataset/cuhk_cr1.py index 6c7e075..da6f711 100644 --- a/src/dataset/cuhk_cr1.py +++ b/src/dataset/cuhk_cr1.py @@ -45,7 +45,7 @@ def get_dataset() -> tuple[Dataset, Dataset]: batch_size=32, remove_columns=dataset["train"].column_names, ) - dataset.set_format(type="torch", columns=["x0", "x1"]) + dataset.set_format(type="torch", columns=["cloud", "gt"]) dataset.save_to_disk("datasets/CUHK-CR1") return dataset["train"], dataset["test"] diff --git a/src/dataset/cuhk_cr2.py b/src/dataset/cuhk_cr2.py new file mode 100644 index 0000000..4b0b35b --- /dev/null +++ b/src/dataset/cuhk_cr2.py @@ -0,0 +1,62 @@ +import os +from pathlib import Path + +from datasets import Dataset, DatasetDict, Image +from src.dataset.preprocess import make_transform + +transform = make_transform(512) + + +def get_dataset() -> tuple[Dataset, Dataset]: + if os.path.exists("datasets/CUHK-CR2"): + dataset = DatasetDict.load_from_disk("datasets/CUHK-CR2") + return dataset["train"], dataset["test"] + + data_dir = Path("/data2/C-CUHK/CUHK-CR2") + + train_cloud = sorted((data_dir / "train/cloud").glob("*.png")) + train_no_cloud = sorted((data_dir / "train/label").glob("*.png")) + test_cloud = sorted((data_dir / "test/cloud").glob("*.png")) + test_no_cloud = sorted((data_dir / "test/label").glob("*.png")) + + dataset = DatasetDict( + { + "train": Dataset.from_dict( + { + "cloud": [str(p) for p in train_cloud], + "label": [str(p) for p in train_no_cloud], + } + ) + .cast_column("cloud", Image()) + .cast_column("label", Image()), + "test": Dataset.from_dict( + { + "cloud": [str(p) for p in test_cloud], + "label": [str(p) for p in test_no_cloud], + } + ) + .cast_column("cloud", Image()) + .cast_column("label", Image()), + } + ) + dataset = dataset.map( + preprocess_function, + batched=True, + batch_size=32, + remove_columns=dataset["train"].column_names, + ) + dataset.set_format(type="torch", columns=["cloud", "gt"]) + dataset.save_to_disk("datasets/CUHK-CR2") + + return dataset["train"], dataset["test"] + + +def preprocess_function(examples): + x0_list = [] + x1_list = [] + for x0_img, x1_img in zip(examples["cloud"], examples["label"]): + x0_transformed = transform(x0_img) + x1_transformed = transform(x1_img) + x0_list.append(x0_transformed) + x1_list.append(x1_transformed) + return {"cloud": x0_list, "gt": x1_list} diff --git a/src/model/utransformer.py b/src/model/utransformer.py index e3e2a57..a1255fa 100644 --- a/src/model/utransformer.py +++ b/src/model/utransformer.py @@ -105,10 +105,11 @@ class DinoConditionedLayer(DINOv3ViTLayer): conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + do_condition: bool = True, **kwargs, ) -> torch.Tensor: assert position_embeddings is not None - assert conditioning_input is not None + assert conditioning_input is not None or not do_condition residual = hidden_states hidden_states = self.norm1(hidden_states) @@ -120,13 +121,14 @@ class DinoConditionedLayer(DINOv3ViTLayer): hidden_states = self.layer_scale1(hidden_states) hidden_states = self.drop_path(hidden_states) + residual - residual = hidden_states - hidden_states = self.norm_cond(hidden_states) - hidden_states, _ = self.cond( - hidden_states, conditioning_input, conditioning_input - ) - hidden_states = self.layer_scale_cond(hidden_states) - hidden_states = self.drop_path(hidden_states) + residual + if do_condition: + residual = hidden_states + hidden_states = self.norm_cond(hidden_states) + hidden_states, _ = self.cond( + hidden_states, conditioning_input, conditioning_input + ) + hidden_states = self.layer_scale_cond(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual residual = hidden_states hidden_states = self.norm2(hidden_states) @@ -191,6 +193,8 @@ class DinoV3ViTDecoder(nn.Module): ) self.pixel_shuffle = nn.PixelShuffle(self.patch_size) + nn.init.zeros_(self.projection.weight) + nn.init.zeros_(self.projection.bias) def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: batch_size = x.shape[0] @@ -211,9 +215,14 @@ class DinoV3ViTDecoder(nn.Module): class UTransformer(nn.Module): - def __init__(self, config: DINOv3ViTConfig, num_classes: int): + def __init__( + self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4 + ): super().__init__() self.config = config + self.scale_factor = scale_factor + + assert config.num_hidden_layers % scale_factor == 0 self.embeddings = DINOv3ViTEmbeddings(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) @@ -228,18 +237,21 @@ class UTransformer(nn.Module): for _ in range(config.num_hidden_layers) ] ) + self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.decoder_layers = nn.ModuleList( [ DinoConditionedLayer(config, False) - for _ in range(config.num_hidden_layers // 2) + for _ in range(config.num_hidden_layers // scale_factor) ] ) + self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.decoder = DinoV3ViTDecoder(config) # freeze pretrained self.embeddings.requires_grad_(False) self.rope_embeddings.requires_grad_(False) + self.encoder_norm.requires_grad_(False) def forward( self, @@ -260,7 +272,8 @@ class UTransformer(nn.Module): residual = [] for i, layer_module in enumerate(self.encoder_layers): - residual.append(x) + if i % self.scale_factor == 0: + residual.append(x) layer_head_mask = head_mask[i] if head_mask is not None else None x = layer_module( x, @@ -269,35 +282,71 @@ class UTransformer(nn.Module): position_embeddings=position_embeddings, ) + x = self.encoder_norm(x) + + reversed_residual = residual[::-1] for i, layer_module in enumerate(self.decoder_layers): layer_head_mask = head_mask[i] if head_mask is not None else None - x = x + residual.pop() + residual.pop() x = layer_module( x, conditioning_input=conditioning_input, attention_mask=layer_head_mask, position_embeddings=position_embeddings, ) + x = x + reversed_residual[i] - return self.decoder(x, image_size=pixel_values.shape[-2:]) + x = self.decoder_norm(x) + + return self.decoder(x, image_size=pixel_values.shape[-2:]), residual + + def get_residual( + self, + pixel_values: torch.Tensor, + time: Optional[torch.Tensor], + do_condition: bool, + ): + pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) + position_embeddings = self.rope_embeddings(pixel_values) + + x = self.embeddings(pixel_values, bool_masked_pos=None) + + if do_condition: + t = self.t_embedder(time).unsqueeze(1) + # y = self.y_embedder(cond, self.training).unsqueeze(1) + # conditioning_input = t.to(x.dtype) + y.to(x.dtype) + conditioning_input = t.to(x.dtype) + else: + conditioning_input = None + + residual = [] + for i, layer_module in enumerate(self.encoder_layers): + if i % self.scale_factor == 0: + residual.append(x) + + x = layer_module( + x, + conditioning_input=conditioning_input, + attention_mask=None, + position_embeddings=position_embeddings, + do_condition=do_condition, + ) + + return residual @staticmethod def from_pretrained_backbone(name: str): config = DINOv3ViTConfig.from_pretrained(name) - instance = UTransformer(config, 0).to("cuda:2") + instance = UTransformer(config, 0).to("cuda:1") weight_dict = {} with safe_open( - hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:2" + hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1" ) as f: for key in f.keys(): new_key = key.replace("layer.", "encoder_layers.").replace( "norm.", "encoder_norm." ) - if key.startswith("norm."): - continue - weight_dict[new_key] = f.get_tensor(key) instance.load_state_dict(weight_dict, strict=False) diff --git a/src/rf.py b/src/rf.py index 3b12dcb..1913245 100644 --- a/src/rf.py +++ b/src/rf.py @@ -13,17 +13,13 @@ def pseudo_huber_loss(x: torch.Tensor, c=0.00054): class RF: - def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_huber"): + def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_mse_enhanced"): self.model = model self.ln = ln self.ushaped = ushaped self.loss_fn = loss_fn - self.lpips = ( - lpips.LPIPS(net="vgg").to(model.device) - if loss_fn == "lpips_huber" - else None - ) + self.lpips = lpips.LPIPS(net="vgg").to("cuda:1") if "lpips" in loss_fn else None def forward(self, x0, z1): # x0 is gt / z is noise @@ -41,7 +37,7 @@ class RF: texp = t.view([b, *([1] * len(x0.shape[1:]))]) zt = (1 - texp) * x0 + texp * z1 - vtheta = self.model(zt, t) + vtheta, residual = self.model(zt, t) if self.loss_fn == "lpips_huber": # https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t) @@ -57,6 +53,41 @@ class RF: weight = t.view(-1) loss = (1 - weight) * huber + lpips + elif self.loss_fn == "lpips_mse": + if not self.lpips: + raise Exception + + lpips = self.lpips( + denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1) + ) + loss = ((z1 - x0 - vtheta) ** 2).mean( + dim=list(range(1, len(x0.shape))) + ) + 2.0 * lpips + elif self.loss_fn == "lpips_mse_enhanced": + if not self.lpips: + raise Exception + + lpips = self.lpips( + denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1) + ) + dino_loss = torch.stack( + [ + ( + 1 + - torch.nn.functional.cosine_similarity( + v_residual, x0_residual, dim=-1 + ) + ).mean(dim=-1) + for v_residual, x0_residual in zip( + residual, self.model.get_residual(x0, None, False) + ) + ] + ).mean(dim=0) * (2 - t.view(-1)) + loss = ( + ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape)))) + + 2.0 * lpips + + dino_loss + ) else: loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape)))) @@ -66,7 +97,7 @@ class RF: return loss.mean(), ttloss @torch.no_grad() - def sample(self, z1, sample_steps=50): + def sample(self, z1, sample_steps=5): b = z1.size(0) dt = 1.0 / sample_steps dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))]) @@ -76,8 +107,36 @@ class RF: t = i / sample_steps t = torch.tensor([t] * b).to(z.device) - vc = self.model(z, t) + vc, _ = self.model(z, t) z = z - dt * vc images.append(z) return images + + @torch.no_grad() + def sample_heun(self, z1, sample_steps=50): + b = z1.size(0) + dt = 1.0 / sample_steps + + images = [z1] + z = z1 + + for i in range(sample_steps, 0, -1): + t_current = i / sample_steps + t_next = (i - 1) / sample_steps + + t_current_tensor = torch.tensor([t_current] * b, device=z.device) + + v_current, _ = self.model(z, t_current_tensor) + z_pred = z - dt * v_current + + t_next_tensor = torch.tensor([t_next] * b, device=z.device) + v_next, _ = self.model(z_pred, t_next_tensor) + + v_avg = 0.5 * (v_current + v_next) + + z = z - dt * v_avg + + images.append(z) + + return images diff --git a/test_images/pred_0.png b/test_images/pred_0.png index c93f125..ff382db 100644 Binary files a/test_images/pred_0.png and b/test_images/pred_0.png differ diff --git a/test_images/pred_1.png b/test_images/pred_1.png index 24a90eb..025b1fb 100644 Binary files a/test_images/pred_1.png and b/test_images/pred_1.png differ diff --git a/test_images/pred_2.png b/test_images/pred_2.png index 6071615..8170bf8 100644 Binary files a/test_images/pred_2.png and b/test_images/pred_2.png differ diff --git a/test_images/pred_3.png b/test_images/pred_3.png index 719f6a4..b40c5b1 100644 Binary files a/test_images/pred_3.png and b/test_images/pred_3.png differ diff --git a/test_images/pred_4.png b/test_images/pred_4.png index 383684b..0afb7c6 100644 Binary files a/test_images/pred_4.png and b/test_images/pred_4.png differ diff --git a/test_images/pred_5.png b/test_images/pred_5.png index 7d3efbd..53a9586 100644 Binary files a/test_images/pred_5.png and b/test_images/pred_5.png differ diff --git a/test_images/pred_6.png b/test_images/pred_6.png index 4ca3a17..9b2d219 100644 Binary files a/test_images/pred_6.png and b/test_images/pred_6.png differ diff --git a/test_images/pred_7.png b/test_images/pred_7.png index 9954391..c23cca9 100644 Binary files a/test_images/pred_7.png and b/test_images/pred_7.png differ diff --git a/test_images/pred_8.png b/test_images/pred_8.png index 93bd5f2..1a625b1 100644 Binary files a/test_images/pred_8.png and b/test_images/pred_8.png differ diff --git a/test_images/pred_9.png b/test_images/pred_9.png index 85bbd5b..dfc59ae 100644 Binary files a/test_images/pred_9.png and b/test_images/pred_9.png differ