another approach for final decoder
5
main.py
@@ -72,8 +72,8 @@ for epoch in range(100):
|
|||||||
):
|
):
|
||||||
batch = test_dataset[i : i + batch_size]
|
batch = test_dataset[i : i + batch_size]
|
||||||
images = rf.sample(batch["x0"].to(device))
|
images = rf.sample(batch["x0"].to(device))
|
||||||
image = denormalize(images[-1]).clamp(0, 1) * 255
|
image = denormalize(images[-1]).clamp(0, 1)
|
||||||
original = denormalize(batch["x1"]).clamp(0, 1) * 255
|
original = denormalize(batch["x1"]).clamp(0, 1)
|
||||||
|
|
||||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||||
psnr_sum += psnr.sum().item()
|
psnr_sum += psnr.sum().item()
|
||||||
@@ -92,7 +92,6 @@ for epoch in range(100):
|
|||||||
"epoch": epoch + 1,
|
"epoch": epoch + 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
rf.model.train()
|
rf.model.train()
|
||||||
|
|
||||||
torch.save(
|
torch.save(
|
||||||
|
|||||||
@@ -41,13 +41,13 @@ with torch.no_grad():
|
|||||||
batch = test_dataset[i : i + batch_size]
|
batch = test_dataset[i : i + batch_size]
|
||||||
images = rf.sample(batch["x0"].to(device))
|
images = rf.sample(batch["x0"].to(device))
|
||||||
|
|
||||||
image = denormalize(images[-1]).clamp(0, 1) * 255
|
image = denormalize(images[-1]).clamp(0, 1)
|
||||||
original = denormalize(batch["x1"]).clamp(0, 1) * 255
|
original = denormalize(batch["x1"]).clamp(0, 1)
|
||||||
|
|
||||||
if saved_count < max_save:
|
if saved_count < max_save:
|
||||||
for j in range(min(image.shape[0], max_save - saved_count)):
|
for j in range(min(image.shape[0], max_save - saved_count)):
|
||||||
save_image(image[j] / 255, f"{save_dir}/pred_{saved_count}.png")
|
save_image(image[j], f"{save_dir}/pred_{saved_count}.png")
|
||||||
save_image(original[j] / 255, f"{save_dir}/gt_{saved_count}.png")
|
save_image(original[j], f"{save_dir}/gt_{saved_count}.png")
|
||||||
save_image(
|
save_image(
|
||||||
denormalize(batch["x0"][j]).clamp(0, 1),
|
denormalize(batch["x0"][j]).clamp(0, 1),
|
||||||
f"{save_dir}/input_{saved_count}.png",
|
f"{save_dir}/input_{saved_count}.png",
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ from torchmetrics.image import (
|
|||||||
StructuralSimilarityIndexMeasure,
|
StructuralSimilarityIndexMeasure,
|
||||||
)
|
)
|
||||||
|
|
||||||
psnr = PeakSignalNoiseRatio(255.0, reduction="none")
|
psnr = PeakSignalNoiseRatio(1.0, reduction="none")
|
||||||
ssim = StructuralSimilarityIndexMeasure(reduction="none")
|
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
|
||||||
lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex", reduction="none")
|
lpips = LearnedPerceptualImagePatchSimilarity(
|
||||||
|
net_type="alex", reduction="none", normalize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def benchmark(image1, image2):
|
def benchmark(image1, image2):
|
||||||
|
|||||||
@@ -137,40 +137,103 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# class DinoV3ViTDecoder(nn.Module):
|
||||||
|
# def __init__(self, config: DINOv3ViTConfig):
|
||||||
|
# super().__init__()
|
||||||
|
# self.config = config
|
||||||
|
# self.num_channels_out = config.num_channels
|
||||||
|
|
||||||
|
# self.projection = nn.Linear(
|
||||||
|
# config.hidden_size,
|
||||||
|
# self.num_channels_out * config.patch_size * config.patch_size,
|
||||||
|
# bias=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||||
|
# batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# num_special_tokens = 1 + self.config.num_register_tokens
|
||||||
|
# patch_tokens = x[:, num_special_tokens:, :]
|
||||||
|
|
||||||
|
# projected_tokens = self.projection(patch_tokens)
|
||||||
|
|
||||||
|
# p = self.config.patch_size
|
||||||
|
# c = self.num_channels_out
|
||||||
|
# h_grid = image_size[0] // p
|
||||||
|
# w_grid = image_size[1] // p
|
||||||
|
|
||||||
|
# assert patch_tokens.shape[1] == h_grid * w_grid, (
|
||||||
|
# "Number of patches does not match image size."
|
||||||
|
# )
|
||||||
|
|
||||||
|
# x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
|
||||||
|
|
||||||
|
# x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
|
||||||
|
|
||||||
|
# reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
|
||||||
|
|
||||||
|
# return reconstructed_image
|
||||||
|
|
||||||
|
# lets try conv decoder
|
||||||
|
|
||||||
|
|
||||||
class DinoV3ViTDecoder(nn.Module):
|
class DinoV3ViTDecoder(nn.Module):
|
||||||
def __init__(self, config: DINOv3ViTConfig):
|
def __init__(self, config: DINOv3ViTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.num_channels_out = config.num_channels
|
self.num_channels_out = config.num_channels
|
||||||
|
hidden_dim = config.hidden_size
|
||||||
|
patch_size = config.patch_size
|
||||||
|
|
||||||
self.projection = nn.Linear(
|
self.projection = nn.Linear(hidden_dim, hidden_dim)
|
||||||
config.hidden_size,
|
|
||||||
self.num_channels_out * config.patch_size * config.patch_size,
|
if patch_size == 14:
|
||||||
bias=True,
|
final_upsample = 7
|
||||||
|
elif patch_size == 16:
|
||||||
|
final_upsample = 8
|
||||||
|
elif patch_size == 8:
|
||||||
|
final_upsample = 4
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid")
|
||||||
|
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Conv2d(hidden_dim, 256, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
||||||
|
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Upsample(
|
||||||
|
scale_factor=final_upsample, mode="bilinear", align_corners=False
|
||||||
|
),
|
||||||
|
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(32, self.num_channels_out, kernel_size=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
num_special_tokens = 1 + self.config.num_register_tokens
|
patch_tokens = x[:, 1 + self.config.num_register_tokens :, :]
|
||||||
patch_tokens = x[:, num_special_tokens:, :]
|
|
||||||
|
|
||||||
projected_tokens = self.projection(patch_tokens)
|
projected_tokens = self.projection(patch_tokens)
|
||||||
|
|
||||||
p = self.config.patch_size
|
p = self.config.patch_size
|
||||||
c = self.num_channels_out
|
|
||||||
h_grid = image_size[0] // p
|
h_grid = image_size[0] // p
|
||||||
w_grid = image_size[1] // p
|
w_grid = image_size[1] // p
|
||||||
|
|
||||||
assert patch_tokens.shape[1] == h_grid * w_grid, (
|
assert patch_tokens.shape[1] == h_grid * w_grid
|
||||||
"Number of patches does not match image size."
|
|
||||||
|
x_spatial = projected_tokens.reshape(
|
||||||
|
batch_size, h_grid, w_grid, self.config.hidden_size
|
||||||
)
|
)
|
||||||
|
|
||||||
x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
|
x_spatial = x_spatial.permute(0, 3, 1, 2)
|
||||||
|
reconstructed_image = self.decoder(x_spatial)
|
||||||
x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
|
|
||||||
|
|
||||||
reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
|
|
||||||
|
|
||||||
return reconstructed_image
|
return reconstructed_image
|
||||||
|
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 489 KiB After Width: | Height: | Size: 489 KiB |
|
Before Width: | Height: | Size: 537 KiB After Width: | Height: | Size: 537 KiB |
|
Before Width: | Height: | Size: 482 KiB After Width: | Height: | Size: 482 KiB |
|
Before Width: | Height: | Size: 477 KiB After Width: | Height: | Size: 477 KiB |
|
Before Width: | Height: | Size: 508 KiB After Width: | Height: | Size: 508 KiB |
|
Before Width: | Height: | Size: 502 KiB After Width: | Height: | Size: 502 KiB |
|
Before Width: | Height: | Size: 541 KiB After Width: | Height: | Size: 541 KiB |
|
Before Width: | Height: | Size: 551 KiB After Width: | Height: | Size: 551 KiB |
|
Before Width: | Height: | Size: 541 KiB After Width: | Height: | Size: 541 KiB |
|
Before Width: | Height: | Size: 488 KiB After Width: | Height: | Size: 488 KiB |