diff --git a/main.py b/main.py index 976084f..7a56e08 100644 --- a/main.py +++ b/main.py @@ -72,8 +72,8 @@ for epoch in range(100): ): batch = test_dataset[i : i + batch_size] images = rf.sample(batch["x0"].to(device)) - image = denormalize(images[-1]).clamp(0, 1) * 255 - original = denormalize(batch["x1"]).clamp(0, 1) * 255 + image = denormalize(images[-1]).clamp(0, 1) + original = denormalize(batch["x1"]).clamp(0, 1) psnr, ssim, lpips = benchmark(image.cpu(), original.cpu()) psnr_sum += psnr.sum().item() @@ -92,7 +92,6 @@ for epoch in range(100): "epoch": epoch + 1, } ) - rf.model.train() torch.save( diff --git a/quick_eval.py b/quick_eval.py index cf8dd60..e0cf680 100644 --- a/quick_eval.py +++ b/quick_eval.py @@ -41,13 +41,13 @@ with torch.no_grad(): batch = test_dataset[i : i + batch_size] images = rf.sample(batch["x0"].to(device)) - image = denormalize(images[-1]).clamp(0, 1) * 255 - original = denormalize(batch["x1"]).clamp(0, 1) * 255 + image = denormalize(images[-1]).clamp(0, 1) + original = denormalize(batch["x1"]).clamp(0, 1) if saved_count < max_save: for j in range(min(image.shape[0], max_save - saved_count)): - save_image(image[j] / 255, f"{save_dir}/pred_{saved_count}.png") - save_image(original[j] / 255, f"{save_dir}/gt_{saved_count}.png") + save_image(image[j], f"{save_dir}/pred_{saved_count}.png") + save_image(original[j], f"{save_dir}/gt_{saved_count}.png") save_image( denormalize(batch["x0"][j]).clamp(0, 1), f"{save_dir}/input_{saved_count}.png", diff --git a/src/benchmark/__init__.py b/src/benchmark/__init__.py index 35ce9ee..d56821c 100644 --- a/src/benchmark/__init__.py +++ b/src/benchmark/__init__.py @@ -4,9 +4,11 @@ from torchmetrics.image import ( StructuralSimilarityIndexMeasure, ) -psnr = PeakSignalNoiseRatio(255.0, reduction="none") -ssim = StructuralSimilarityIndexMeasure(reduction="none") -lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex", reduction="none") +psnr = PeakSignalNoiseRatio(1.0, reduction="none") +ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none") +lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", reduction="none", normalize=True +) def benchmark(image1, image2): diff --git a/src/model/utransformer.py b/src/model/utransformer.py index 5e610ca..55958d4 100644 --- a/src/model/utransformer.py +++ b/src/model/utransformer.py @@ -137,40 +137,103 @@ class DinoConditionedLayer(DINOv3ViTLayer): return hidden_states +# class DinoV3ViTDecoder(nn.Module): +# def __init__(self, config: DINOv3ViTConfig): +# super().__init__() +# self.config = config +# self.num_channels_out = config.num_channels + +# self.projection = nn.Linear( +# config.hidden_size, +# self.num_channels_out * config.patch_size * config.patch_size, +# bias=True, +# ) + +# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: +# batch_size = x.shape[0] + +# num_special_tokens = 1 + self.config.num_register_tokens +# patch_tokens = x[:, num_special_tokens:, :] + +# projected_tokens = self.projection(patch_tokens) + +# p = self.config.patch_size +# c = self.num_channels_out +# h_grid = image_size[0] // p +# w_grid = image_size[1] // p + +# assert patch_tokens.shape[1] == h_grid * w_grid, ( +# "Number of patches does not match image size." +# ) + +# x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c) + +# x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped) + +# reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p) + +# return reconstructed_image + +# lets try conv decoder + + class DinoV3ViTDecoder(nn.Module): def __init__(self, config: DINOv3ViTConfig): super().__init__() self.config = config self.num_channels_out = config.num_channels + hidden_dim = config.hidden_size + patch_size = config.patch_size - self.projection = nn.Linear( - config.hidden_size, - self.num_channels_out * config.patch_size * config.patch_size, - bias=True, + self.projection = nn.Linear(hidden_dim, hidden_dim) + + if patch_size == 14: + final_upsample = 7 + elif patch_size == 16: + final_upsample = 8 + elif patch_size == 8: + final_upsample = 4 + else: + raise ValueError("invalid") + + self.decoder = nn.Sequential( + nn.Conv2d(hidden_dim, 256, kernel_size=3, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Upsample( + scale_factor=final_upsample, mode="bilinear", align_corners=False + ), + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(32, self.num_channels_out, kernel_size=1), ) def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: batch_size = x.shape[0] - num_special_tokens = 1 + self.config.num_register_tokens - patch_tokens = x[:, num_special_tokens:, :] + patch_tokens = x[:, 1 + self.config.num_register_tokens :, :] projected_tokens = self.projection(patch_tokens) p = self.config.patch_size - c = self.num_channels_out h_grid = image_size[0] // p w_grid = image_size[1] // p - assert patch_tokens.shape[1] == h_grid * w_grid, ( - "Number of patches does not match image size." + assert patch_tokens.shape[1] == h_grid * w_grid + + x_spatial = projected_tokens.reshape( + batch_size, h_grid, w_grid, self.config.hidden_size ) - x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c) - - x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped) - - reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p) + x_spatial = x_spatial.permute(0, 3, 1, 2) + reconstructed_image = self.decoder(x_spatial) return reconstructed_image diff --git a/test_images/pred_0.png b/test_images/pred_0.png index 0b47ae6..7fab0af 100644 Binary files a/test_images/pred_0.png and b/test_images/pred_0.png differ diff --git a/test_images/pred_1.png b/test_images/pred_1.png index ae378f5..cdcc5c7 100644 Binary files a/test_images/pred_1.png and b/test_images/pred_1.png differ diff --git a/test_images/pred_2.png b/test_images/pred_2.png index d1b010f..f11e44e 100644 Binary files a/test_images/pred_2.png and b/test_images/pred_2.png differ diff --git a/test_images/pred_3.png b/test_images/pred_3.png index 31538d2..9f04f63 100644 Binary files a/test_images/pred_3.png and b/test_images/pred_3.png differ diff --git a/test_images/pred_4.png b/test_images/pred_4.png index 5e93424..8ae0b3a 100644 Binary files a/test_images/pred_4.png and b/test_images/pred_4.png differ diff --git a/test_images/pred_5.png b/test_images/pred_5.png index 4b10854..e91574d 100644 Binary files a/test_images/pred_5.png and b/test_images/pred_5.png differ diff --git a/test_images/pred_6.png b/test_images/pred_6.png index 01c42fc..6dadf38 100644 Binary files a/test_images/pred_6.png and b/test_images/pred_6.png differ diff --git a/test_images/pred_7.png b/test_images/pred_7.png index f8dd5b0..4e0f613 100644 Binary files a/test_images/pred_7.png and b/test_images/pred_7.png differ diff --git a/test_images/pred_8.png b/test_images/pred_8.png index a7dd1a4..0f74ca7 100644 Binary files a/test_images/pred_8.png and b/test_images/pred_8.png differ diff --git a/test_images/pred_9.png b/test_images/pred_9.png index 704d15f..7e3aaaa 100644 Binary files a/test_images/pred_9.png and b/test_images/pred_9.png differ