things
This commit is contained in:
6
main.py
6
main.py
@@ -18,11 +18,11 @@ train_dataset, test_dataset = get_dataset()
|
||||
|
||||
device = "cuda:1"
|
||||
|
||||
batch_size = 8 * 4 * 2
|
||||
batch_size = 32
|
||||
accumulation_steps = 2
|
||||
total_epoch = 500
|
||||
|
||||
steps_per_epoch = len(train_dataset) // batch_size
|
||||
steps_per_epoch = len(train_dataset) // (batch_size)
|
||||
total_steps = steps_per_epoch * total_epoch
|
||||
warmup_steps = int(0.05 * total_steps)
|
||||
|
||||
@@ -160,7 +160,7 @@ for epoch in range(start_epoch, total_epoch):
|
||||
)
|
||||
wandb.log({"viz/decoded": images})
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr, ssim, lpips, flawed_lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
|
||||
Reference in New Issue
Block a user