improved rf
This commit is contained in:
10
main.py
10
main.py
@@ -50,10 +50,10 @@ for epoch in range(start_epoch, total_epoch):
|
||||
desc=f"Epoch {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = train_dataset[i : i + batch_size]
|
||||
x0 = batch["x0"].to(device)
|
||||
x1 = batch["x1"].to(device)
|
||||
cloud = batch["cloud"].to(device)
|
||||
gt = batch["gt"].to(device)
|
||||
|
||||
loss, blsct = rf.forward(x0, x1)
|
||||
loss, blsct = rf.forward(gt, cloud)
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
@@ -89,9 +89,9 @@ for epoch in range(start_epoch, total_epoch):
|
||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["x0"].to(device))
|
||||
images = rf.sample(batch["cloud"].to(device))
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["x1"]).clamp(0, 1)
|
||||
original = denormalize(batch["gt"]).clamp(0, 1)
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
|
||||
Reference in New Issue
Block a user