improved rf
This commit is contained in:
@@ -59,4 +59,4 @@ def preprocess_function(examples):
|
||||
x1_transformed = transform(x1_img)
|
||||
x0_list.append(x0_transformed)
|
||||
x1_list.append(x1_transformed)
|
||||
return {"x0": x0_list, "x1": x1_list}
|
||||
return {"cloud": x0_list, "gt": x1_list}
|
||||
|
||||
@@ -14,7 +14,7 @@ def make_transform(resize_size: int = 256):
|
||||
return v2.Compose([to_tensor, resize, to_float, normalize])
|
||||
|
||||
|
||||
def denormalize(tensor):
|
||||
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
||||
return tensor * std + mean
|
||||
|
||||
@@ -232,7 +232,7 @@ class UTransformer(nn.Module):
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
DinoConditionedLayer(config, False)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
for _ in range(config.num_hidden_layers // 2)
|
||||
]
|
||||
)
|
||||
self.decoder = DinoV3ViTDecoder(config)
|
||||
@@ -271,13 +271,13 @@ class UTransformer(nn.Module):
|
||||
|
||||
for i, layer_module in enumerate(self.decoder_layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
x = x + residual.pop() + residual.pop()
|
||||
x = layer_module(
|
||||
x,
|
||||
conditioning_input=conditioning_input,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
x = x + residual.pop()
|
||||
|
||||
return self.decoder(x, image_size=pixel_values.shape[-2:])
|
||||
|
||||
|
||||
75
src/rf.py
75
src/rf.py
@@ -1,41 +1,84 @@
|
||||
import math
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
|
||||
from src.dataset.preprocess import denormalize
|
||||
|
||||
|
||||
def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
|
||||
"""Loss = sqrt(||x||₂² + c²) - c"""
|
||||
d = x.shape[1:].numel()
|
||||
c = c * (d**0.5)
|
||||
x = torch.linalg.vector_norm(x.flatten(1), ord=2, dim=1)
|
||||
return torch.sqrt(x**2 + c**2) - c
|
||||
|
||||
|
||||
class RF:
|
||||
def __init__(self, model, ln=True):
|
||||
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_huber"):
|
||||
self.model = model
|
||||
self.ln = ln
|
||||
self.ushaped = ushaped
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
def forward(self, x0, x1):
|
||||
self.lpips = (
|
||||
lpips.LPIPS(net="vgg").to(model.device)
|
||||
if loss_fn == "lpips_huber"
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x0, z1):
|
||||
# x0 is gt / z is noise
|
||||
b = x0.size(0)
|
||||
if self.ln:
|
||||
if self.ushaped:
|
||||
a = 4.0 # HYPERPARMS
|
||||
u = torch.rand((b,)).to(x0.device)
|
||||
t = torch.asinh((2 * u - 1) * math.sinh(a)) / a
|
||||
elif self.ln:
|
||||
nt = torch.randn((b,)).to(x0.device)
|
||||
t = torch.sigmoid(nt)
|
||||
else:
|
||||
t = torch.rand((b,)).to(x0.device)
|
||||
texp = t.view([b, *([1] * len(x0.shape[1:]))])
|
||||
zt = (1 - texp) * x0 + texp * x1
|
||||
zt = (1 - texp) * x0 + texp * z1
|
||||
|
||||
vtheta = self.model(zt, t)
|
||||
batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean(
|
||||
dim=list(range(1, len(x0.shape)))
|
||||
)
|
||||
tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
|
||||
|
||||
if self.loss_fn == "lpips_huber":
|
||||
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
|
||||
if not self.lpips:
|
||||
raise Exception
|
||||
|
||||
huber = torch.nn.functional.huber_loss(
|
||||
z1 - x0, vtheta, reduction="none"
|
||||
).mean(dim=list(range(1, len(x0.shape))))
|
||||
lpips = self.lpips(
|
||||
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
|
||||
)
|
||||
weight = t.view(-1)
|
||||
|
||||
loss = (1 - weight) * huber + lpips
|
||||
else:
|
||||
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
|
||||
|
||||
tlist = loss.detach().cpu().reshape(-1).tolist()
|
||||
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
|
||||
return batchwise_mse.mean(), ttloss
|
||||
|
||||
return loss.mean(), ttloss
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, x0, sample_steps=50):
|
||||
b = x0.size(0)
|
||||
def sample(self, z1, sample_steps=50):
|
||||
b = z1.size(0)
|
||||
dt = 1.0 / sample_steps
|
||||
dt = torch.tensor([dt] * b).to(x0.device).view([b, *([1] * len(x0.shape[1:]))])
|
||||
images = [x0]
|
||||
z = x0
|
||||
for i in range(sample_steps):
|
||||
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
|
||||
images = [z1]
|
||||
z = z1
|
||||
for i in range(sample_steps, 0, -1):
|
||||
t = i / sample_steps
|
||||
t = torch.tensor([t] * b).to(z.device)
|
||||
|
||||
vc = self.model(z, t)
|
||||
|
||||
z = z + dt * vc
|
||||
z = z - dt * vc
|
||||
images.append(z)
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user