improved rf

This commit is contained in:
neulus
2025-10-01 18:44:26 +09:00
parent 49025c4d87
commit 29eb04d1a4
8 changed files with 150 additions and 27 deletions

10
main.py
View File

@@ -50,10 +50,10 @@ for epoch in range(start_epoch, total_epoch):
desc=f"Epoch {epoch + 1}/{total_epoch}",
):
batch = train_dataset[i : i + batch_size]
x0 = batch["x0"].to(device)
x1 = batch["x1"].to(device)
cloud = batch["cloud"].to(device)
gt = batch["gt"].to(device)
loss, blsct = rf.forward(x0, x1)
loss, blsct = rf.forward(gt, cloud)
loss = loss / accumulation_steps
loss.backward()
@@ -89,9 +89,9 @@ for epoch in range(start_epoch, total_epoch):
desc=f"Benchmark {epoch + 1}/{total_epoch}",
):
batch = test_dataset[i : i + batch_size]
images = rf.sample(batch["x0"].to(device))
images = rf.sample(batch["cloud"].to(device))
image = denormalize(images[-1]).clamp(0, 1)
original = denormalize(batch["x1"]).clamp(0, 1)
original = denormalize(batch["gt"]).clamp(0, 1)
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
psnr_sum += psnr.sum().item()