things
This commit is contained in:
@@ -7,6 +7,7 @@ from torchmetrics.image import (
|
||||
|
||||
psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3))
|
||||
lp = lpips.LPIPS(net="alex")
|
||||
flawed_lp = lpips.LPIPS(net="alex")
|
||||
|
||||
|
||||
def benchmark(image1, image2):
|
||||
@@ -19,4 +20,5 @@ def benchmark(image1, image2):
|
||||
size_average=False,
|
||||
),
|
||||
lp(image1 * 2 - 1, image2 * 2 - 1),
|
||||
flawed_lp(image1 * 255, image2 * 255),
|
||||
)
|
||||
|
||||
@@ -2,9 +2,11 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset, DatasetDict, Image
|
||||
from src.dataset.preprocess import make_transform
|
||||
|
||||
from src.dataset.preprocess import make_nir_transform, make_transform
|
||||
|
||||
transform = make_transform(256)
|
||||
nir_transform = make_nir_transform(256)
|
||||
|
||||
|
||||
def get_dataset() -> tuple[Dataset, Dataset]:
|
||||
@@ -13,30 +15,43 @@ def get_dataset() -> tuple[Dataset, Dataset]:
|
||||
return dataset["train"], dataset["test"]
|
||||
|
||||
data_dir = Path("/data2/C-CUHK/CUHK-CR2")
|
||||
nir_dir = Path("/data2/C-CUHK/nir/CUHK-CR2")
|
||||
|
||||
train_cloud = sorted((data_dir / "train/cloud").glob("*.png"))
|
||||
train_cloud_nir = sorted((nir_dir / "train/cloud").glob("*.png"))
|
||||
train_no_cloud = sorted((data_dir / "train/label").glob("*.png"))
|
||||
train_no_cloud_nir = sorted((nir_dir / "train/label").glob("*.png"))
|
||||
test_cloud = sorted((data_dir / "test/cloud").glob("*.png"))
|
||||
test_cloud_nir = sorted((nir_dir / "test/cloud").glob("*.png"))
|
||||
test_no_cloud = sorted((data_dir / "test/label").glob("*.png"))
|
||||
test_no_cloud_nir = sorted((nir_dir / "test/label").glob("*.png"))
|
||||
|
||||
dataset = DatasetDict(
|
||||
{
|
||||
"train": Dataset.from_dict(
|
||||
{
|
||||
"cloud": [str(p) for p in train_cloud],
|
||||
"cloud_nir": [str(p) for p in train_cloud_nir],
|
||||
"label": [str(p) for p in train_no_cloud],
|
||||
"label_nir": [str(p) for p in train_no_cloud_nir],
|
||||
}
|
||||
)
|
||||
.cast_column("cloud", Image())
|
||||
.cast_column("label", Image()),
|
||||
.cast_column("label", Image())
|
||||
.cast_column("cloud_nir", Image())
|
||||
.cast_column("label_nir", Image()),
|
||||
"test": Dataset.from_dict(
|
||||
{
|
||||
"cloud": [str(p) for p in test_cloud],
|
||||
"cloud_nir": [str(p) for p in test_cloud_nir],
|
||||
"label": [str(p) for p in test_no_cloud],
|
||||
"label_nir": [str(p) for p in test_no_cloud_nir],
|
||||
}
|
||||
)
|
||||
.cast_column("cloud", Image())
|
||||
.cast_column("label", Image()),
|
||||
.cast_column("label", Image())
|
||||
.cast_column("cloud_nir", Image())
|
||||
.cast_column("label_nir", Image()),
|
||||
}
|
||||
)
|
||||
dataset = dataset.map(
|
||||
@@ -45,7 +60,7 @@ def get_dataset() -> tuple[Dataset, Dataset]:
|
||||
batch_size=32,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
)
|
||||
dataset.set_format(type="torch", columns=["cloud", "gt"])
|
||||
dataset.set_format(type="torch", columns=["cloud", "gt", "cloud_nir", "gt_nir"])
|
||||
dataset.save_to_disk("datasets/CUHK-CR2")
|
||||
|
||||
return dataset["train"], dataset["test"]
|
||||
@@ -54,9 +69,25 @@ def get_dataset() -> tuple[Dataset, Dataset]:
|
||||
def preprocess_function(examples):
|
||||
x0_list = []
|
||||
x1_list = []
|
||||
for x0_img, x1_img in zip(examples["cloud"], examples["label"]):
|
||||
x0_nir_list = []
|
||||
x1_nir_list = []
|
||||
for x0_img, x1_img, x0_nir, x1_nir in zip(
|
||||
examples["cloud"],
|
||||
examples["label"],
|
||||
examples["cloud_nir"],
|
||||
examples["label_nir"],
|
||||
):
|
||||
x0_transformed = transform(x0_img)
|
||||
x1_transformed = transform(x1_img)
|
||||
x0_nir = nir_transform(x0_nir)
|
||||
x1_nir = nir_transform(x1_nir)
|
||||
x0_list.append(x0_transformed)
|
||||
x1_list.append(x1_transformed)
|
||||
return {"cloud": x0_list, "gt": x1_list}
|
||||
x0_nir_list.append(x0_nir)
|
||||
x1_nir_list.append(x1_nir)
|
||||
return {
|
||||
"cloud": x0_list,
|
||||
"gt": x1_list,
|
||||
"cloud_nir": x0_nir_list,
|
||||
"gt_nir": x1_nir_list,
|
||||
}
|
||||
|
||||
@@ -14,13 +14,20 @@ def make_transform(resize_size: int = 256):
|
||||
return v2.Compose([to_tensor, resize, to_float, normalize])
|
||||
|
||||
|
||||
def make_nir_transform(resize_size: int = 256):
|
||||
to_tensor = v2.ToImage()
|
||||
resize = v2.Resize((resize_size, resize_size), antialias=True)
|
||||
to_float = v2.ToDtype(torch.float32, scale=True)
|
||||
return v2.Compose([to_tensor, v2.Grayscale(), resize, to_float])
|
||||
|
||||
|
||||
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
||||
return tensor * std + mean
|
||||
return tensor[:, :3] * std + mean
|
||||
|
||||
|
||||
def normalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
||||
return (tensor - mean) / std
|
||||
return (tensor[:, :3] - mean) / std
|
||||
|
||||
@@ -165,7 +165,10 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
)
|
||||
|
||||
angles = (
|
||||
2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] # type: ignore
|
||||
2
|
||||
* math.pi
|
||||
* patch_coords[:, :, None]
|
||||
* self.inv_freq[None, None, :].to(patch_coords.device) # type: ignore
|
||||
)
|
||||
angles = angles.flatten(1, 2)
|
||||
angles = angles.tile(2)
|
||||
|
||||
0
src/model/hdit.py
Normal file
0
src/model/hdit.py
Normal file
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -260,22 +261,113 @@ class DinoV3ViTDecoder(nn.Module):
|
||||
# return x
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size_input, max_freqs):
|
||||
super().__init__()
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
self.embedder = nn.Sequential(
|
||||
nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True),
|
||||
)
|
||||
|
||||
@lru_cache
|
||||
def fetch_pos(self, patch_size, device, dtype):
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
freqs = torch.linspace(
|
||||
0, self.max_freqs, self.max_freqs, dtype=dtype, device=device
|
||||
)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2)
|
||||
return dct
|
||||
|
||||
def forward(self, inputs):
|
||||
target_dtype = self.embedder[0].weight.dtype
|
||||
inputs = inputs.to(dtype=target_dtype)
|
||||
B, P2, C = inputs.shape
|
||||
patch_size = int(P2**0.5)
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
dct = self.fetch_pos(patch_size, device, dtype)
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
inputs = torch.cat([inputs, dct], dim=-1)
|
||||
inputs = self.embedder(inputs)
|
||||
return inputs
|
||||
|
||||
|
||||
class NerfBlock(nn.Module):
|
||||
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio: int = 4):
|
||||
super().__init__()
|
||||
self.param_generator1 = nn.Sequential(
|
||||
nn.Linear(hidden_size_s, 2 * hidden_size_x**2 * mlp_ratio, bias=True),
|
||||
)
|
||||
self.norm = nn.RMSNorm(hidden_size_x, eps=1e-6)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
def forward(self, x, s):
|
||||
batch_size, num_x, hidden_size_x = x.shape
|
||||
mlp_params1 = self.param_generator1(s)
|
||||
fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1)
|
||||
fc1_param1 = fc1_param1.view(
|
||||
batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio
|
||||
)
|
||||
fc2_param1 = fc2_param1.view(
|
||||
batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x
|
||||
)
|
||||
|
||||
# normalize fc1
|
||||
normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2)
|
||||
# normalize fc2
|
||||
normalized_fc2_param1 = torch.nn.functional.normalize(fc2_param1, dim=-2)
|
||||
# mlp 1
|
||||
res_x = x
|
||||
x = self.norm(x)
|
||||
x = torch.bmm(x, normalized_fc1_param1)
|
||||
x = torch.nn.functional.silu(x)
|
||||
x = torch.bmm(x, normalized_fc2_param1)
|
||||
x = x + res_x
|
||||
return x
|
||||
|
||||
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm = nn.RMSNorm(hidden_size, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class UTransformer(nn.Module):
|
||||
def __init__(
|
||||
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
|
||||
self,
|
||||
config: DINOv3ViTConfig,
|
||||
num_classes: int,
|
||||
nerf_patch=16,
|
||||
nerf_hidden=64,
|
||||
scale_factor: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scale_factor = scale_factor
|
||||
self.nerf_patch_size = nerf_patch
|
||||
|
||||
assert config.num_hidden_layers % scale_factor == 0
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
# self.y_embedder = LabelEmbedder(
|
||||
# num_classes, config.hidden_size, config.drop_path_rate
|
||||
# ) # disable cond for now
|
||||
|
||||
self.encoder_layers = nn.ModuleList(
|
||||
[
|
||||
@@ -302,8 +394,13 @@ class UTransformer(nn.Module):
|
||||
self.rest_decoder = nn.ModuleList(
|
||||
[DinoConditionedLayer(config, False) for _ in range(4)]
|
||||
)
|
||||
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
|
||||
# nerf!
|
||||
self.nerf_encoder = NerfEmbedder(3, nerf_hidden, 8) # (rgb, hidden, freq)
|
||||
self.nerf_decoder = nn.ModuleList(
|
||||
[NerfBlock(self.config.hidden_size, nerf_hidden) for _ in range(12)]
|
||||
)
|
||||
self.final_layer = NerfFinalLayer(nerf_hidden, 3)
|
||||
|
||||
# freeze pretrained
|
||||
self.embeddings.requires_grad_(False)
|
||||
@@ -321,6 +418,13 @@ class UTransformer(nn.Module):
|
||||
if time.dim() == 0:
|
||||
time = time.repeat(pixel_values.shape[0])
|
||||
|
||||
# resolution config
|
||||
B = pixel_values.shape[0]
|
||||
dino_h = pixel_values.shape[-2] // self.config.patch_size
|
||||
dino_w = pixel_values.shape[-1] // self.config.patch_size
|
||||
nerf_h = pixel_values.shape[-2] // self.nerf_patch_size
|
||||
nerf_w = pixel_values.shape[-1] // self.nerf_patch_size
|
||||
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
@@ -367,11 +471,52 @@ class UTransformer(nn.Module):
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
do_condition=False,
|
||||
)
|
||||
) # (batch, image // patch^2, 1024)
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
x = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||
nerf_cond = nn.functional.silu(t + x) # (batch, image // patch^2, 1024)
|
||||
nerf_cond = nerf_cond.reshape(
|
||||
B, dino_h, dino_w, self.config.hidden_size
|
||||
).permute(0, 3, 1, 2) # (batch, 1024, image // patch, image // patch)
|
||||
# nerf_cond = nn.functional.interpolate(
|
||||
# nerf_cond, size=(nerf_h, nerf_w), mode="bilinear", align_corners=False
|
||||
# )
|
||||
nerf_cond = (
|
||||
nerf_cond.permute(0, 2, 3, 1)
|
||||
.reshape(-1, nerf_h * nerf_w, self.config.hidden_size)
|
||||
.view(-1, self.config.hidden_size)
|
||||
)
|
||||
|
||||
# nerf
|
||||
x_nerf = nn.functional.unfold(
|
||||
pixel_values, self.nerf_patch_size, stride=self.nerf_patch_size
|
||||
).transpose(1, 2)
|
||||
x_nerf = x_nerf.reshape(
|
||||
B * x_nerf.shape[1], -1, self.nerf_patch_size**2
|
||||
).transpose(1, 2)
|
||||
x_nerf = self.nerf_encoder(x_nerf)
|
||||
|
||||
for module in self.nerf_decoder:
|
||||
x_nerf = module(x_nerf, nerf_cond)
|
||||
|
||||
x_nerf = self.final_layer(x_nerf)
|
||||
|
||||
num_patches = nerf_h * nerf_w
|
||||
x_nerf = x_nerf.reshape(
|
||||
B * num_patches, -1
|
||||
) # (B*num_patches, 48): flatten pixels+RGB per patch
|
||||
x_nerf = (
|
||||
x_nerf.view(B, num_patches, -1).transpose(1, 2).contiguous()
|
||||
) # (B, 48, num_patches)
|
||||
|
||||
res = nn.functional.fold(
|
||||
x_nerf,
|
||||
(pixel_values.shape[-2], pixel_values.shape[-1]),
|
||||
kernel_size=self.nerf_patch_size,
|
||||
stride=self.nerf_patch_size,
|
||||
)
|
||||
return res
|
||||
|
||||
def get_residual(
|
||||
self,
|
||||
@@ -410,7 +555,7 @@ class UTransformer(nn.Module):
|
||||
@staticmethod
|
||||
def from_pretrained_backbone(name: str):
|
||||
config = DINOv3ViTConfig.from_pretrained(name)
|
||||
instance = UTransformer(config, 0).to("cuda:1")
|
||||
instance = UTransformer(config, 0)
|
||||
|
||||
weight_dict = {}
|
||||
with safe_open(
|
||||
|
||||
41
src/rf.py
41
src/rf.py
@@ -22,7 +22,7 @@ use_lecam = True
|
||||
|
||||
|
||||
class RF:
|
||||
def __init__(self, model, fm="otcfm", loss="mse"):
|
||||
def __init__(self, model, fm="otcfm", loss="mse", lp=None):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.iter = 0
|
||||
@@ -40,19 +40,21 @@ class RF:
|
||||
raise NotImplementedError(
|
||||
f"Unknown model {fm}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
|
||||
)
|
||||
if not lp:
|
||||
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1")
|
||||
self.lpips2 = lpips.LPIPS(net="alex").to("cuda:1")
|
||||
|
||||
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1")
|
||||
self.lpips2 = lpips.LPIPS(net="alex").to("cuda:1")
|
||||
|
||||
discriminator = PatchDiscriminator().to("cuda:1")
|
||||
discriminator.requires_grad_(True)
|
||||
self.discriminator = discriminator
|
||||
self.optimizer_D = optim.AdamW(
|
||||
discriminator.parameters(),
|
||||
lr=2e-4,
|
||||
weight_decay=1e-3,
|
||||
betas=(0.9, 0.95),
|
||||
)
|
||||
discriminator = PatchDiscriminator().to("cuda:1")
|
||||
discriminator.requires_grad_(True)
|
||||
self.discriminator = discriminator
|
||||
self.optimizer_D = optim.AdamW(
|
||||
discriminator.parameters(),
|
||||
lr=2e-4,
|
||||
weight_decay=1e-3,
|
||||
betas=(0.9, 0.95),
|
||||
)
|
||||
else:
|
||||
self.lpips = lp
|
||||
|
||||
def gan_loss(self, real, fake):
|
||||
global lecam_beta, lecam_anchor_real_logits, lecam_anchor_fake_logits, use_lecam
|
||||
@@ -105,7 +107,7 @@ class RF:
|
||||
if condition:
|
||||
vt = self.model(xt, t, cloud)
|
||||
else:
|
||||
vt, _ = self.model(xt, t)
|
||||
vt = self.model(xt, t)
|
||||
|
||||
if self.loss == "mse":
|
||||
loss = ((vt - ut) ** 2).mean(dim=list(range(1, len(gt.shape))))
|
||||
@@ -116,11 +118,18 @@ class RF:
|
||||
denormalize(gt) * 2 - 1,
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt) * 2 - 1,
|
||||
)
|
||||
ssim = 1 - ms_ssim(
|
||||
denormalize(gt),
|
||||
denormalize(xt + (1 - t[:, None, None, None]) * vt),
|
||||
data_range=1.0,
|
||||
size_average=False,
|
||||
)
|
||||
loss_list = {
|
||||
"train/mse": mse.mean().item(),
|
||||
"train/lpips": lpips.mean().item(),
|
||||
"train/ssim": ssim.mean().item(),
|
||||
}
|
||||
loss = mse + lpips * 2.0
|
||||
loss = mse + lpips * 2.0 + ssim
|
||||
elif self.loss == "gan_lpips_mse":
|
||||
self.iter += 1
|
||||
# if self.iter % 4 == 0:
|
||||
@@ -179,7 +188,7 @@ class RF:
|
||||
)
|
||||
else:
|
||||
traj = odeint(
|
||||
lambda t, x: self.model(x, t)[0],
|
||||
lambda t, x: self.model(x, t),
|
||||
cloud,
|
||||
t_span,
|
||||
rtol=tol,
|
||||
|
||||
Reference in New Issue
Block a user