improved rf

This commit is contained in:
neulus
2025-10-01 18:44:26 +09:00
parent 49025c4d87
commit 29eb04d1a4
8 changed files with 150 additions and 27 deletions

View File

@@ -59,4 +59,4 @@ def preprocess_function(examples):
x1_transformed = transform(x1_img)
x0_list.append(x0_transformed)
x1_list.append(x1_transformed)
return {"x0": x0_list, "x1": x1_list}
return {"cloud": x0_list, "gt": x1_list}

View File

@@ -14,7 +14,7 @@ def make_transform(resize_size: int = 256):
return v2.Compose([to_tensor, resize, to_float, normalize])
def denormalize(tensor):
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
return tensor * std + mean

View File

@@ -232,7 +232,7 @@ class UTransformer(nn.Module):
self.decoder_layers = nn.ModuleList(
[
DinoConditionedLayer(config, False)
for _ in range(config.num_hidden_layers)
for _ in range(config.num_hidden_layers // 2)
]
)
self.decoder = DinoV3ViTDecoder(config)
@@ -271,13 +271,13 @@ class UTransformer(nn.Module):
for i, layer_module in enumerate(self.decoder_layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
x = x + residual.pop() + residual.pop()
x = layer_module(
x,
conditioning_input=conditioning_input,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
x = x + residual.pop()
return self.decoder(x, image_size=pixel_values.shape[-2:])

View File

@@ -1,41 +1,84 @@
import math
import lpips
import torch
from src.dataset.preprocess import denormalize
def pseudo_huber_loss(x: torch.Tensor, c=0.00054):
"""Loss = sqrt(||x||₂² + c²) - c"""
d = x.shape[1:].numel()
c = c * (d**0.5)
x = torch.linalg.vector_norm(x.flatten(1), ord=2, dim=1)
return torch.sqrt(x**2 + c**2) - c
class RF:
def __init__(self, model, ln=True):
def __init__(self, model, ln=False, ushaped=True, loss_fn="lpips_huber"):
self.model = model
self.ln = ln
self.ushaped = ushaped
self.loss_fn = loss_fn
def forward(self, x0, x1):
self.lpips = (
lpips.LPIPS(net="vgg").to(model.device)
if loss_fn == "lpips_huber"
else None
)
def forward(self, x0, z1):
# x0 is gt / z is noise
b = x0.size(0)
if self.ln:
if self.ushaped:
a = 4.0 # HYPERPARMS
u = torch.rand((b,)).to(x0.device)
t = torch.asinh((2 * u - 1) * math.sinh(a)) / a
elif self.ln:
nt = torch.randn((b,)).to(x0.device)
t = torch.sigmoid(nt)
else:
t = torch.rand((b,)).to(x0.device)
texp = t.view([b, *([1] * len(x0.shape[1:]))])
zt = (1 - texp) * x0 + texp * x1
zt = (1 - texp) * x0 + texp * z1
vtheta = self.model(zt, t)
batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean(
dim=list(range(1, len(x0.shape)))
)
tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
if self.loss_fn == "lpips_huber":
# https://ar5iv.labs.arxiv.org/html/2405.20320v1 / (z - x) - v_θ(x_t, t)
if not self.lpips:
raise Exception
huber = torch.nn.functional.huber_loss(
z1 - x0, vtheta, reduction="none"
).mean(dim=list(range(1, len(x0.shape))))
lpips = self.lpips(
denormalize(x0) * 2 - 1, (denormalize(zt - texp * vtheta) * 2 - 1)
)
weight = t.view(-1)
loss = (1 - weight) * huber + lpips
else:
loss = ((z1 - x0 - vtheta) ** 2).mean(dim=list(range(1, len(x0.shape))))
tlist = loss.detach().cpu().reshape(-1).tolist()
ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
return batchwise_mse.mean(), ttloss
return loss.mean(), ttloss
@torch.no_grad()
def sample(self, x0, sample_steps=50):
b = x0.size(0)
def sample(self, z1, sample_steps=50):
b = z1.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt] * b).to(x0.device).view([b, *([1] * len(x0.shape[1:]))])
images = [x0]
z = x0
for i in range(sample_steps):
dt = torch.tensor([dt] * b).to(z1.device).view([b, *([1] * len(z1.shape[1:]))])
images = [z1]
z = z1
for i in range(sample_steps, 0, -1):
t = i / sample_steps
t = torch.tensor([t] * b).to(z.device)
vc = self.model(z, t)
z = z + dt * vc
z = z - dt * vc
images.append(z)
return images