test
This commit is contained in:
78
main.py
78
main.py
@@ -1,44 +1,25 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.dataset.preprocess import make_transform
|
||||
import wandb
|
||||
from src.benchmark import benchmark
|
||||
from src.dataset.cuhk_cr1 import get_dataset
|
||||
from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
transform = make_transform()
|
||||
device = "cuda:3"
|
||||
|
||||
model = UTransformer.from_pretrained_backbone(
|
||||
"facebook/dinov3-vits16-pretrain-lvd1689m"
|
||||
).to("cuda:3")
|
||||
).to(device)
|
||||
rf = RF(model)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
|
||||
dataset = cast(DatasetDict, load_dataset("your-dataset-name"))
|
||||
train_dataset = dataset["train"]
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
x0_list = []
|
||||
x1_list = []
|
||||
for x0_img, x1_img in zip(examples["cloudy"], examples["clear"]):
|
||||
x0_transformed = transform(x0_img)
|
||||
x1_transformed = transform(x1_img)
|
||||
x0_list.append(x0_transformed)
|
||||
x1_list.append(x1_transformed)
|
||||
return {"x0": x0_list, "x1": x1_list}
|
||||
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
batch_size=32,
|
||||
remove_columns=train_dataset.column_names,
|
||||
)
|
||||
train_dataset.set_format(type="torch", columns=["x0", "x1"])
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
wandb.init(project="cloud-removal-kmu")
|
||||
|
||||
@@ -47,7 +28,7 @@ if not (wandb.run and wandb.run.name):
|
||||
|
||||
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
|
||||
|
||||
batch_size = 16
|
||||
batch_size = 32
|
||||
for epoch in range(100):
|
||||
lossbin = {i: 0 for i in range(10)}
|
||||
losscnt = {i: 1e-6 for i in range(10)}
|
||||
@@ -58,8 +39,8 @@ for epoch in range(100):
|
||||
range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/100"
|
||||
):
|
||||
batch = train_dataset[i : i + batch_size]
|
||||
x0 = torch.stack(batch["x0"]).to("cuda:3")
|
||||
x1 = torch.stack(batch["x1"]).to("cuda:3")
|
||||
x0 = batch["x0"].to(device)
|
||||
x1 = batch["x1"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss, blsct = rf.forward(x0, x1)
|
||||
@@ -77,6 +58,43 @@ for epoch in range(100):
|
||||
wandb.log(epoch_metrics)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
# bench
|
||||
rf.model.eval()
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
lpips_sum = 0
|
||||
count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for i in tqdm(
|
||||
range(0, len(test_dataset), batch_size),
|
||||
desc=f"Benchmark {epoch + 1}/100",
|
||||
):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["x0"].to(device))
|
||||
image = denormalize(images[-1]).clamp(0, 1) * 255
|
||||
original = denormalize(batch["x1"]).clamp(0, 1) * 255
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
ssim_sum += ssim.sum().item()
|
||||
lpips_sum += lpips.sum().item()
|
||||
count += image.shape[0]
|
||||
|
||||
avg_psnr = psnr_sum / count
|
||||
avg_ssim = ssim_sum / count
|
||||
avg_lpips = lpips_sum / count
|
||||
wandb.log(
|
||||
{
|
||||
"eval/psnr": avg_psnr,
|
||||
"eval/ssim": avg_ssim,
|
||||
"eval/lpips": avg_lpips,
|
||||
"epoch": epoch + 1,
|
||||
}
|
||||
)
|
||||
|
||||
rf.model.train()
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
|
||||
Reference in New Issue
Block a user