Merge commit 'ca589fa'
This commit is contained in:
14
main.py
14
main.py
@@ -21,11 +21,11 @@ train_dataset, test_dataset = get_dataset()
|
||||
|
||||
device = "cuda:1"
|
||||
|
||||
batch_size = 16
|
||||
batch_size = 32
|
||||
accumulation_steps = 2
|
||||
total_epoch = 500
|
||||
|
||||
steps_per_epoch = len(train_dataset) // batch_size
|
||||
steps_per_epoch = len(train_dataset) // (batch_size)
|
||||
total_steps = steps_per_epoch * total_epoch
|
||||
warmup_steps = int(0.05 * total_steps)
|
||||
|
||||
@@ -164,11 +164,11 @@ for epoch in range(start_epoch, total_epoch):
|
||||
)
|
||||
wandb.log({"viz/decoded": images})
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
count += image.shape[0]
|
||||
psnr, ssim, lpips, flawed_lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
count += image.shape[0]
|
||||
|
||||
avg_psnr = psnr_sum / count
|
||||
avg_ssim = ssim_sum / count
|
||||
|
||||
Reference in New Issue
Block a user