fix wrong average of psnr
21
main.py
@@ -2,16 +2,17 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from torchvision.utils import make_grid
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from src.benchmark import benchmark
|
from src.benchmark import benchmark
|
||||||
from src.dataset.cuhk_cr1 import get_dataset
|
from src.dataset.cuhk_cr2 import get_dataset
|
||||||
from src.dataset.preprocess import denormalize
|
from src.dataset.preprocess import denormalize
|
||||||
from src.model.utransformer import UTransformer
|
from src.model.utransformer import UTransformer
|
||||||
from src.rf import RF
|
from src.rf import RF
|
||||||
|
|
||||||
device = "cuda:2"
|
device = "cuda:1"
|
||||||
|
|
||||||
model = UTransformer.from_pretrained_backbone(
|
model = UTransformer.from_pretrained_backbone(
|
||||||
"facebook/dinov3-vitl16-pretrain-sat493m"
|
"facebook/dinov3-vitl16-pretrain-sat493m"
|
||||||
@@ -21,7 +22,7 @@ optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
|||||||
|
|
||||||
train_dataset, test_dataset = get_dataset()
|
train_dataset, test_dataset = get_dataset()
|
||||||
|
|
||||||
wandb.init(project="cloud-removal-kmu", id="icy-field-12", resume="allow")
|
wandb.init(project="cloud-removal-kmu", id="dashing-moon-31", resume="allow")
|
||||||
|
|
||||||
if not (wandb.run and wandb.run.name):
|
if not (wandb.run and wandb.run.name):
|
||||||
raise Exception("nope")
|
raise Exception("nope")
|
||||||
@@ -36,7 +37,7 @@ if os.path.exists(checkpoint_path):
|
|||||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
start_epoch = checkpoint["epoch"] + 1
|
start_epoch = checkpoint["epoch"] + 1
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 8
|
||||||
accumulation_steps = 8
|
accumulation_steps = 8
|
||||||
total_epoch = 1000
|
total_epoch = 1000
|
||||||
for epoch in range(start_epoch, total_epoch):
|
for epoch in range(start_epoch, total_epoch):
|
||||||
@@ -89,10 +90,20 @@ for epoch in range(start_epoch, total_epoch):
|
|||||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||||
):
|
):
|
||||||
batch = test_dataset[i : i + batch_size]
|
batch = test_dataset[i : i + batch_size]
|
||||||
images = rf.sample(batch["cloud"].to(device))
|
images = rf.sample_heun(batch["cloud"].to(device))
|
||||||
image = denormalize(images[-1]).clamp(0, 1)
|
image = denormalize(images[-1]).clamp(0, 1)
|
||||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
for step, demo in enumerate([images[0], images[-1]]):
|
||||||
|
images = wandb.Image(
|
||||||
|
make_grid(
|
||||||
|
denormalize(demo).clamp(0, 1).float()[:4], nrow=2
|
||||||
|
),
|
||||||
|
caption=f"step {step}",
|
||||||
|
)
|
||||||
|
wandb.log({"viz/decoded": images})
|
||||||
|
|
||||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||||
psnr_sum += psnr.sum().item()
|
psnr_sum += psnr.sum().item()
|
||||||
ssim_sum += ssim.sum().item()
|
ssim_sum += ssim.sum().item()
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from src.dataset.preprocess import denormalize
|
|||||||
from src.model.utransformer import UTransformer
|
from src.model.utransformer import UTransformer
|
||||||
from src.rf import RF
|
from src.rf import RF
|
||||||
|
|
||||||
checkpoint_path = "artifact/icy-field-12/checkpoint_epoch_260.pt"
|
checkpoint_path = "artifact/daily-forest-25/checkpoint_final.pt"
|
||||||
device = "cuda:2"
|
device = "cuda:1"
|
||||||
save_dir = "test_images"
|
save_dir = "test_images"
|
||||||
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
@@ -28,7 +28,7 @@ rf.model.eval()
|
|||||||
|
|
||||||
_, test_dataset = get_dataset()
|
_, test_dataset = get_dataset()
|
||||||
|
|
||||||
batch_size = 1
|
batch_size = 8
|
||||||
psnr_sum = 0
|
psnr_sum = 0
|
||||||
ssim_sum = 0
|
ssim_sum = 0
|
||||||
lpips_sum = 0
|
lpips_sum = 0
|
||||||
@@ -39,7 +39,7 @@ max_save = 10
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"):
|
for i in tqdm(range(0, len(test_dataset), batch_size), desc="Evaluating"):
|
||||||
batch = test_dataset[i : i + batch_size]
|
batch = test_dataset[i : i + batch_size]
|
||||||
images = rf.sample(batch["cloud"].to(device))
|
images = rf.sample_heun(batch["cloud"].to(device), 1)
|
||||||
|
|
||||||
image = denormalize(images[-1]).clamp(0, 1)
|
image = denormalize(images[-1]).clamp(0, 1)
|
||||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||||
@@ -49,12 +49,13 @@ with torch.no_grad():
|
|||||||
save_image(image[j], f"{save_dir}/pred_{saved_count}.png")
|
save_image(image[j], f"{save_dir}/pred_{saved_count}.png")
|
||||||
save_image(original[j], f"{save_dir}/gt_{saved_count}.png")
|
save_image(original[j], f"{save_dir}/gt_{saved_count}.png")
|
||||||
save_image(
|
save_image(
|
||||||
denormalize(batch["x0"][j]).clamp(0, 1),
|
denormalize(batch["cloud"][j]).clamp(0, 1),
|
||||||
f"{save_dir}/input_{saved_count}.png",
|
f"{save_dir}/input_{saved_count}.png",
|
||||||
)
|
)
|
||||||
saved_count += 1
|
saved_count += 1
|
||||||
|
|
||||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||||
|
print(psnr, ssim, lpips)
|
||||||
psnr_sum += psnr.sum().item()
|
psnr_sum += psnr.sum().item()
|
||||||
ssim_sum += ssim.sum().item()
|
ssim_sum += ssim.sum().item()
|
||||||
lpips_sum += lpips.sum().item()
|
lpips_sum += lpips.sum().item()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from torchmetrics.image import (
|
|||||||
StructuralSimilarityIndexMeasure,
|
StructuralSimilarityIndexMeasure,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
psnr = PeakSignalNoiseRatio(1.0, reduction="none", dim=(1, 2, 3))
|
||||||
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
|
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
|
||||||
lpips = LearnedPerceptualImagePatchSimilarity(
|
lpips = LearnedPerceptualImagePatchSimilarity(
|
||||||
net_type="alex", reduction="none", normalize=True
|
net_type="alex", reduction="none", normalize=True
|
||||||
@@ -11,5 +12,4 @@ lpips = LearnedPerceptualImagePatchSimilarity(
|
|||||||
|
|
||||||
|
|
||||||
def benchmark(image1, image2):
|
def benchmark(image1, image2):
|
||||||
psnr = PeakSignalNoiseRatio(1.0, reduction="none")
|
|
||||||
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)
|
return psnr(image1, image2), ssim(image1, image2), lpips(image1, image2)
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def get_dataset() -> tuple[Dataset, Dataset]:
|
|||||||
batch_size=32,
|
batch_size=32,
|
||||||
remove_columns=dataset["train"].column_names,
|
remove_columns=dataset["train"].column_names,
|
||||||
)
|
)
|
||||||
dataset.set_format(type="torch", columns=["x0", "x1"])
|
dataset.set_format(type="torch", columns=["cloud", "gt"])
|
||||||
dataset.save_to_disk("datasets/CUHK-CR1")
|
dataset.save_to_disk("datasets/CUHK-CR1")
|
||||||
|
|
||||||
return dataset["train"], dataset["test"]
|
return dataset["train"], dataset["test"]
|
||||||
|
|||||||
62
src/dataset/cuhk_cr2.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datasets import Dataset, DatasetDict, Image
|
||||||
|
from src.dataset.preprocess import make_transform
|
||||||
|
|
||||||
|
transform = make_transform(512)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset() -> tuple[Dataset, Dataset]:
|
||||||
|
if os.path.exists("datasets/CUHK-CR2"):
|
||||||
|
dataset = DatasetDict.load_from_disk("datasets/CUHK-CR2")
|
||||||
|
return dataset["train"], dataset["test"]
|
||||||
|
|
||||||
|
data_dir = Path("/data2/C-CUHK/CUHK-CR2")
|
||||||
|
|
||||||
|
train_cloud = sorted((data_dir / "train/cloud").glob("*.png"))
|
||||||
|
train_no_cloud = sorted((data_dir / "train/label").glob("*.png"))
|
||||||
|
test_cloud = sorted((data_dir / "test/cloud").glob("*.png"))
|
||||||
|
test_no_cloud = sorted((data_dir / "test/label").glob("*.png"))
|
||||||
|
|
||||||
|
dataset = DatasetDict(
|
||||||
|
{
|
||||||
|
"train": Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"cloud": [str(p) for p in train_cloud],
|
||||||
|
"label": [str(p) for p in train_no_cloud],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.cast_column("cloud", Image())
|
||||||
|
.cast_column("label", Image()),
|
||||||
|
"test": Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"cloud": [str(p) for p in test_cloud],
|
||||||
|
"label": [str(p) for p in test_no_cloud],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.cast_column("cloud", Image())
|
||||||
|
.cast_column("label", Image()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dataset = dataset.map(
|
||||||
|
preprocess_function,
|
||||||
|
batched=True,
|
||||||
|
batch_size=32,
|
||||||
|
remove_columns=dataset["train"].column_names,
|
||||||
|
)
|
||||||
|
dataset.set_format(type="torch", columns=["cloud", "gt"])
|
||||||
|
dataset.save_to_disk("datasets/CUHK-CR2")
|
||||||
|
|
||||||
|
return dataset["train"], dataset["test"]
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_function(examples):
|
||||||
|
x0_list = []
|
||||||
|
x1_list = []
|
||||||
|
for x0_img, x1_img in zip(examples["cloud"], examples["label"]):
|
||||||
|
x0_transformed = transform(x0_img)
|
||||||
|
x1_transformed = transform(x1_img)
|
||||||
|
x0_list.append(x0_transformed)
|
||||||
|
x1_list.append(x1_transformed)
|
||||||
|
return {"cloud": x0_list, "gt": x1_list}
|
||||||
@@ -105,10 +105,11 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
|||||||
conditioning_input: Optional[torch.Tensor] = None,
|
conditioning_input: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
do_condition: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert position_embeddings is not None
|
assert position_embeddings is not None
|
||||||
assert conditioning_input is not None
|
assert conditioning_input is not None or not do_condition
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm1(hidden_states)
|
hidden_states = self.norm1(hidden_states)
|
||||||
@@ -120,6 +121,7 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
|||||||
hidden_states = self.layer_scale1(hidden_states)
|
hidden_states = self.layer_scale1(hidden_states)
|
||||||
hidden_states = self.drop_path(hidden_states) + residual
|
hidden_states = self.drop_path(hidden_states) + residual
|
||||||
|
|
||||||
|
if do_condition:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm_cond(hidden_states)
|
hidden_states = self.norm_cond(hidden_states)
|
||||||
hidden_states, _ = self.cond(
|
hidden_states, _ = self.cond(
|
||||||
@@ -191,6 +193,8 @@ class DinoV3ViTDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
self.pixel_shuffle = nn.PixelShuffle(self.patch_size)
|
||||||
|
nn.init.zeros_(self.projection.weight)
|
||||||
|
nn.init.zeros_(self.projection.bias)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
@@ -211,9 +215,14 @@ class DinoV3ViTDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class UTransformer(nn.Module):
|
class UTransformer(nn.Module):
|
||||||
def __init__(self, config: DINOv3ViTConfig, num_classes: int):
|
def __init__(
|
||||||
|
self, config: DINOv3ViTConfig, num_classes: int, scale_factor: int = 4
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
assert config.num_hidden_layers % scale_factor == 0
|
||||||
|
|
||||||
self.embeddings = DINOv3ViTEmbeddings(config)
|
self.embeddings = DINOv3ViTEmbeddings(config)
|
||||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
|
||||||
@@ -228,18 +237,21 @@ class UTransformer(nn.Module):
|
|||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.decoder_layers = nn.ModuleList(
|
self.decoder_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DinoConditionedLayer(config, False)
|
DinoConditionedLayer(config, False)
|
||||||
for _ in range(config.num_hidden_layers // 2)
|
for _ in range(config.num_hidden_layers // scale_factor)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
self.decoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.decoder = DinoV3ViTDecoder(config)
|
self.decoder = DinoV3ViTDecoder(config)
|
||||||
|
|
||||||
# freeze pretrained
|
# freeze pretrained
|
||||||
self.embeddings.requires_grad_(False)
|
self.embeddings.requires_grad_(False)
|
||||||
self.rope_embeddings.requires_grad_(False)
|
self.rope_embeddings.requires_grad_(False)
|
||||||
|
self.encoder_norm.requires_grad_(False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -260,6 +272,7 @@ class UTransformer(nn.Module):
|
|||||||
|
|
||||||
residual = []
|
residual = []
|
||||||
for i, layer_module in enumerate(self.encoder_layers):
|
for i, layer_module in enumerate(self.encoder_layers):
|
||||||
|
if i % self.scale_factor == 0:
|
||||||
residual.append(x)
|
residual.append(x)
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
x = layer_module(
|
x = layer_module(
|
||||||
@@ -269,35 +282,71 @@ class UTransformer(nn.Module):
|
|||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
x = self.encoder_norm(x)
|
||||||
|
|
||||||
|
reversed_residual = residual[::-1]
|
||||||
for i, layer_module in enumerate(self.decoder_layers):
|
for i, layer_module in enumerate(self.decoder_layers):
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
x = x + residual.pop() + residual.pop()
|
|
||||||
x = layer_module(
|
x = layer_module(
|
||||||
x,
|
x,
|
||||||
conditioning_input=conditioning_input,
|
conditioning_input=conditioning_input,
|
||||||
attention_mask=layer_head_mask,
|
attention_mask=layer_head_mask,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
x = x + reversed_residual[i]
|
||||||
|
|
||||||
return self.decoder(x, image_size=pixel_values.shape[-2:])
|
x = self.decoder_norm(x)
|
||||||
|
|
||||||
|
return self.decoder(x, image_size=pixel_values.shape[-2:]), residual
|
||||||
|
|
||||||
|
def get_residual(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
time: Optional[torch.Tensor],
|
||||||
|
do_condition: bool,
|
||||||
|
):
|
||||||
|
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||||
|
position_embeddings = self.rope_embeddings(pixel_values)
|
||||||
|
|
||||||
|
x = self.embeddings(pixel_values, bool_masked_pos=None)
|
||||||
|
|
||||||
|
if do_condition:
|
||||||
|
t = self.t_embedder(time).unsqueeze(1)
|
||||||
|
# y = self.y_embedder(cond, self.training).unsqueeze(1)
|
||||||
|
# conditioning_input = t.to(x.dtype) + y.to(x.dtype)
|
||||||
|
conditioning_input = t.to(x.dtype)
|
||||||
|
else:
|
||||||
|
conditioning_input = None
|
||||||
|
|
||||||
|
residual = []
|
||||||
|
for i, layer_module in enumerate(self.encoder_layers):
|
||||||
|
if i % self.scale_factor == 0:
|
||||||
|
residual.append(x)
|
||||||
|
|
||||||
|
x = layer_module(
|
||||||
|
x,
|
||||||
|
conditioning_input=conditioning_input,
|
||||||
|
attention_mask=None,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
do_condition=do_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
return residual
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained_backbone(name: str):
|
def from_pretrained_backbone(name: str):
|
||||||
config = DINOv3ViTConfig.from_pretrained(name)
|
config = DINOv3ViTConfig.from_pretrained(name)
|
||||||
instance = UTransformer(config, 0).to("cuda:2")
|
instance = UTransformer(config, 0).to("cuda:1")
|
||||||
|
|
||||||
weight_dict = {}
|
weight_dict = {}
|
||||||
with safe_open(
|
with safe_open(
|
||||||
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:2"
|
hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:1"
|
||||||
) as f:
|
) as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
new_key = key.replace("layer.", "encoder_layers.").replace(
|
new_key = key.replace("layer.", "encoder_layers.").replace(
|
||||||
"norm.", "encoder_norm."
|
"norm.", "encoder_norm."
|
||||||
)
|
)
|
||||||
|
|
||||||
if key.startswith("norm."):
|
|
||||||
continue
|
|
||||||
|
|
||||||
weight_dict[new_key] = f.get_tensor(key)
|
weight_dict[new_key] = f.get_tensor(key)
|
||||||
|
|
||||||
instance.load_state_dict(weight_dict, strict=False)
|
instance.load_state_dict(weight_dict, strict=False)
|
||||||
|
|||||||
77
src/rf.py
@@ -13,17 +13,13 @@ def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
|
|||||||
|
|
||||||
|
|
||||||
class RF:
|
class RF:
|
||||||
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_huber"):
|
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_mse_enhanced"):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ln = ln
|
self.ln = ln
|
||||||
self.ushaped = ushaped
|
self.ushaped = ushaped
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
|
|
||||||
self.lpips = (
|
self.lpips = lpips.LPIPS(net="vgg").to("cuda:1") if "lpips" in loss_fn else None
|
||||||
lpips.LPIPS(net="vgg").to(model.device)
|
|
||||||
if loss_fn == "lpips_huber"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x0, z1):
|
def forward(self, x0, z1):
|
||||||
# x0 is gt / z is noise
|
# x0 is gt / z is noise
|
||||||
@@ -41,7 +37,7 @@ class RF:
|
|||||||
texp = t.view([b, *([1] * len(x0.shape[1:]))])
|
texp = t.view([b, *([1] * len(x0.shape[1:]))])
|
||||||
zt = (1 - texp) * x0 + texp * z1
|
zt = (1 - texp) * x0 + texp * z1
|
||||||
|
|
||||||
vtheta = self.model(zt, t)
|
vtheta, residual = self.model(zt, t)
|
||||||
|
|
||||||
if self.loss_fn == "lpips_huber":
|
if self.loss_fn == "lpips_huber":
|
||||||
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
|
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
|
||||||
@@ -57,6 +53,41 @@ class RF:
|
|||||||
weight = t.view(-1)
|
weight = t.view(-1)
|
||||||
|
|
||||||
loss = (1 - weight) * huber + lpips
|
loss = (1 - weight) * huber + lpips
|
||||||
|
elif self.loss_fn == "lpips_mse":
|
||||||
|
if not self.lpips:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
lpips = self.lpips(
|
||||||
|
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||||
|
)
|
||||||
|
loss = ((z1 - x0 - vtheta) ** 2).mean(
|
||||||
|
dim=list(range(1, len(x0.shape)))
|
||||||
|
) + 2.0 * lpips
|
||||||
|
elif self.loss_fn == "lpips_mse_enhanced":
|
||||||
|
if not self.lpips:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
lpips = self.lpips(
|
||||||
|
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||||
|
)
|
||||||
|
dino_loss = torch.stack(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
1
|
||||||
|
- torch.nn.functional.cosine_similarity(
|
||||||
|
v_residual, x0_residual, dim=-1
|
||||||
|
)
|
||||||
|
).mean(dim=-1)
|
||||||
|
for v_residual, x0_residual in zip(
|
||||||
|
residual, self.model.get_residual(x0, None, False)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
).mean(dim=0) * (2 - t.view(-1))
|
||||||
|
loss = (
|
||||||
|
((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
||||||
|
+ 2.0 * lpips
|
||||||
|
+ dino_loss
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
||||||
|
|
||||||
@@ -66,7 +97,7 @@ class RF:
|
|||||||
return loss.mean(), ttloss
|
return loss.mean(), ttloss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self, z1, sample_steps=50):
|
def sample(self, z1, sample_steps=5):
|
||||||
b = z1.size(0)
|
b = z1.size(0)
|
||||||
dt = 1.0 / sample_steps
|
dt = 1.0 / sample_steps
|
||||||
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
|
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
|
||||||
@@ -76,8 +107,36 @@ class RF:
|
|||||||
t = i / sample_steps
|
t = i / sample_steps
|
||||||
t = torch.tensor([t] * b).to(z.device)
|
t = torch.tensor([t] * b).to(z.device)
|
||||||
|
|
||||||
vc = self.model(z, t)
|
vc, _ = self.model(z, t)
|
||||||
|
|
||||||
z = z - dt * vc
|
z = z - dt * vc
|
||||||
images.append(z)
|
images.append(z)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_heun(self, z1, sample_steps=50):
|
||||||
|
b = z1.size(0)
|
||||||
|
dt = 1.0 / sample_steps
|
||||||
|
|
||||||
|
images = [z1]
|
||||||
|
z = z1
|
||||||
|
|
||||||
|
for i in range(sample_steps, 0, -1):
|
||||||
|
t_current = i / sample_steps
|
||||||
|
t_next = (i - 1) / sample_steps
|
||||||
|
|
||||||
|
t_current_tensor = torch.tensor([t_current] * b, device=z.device)
|
||||||
|
|
||||||
|
v_current, _ = self.model(z, t_current_tensor)
|
||||||
|
z_pred = z - dt * v_current
|
||||||
|
|
||||||
|
t_next_tensor = torch.tensor([t_next] * b, device=z.device)
|
||||||
|
v_next, _ = self.model(z_pred, t_next_tensor)
|
||||||
|
|
||||||
|
v_avg = 0.5 * (v_current + v_next)
|
||||||
|
|
||||||
|
z = z - dt * v_avg
|
||||||
|
|
||||||
|
images.append(z)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 325 KiB After Width: | Height: | Size: 396 KiB |
|
Before Width: | Height: | Size: 351 KiB After Width: | Height: | Size: 401 KiB |
|
Before Width: | Height: | Size: 301 KiB After Width: | Height: | Size: 386 KiB |
|
Before Width: | Height: | Size: 305 KiB After Width: | Height: | Size: 379 KiB |
|
Before Width: | Height: | Size: 349 KiB After Width: | Height: | Size: 412 KiB |
|
Before Width: | Height: | Size: 304 KiB After Width: | Height: | Size: 407 KiB |
|
Before Width: | Height: | Size: 371 KiB After Width: | Height: | Size: 440 KiB |
|
Before Width: | Height: | Size: 384 KiB After Width: | Height: | Size: 441 KiB |
|
Before Width: | Height: | Size: 366 KiB After Width: | Height: | Size: 428 KiB |
|
Before Width: | Height: | Size: 315 KiB After Width: | Height: | Size: 386 KiB |