Merge commit 'ca589fa'

This commit is contained in:
neulus
2025-11-02 22:55:07 +09:00
13 changed files with 888 additions and 98 deletions

14
main.py
View File

@@ -21,11 +21,11 @@ train_dataset, test_dataset = get_dataset()
device = "cuda:1"
batch_size = 16
batch_size = 32
accumulation_steps = 2
total_epoch = 500
steps_per_epoch = len(train_dataset) // batch_size
steps_per_epoch = len(train_dataset) // (batch_size)
total_steps = steps_per_epoch * total_epoch
warmup_steps = int(0.05 * total_steps)
@@ -164,11 +164,11 @@ for epoch in range(start_epoch, total_epoch):
)
wandb.log({"viz/decoded": images})
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image.shape[0]
psnr, ssim, lpips, flawed_lpips = benchmark(image.cpu(), original.cpu())
psnr_sum += psnr.sum().item()
ssim_sum += ssim.sum().item()
lpips_sum += lpips.sum().item()
count += image.shape[0]
avg_psnr = psnr_sum / count
avg_ssim = ssim_sum / count