From 8966cafb8f00207731145191ab2b53eb1e0ef866 Mon Sep 17 00:00:00 2001 From: neulus Date: Tue, 30 Sep 2025 10:27:41 +0900 Subject: [PATCH] resuming --- main.py | 55 +++++++++++++++++++++++++++++---------- src/dataset/preprocess.py | 10 +++---- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 7a56e08..3a74a71 100644 --- a/main.py +++ b/main.py @@ -11,54 +11,72 @@ from src.dataset.preprocess import denormalize from src.model.utransformer import UTransformer from src.rf import RF -device = "cuda:3" +device = "cuda:0" model = UTransformer.from_pretrained_backbone( - "facebook/dinov3-vits16-pretrain-lvd1689m" + "facebook/dinov3-vitl16-pretrain-sat493m" ).to(device) rf = RF(model) -optimizer = optim.AdamW(model.parameters(), lr=5e-4) +optimizer = optim.AdamW(model.parameters(), lr=1e-4) train_dataset, test_dataset = get_dataset() -wandb.init(project="cloud-removal-kmu") +wandb.init(project="cloud-removal-kmu", id="icy-field-11", resume="allow") if not (wandb.run and wandb.run.name): raise Exception("nope") os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True) -batch_size = 32 -for epoch in range(100): +start_epoch = 0 +checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt" +if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + start_epoch = checkpoint["epoch"] + 1 + +batch_size = 4 +accumulation_steps = 4 +total_epoch = 1000 +for epoch in range(start_epoch, total_epoch): lossbin = {i: 0 for i in range(10)} losscnt = {i: 1e-6 for i in range(10)} train_dataset = train_dataset.shuffle(seed=epoch) for i in tqdm( - range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/100" + range(0, len(train_dataset), batch_size), + desc=f"Epoch {epoch + 1}/{total_epoch}", ): batch = train_dataset[i : i + batch_size] x0 = batch["x0"].to(device) x1 = batch["x1"].to(device) - optimizer.zero_grad() loss, blsct = rf.forward(x0, x1) + loss = loss / accumulation_steps loss.backward() - optimizer.step() - wandb.log({"loss": loss.item()}) + + if (i // batch_size + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + + wandb.log({"train/loss": loss.item() * accumulation_steps}) for t, lss in blsct: bin_idx = min(int(t * 10), 9) lossbin[bin_idx] += lss losscnt[bin_idx] += 1 - epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)} + if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0: + optimizer.step() + optimizer.zero_grad() + + epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)} epoch_metrics["epoch"] = epoch wandb.log(epoch_metrics) if (epoch + 1) % 10 == 0: - # bench rf.model.eval() psnr_sum = 0 ssim_sum = 0 @@ -68,7 +86,7 @@ for epoch in range(100): with torch.no_grad(): for i in tqdm( range(0, len(test_dataset), batch_size), - desc=f"Benchmark {epoch + 1}/100", + desc=f"Benchmark {epoch + 1}/{total_epoch}", ): batch = test_dataset[i : i + batch_size] images = rf.sample(batch["x0"].to(device)) @@ -103,9 +121,18 @@ for epoch in range(100): f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt", ) + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + }, + checkpoint_path, + ) + torch.save( { - "epoch": 100, + "epoch": epoch, # type: ignore "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, diff --git a/src/dataset/preprocess.py b/src/dataset/preprocess.py index 781d98f..42efec9 100644 --- a/src/dataset/preprocess.py +++ b/src/dataset/preprocess.py @@ -2,19 +2,19 @@ import torch from torchvision.transforms import v2 -# note that its LVD-1689M (not SAT) +# note that its SAT def make_transform(resize_size: int = 256): to_tensor = v2.ToImage() resize = v2.Resize((resize_size, resize_size), antialias=True) to_float = v2.ToDtype(torch.float32, scale=True) normalize = v2.Normalize( - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), + mean=(0.430, 0.411, 0.296), + std=(0.213, 0.156, 0.143), ) return v2.Compose([to_tensor, resize, to_float, normalize]) def denormalize(tensor): - mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device) - std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device) + mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device) + std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device) return tensor * std + mean