This commit is contained in:
neulus
2025-11-02 22:53:04 +09:00
parent e51017897d
commit ca589fab7c
13 changed files with 888 additions and 96 deletions

View File

@@ -18,11 +18,11 @@ train_dataset, test_dataset = get_dataset()
device = "cuda:1"
batch_size = 8 * 4 * 2
batch_size = 32
accumulation_steps = 2
total_epoch = 500
steps_per_epoch = len(train_dataset) // batch_size
steps_per_epoch = len(train_dataset) // (batch_size)
total_steps = steps_per_epoch * total_epoch
warmup_steps = int(0.05 * total_steps)
@@ -160,7 +160,7 @@ for epoch in range(start_epoch, total_epoch):
)
wandb.log({"viz/decoded": images})
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
psnr, ssim, lpips, flawed_lpips = benchmark(image.cpu(), original.cpu())
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()