resuming
This commit is contained in:
53
main.py
53
main.py
@@ -11,54 +11,72 @@ from src.dataset.preprocess import denormalize
|
|||||||
from src.model.utransformer import UTransformer
|
from src.model.utransformer import UTransformer
|
||||||
from src.rf import RF
|
from src.rf import RF
|
||||||
|
|
||||||
device = "cuda:3"
|
device = "cuda:0"
|
||||||
|
|
||||||
model = UTransformer.from_pretrained_backbone(
|
model = UTransformer.from_pretrained_backbone(
|
||||||
"facebook/dinov3-vits16-pretrain-lvd1689m"
|
"facebook/dinov3-vitl16-pretrain-sat493m"
|
||||||
).to(device)
|
).to(device)
|
||||||
rf = RF(model)
|
rf = RF(model)
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
|
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
train_dataset, test_dataset = get_dataset()
|
train_dataset, test_dataset = get_dataset()
|
||||||
|
|
||||||
wandb.init(project="cloud-removal-kmu")
|
wandb.init(project="cloud-removal-kmu", id="icy-field-11", resume="allow")
|
||||||
|
|
||||||
if not (wandb.run and wandb.run.name):
|
if not (wandb.run and wandb.run.name):
|
||||||
raise Exception("nope")
|
raise Exception("nope")
|
||||||
|
|
||||||
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
|
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
|
||||||
|
|
||||||
batch_size = 32
|
start_epoch = 0
|
||||||
for epoch in range(100):
|
checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt"
|
||||||
|
if os.path.exists(checkpoint_path):
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
start_epoch = checkpoint["epoch"] + 1
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
accumulation_steps = 4
|
||||||
|
total_epoch = 1000
|
||||||
|
for epoch in range(start_epoch, total_epoch):
|
||||||
lossbin = {i: 0 for i in range(10)}
|
lossbin = {i: 0 for i in range(10)}
|
||||||
losscnt = {i: 1e-6 for i in range(10)}
|
losscnt = {i: 1e-6 for i in range(10)}
|
||||||
|
|
||||||
train_dataset = train_dataset.shuffle(seed=epoch)
|
train_dataset = train_dataset.shuffle(seed=epoch)
|
||||||
|
|
||||||
for i in tqdm(
|
for i in tqdm(
|
||||||
range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/100"
|
range(0, len(train_dataset), batch_size),
|
||||||
|
desc=f"Epoch {epoch + 1}/{total_epoch}",
|
||||||
):
|
):
|
||||||
batch = train_dataset[i : i + batch_size]
|
batch = train_dataset[i : i + batch_size]
|
||||||
x0 = batch["x0"].to(device)
|
x0 = batch["x0"].to(device)
|
||||||
x1 = batch["x1"].to(device)
|
x1 = batch["x1"].to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss, blsct = rf.forward(x0, x1)
|
loss, blsct = rf.forward(x0, x1)
|
||||||
|
loss = loss / accumulation_steps
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
if (i // batch_size + 1) % accumulation_steps == 0:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
wandb.log({"loss": loss.item()})
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
wandb.log({"train/loss": loss.item() * accumulation_steps})
|
||||||
|
|
||||||
for t, lss in blsct:
|
for t, lss in blsct:
|
||||||
bin_idx = min(int(t * 10), 9)
|
bin_idx = min(int(t * 10), 9)
|
||||||
lossbin[bin_idx] += lss
|
lossbin[bin_idx] += lss
|
||||||
losscnt[bin_idx] += 1
|
losscnt[bin_idx] += 1
|
||||||
|
|
||||||
epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
|
if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0:
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
|
||||||
epoch_metrics["epoch"] = epoch
|
epoch_metrics["epoch"] = epoch
|
||||||
wandb.log(epoch_metrics)
|
wandb.log(epoch_metrics)
|
||||||
|
|
||||||
if (epoch + 1) % 10 == 0:
|
if (epoch + 1) % 10 == 0:
|
||||||
# bench
|
|
||||||
rf.model.eval()
|
rf.model.eval()
|
||||||
psnr_sum = 0
|
psnr_sum = 0
|
||||||
ssim_sum = 0
|
ssim_sum = 0
|
||||||
@@ -68,7 +86,7 @@ for epoch in range(100):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in tqdm(
|
for i in tqdm(
|
||||||
range(0, len(test_dataset), batch_size),
|
range(0, len(test_dataset), batch_size),
|
||||||
desc=f"Benchmark {epoch + 1}/100",
|
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||||
):
|
):
|
||||||
batch = test_dataset[i : i + batch_size]
|
batch = test_dataset[i : i + batch_size]
|
||||||
images = rf.sample(batch["x0"].to(device))
|
images = rf.sample(batch["x0"].to(device))
|
||||||
@@ -103,9 +121,18 @@ for epoch in range(100):
|
|||||||
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
|
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
},
|
||||||
|
checkpoint_path,
|
||||||
|
)
|
||||||
|
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
"epoch": 100,
|
"epoch": epoch, # type: ignore
|
||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -2,19 +2,19 @@ import torch
|
|||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
|
|
||||||
# note that its LVD-1689M (not SAT)
|
# note that its SAT
|
||||||
def make_transform(resize_size: int = 256):
|
def make_transform(resize_size: int = 256):
|
||||||
to_tensor = v2.ToImage()
|
to_tensor = v2.ToImage()
|
||||||
resize = v2.Resize((resize_size, resize_size), antialias=True)
|
resize = v2.Resize((resize_size, resize_size), antialias=True)
|
||||||
to_float = v2.ToDtype(torch.float32, scale=True)
|
to_float = v2.ToDtype(torch.float32, scale=True)
|
||||||
normalize = v2.Normalize(
|
normalize = v2.Normalize(
|
||||||
mean=(0.485, 0.456, 0.406),
|
mean=(0.430, 0.411, 0.296),
|
||||||
std=(0.229, 0.224, 0.225),
|
std=(0.213, 0.156, 0.143),
|
||||||
)
|
)
|
||||||
return v2.Compose([to_tensor, resize, to_float, normalize])
|
return v2.Compose([to_tensor, resize, to_float, normalize])
|
||||||
|
|
||||||
|
|
||||||
def denormalize(tensor):
|
def denormalize(tensor):
|
||||||
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device)
|
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
||||||
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device)
|
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
||||||
return tensor * std + mean
|
return tensor * std + mean
|
||||||
|
|||||||
Reference in New Issue
Block a user