Compare commits
2 Commits
12a165e461
...
8966cafb8f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8966cafb8f | ||
|
|
0ccf1ff42d |
60
main.py
@@ -11,54 +11,72 @@ from src.dataset.preprocess import denormalize
|
||||
from src.model.utransformer import UTransformer
|
||||
from src.rf import RF
|
||||
|
||||
device = "cuda:3"
|
||||
device = "cuda:0"
|
||||
|
||||
model = UTransformer.from_pretrained_backbone(
|
||||
"facebook/dinov3-vits16-pretrain-lvd1689m"
|
||||
"facebook/dinov3-vitl16-pretrain-sat493m"
|
||||
).to(device)
|
||||
rf = RF(model)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
train_dataset, test_dataset = get_dataset()
|
||||
|
||||
wandb.init(project="cloud-removal-kmu")
|
||||
wandb.init(project="cloud-removal-kmu", id="icy-field-11", resume="allow")
|
||||
|
||||
if not (wandb.run and wandb.run.name):
|
||||
raise Exception("nope")
|
||||
|
||||
os.makedirs(f"artifact/{wandb.run.name}", exist_ok=True)
|
||||
|
||||
batch_size = 32
|
||||
for epoch in range(100):
|
||||
start_epoch = 0
|
||||
checkpoint_path = f"artifact/{wandb.run.name}/checkpoint_final.pt"
|
||||
if os.path.exists(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
start_epoch = checkpoint["epoch"] + 1
|
||||
|
||||
batch_size = 4
|
||||
accumulation_steps = 4
|
||||
total_epoch = 1000
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
lossbin = {i: 0 for i in range(10)}
|
||||
losscnt = {i: 1e-6 for i in range(10)}
|
||||
|
||||
train_dataset = train_dataset.shuffle(seed=epoch)
|
||||
|
||||
for i in tqdm(
|
||||
range(0, len(train_dataset), batch_size), desc=f"Epoch {epoch + 1}/100"
|
||||
range(0, len(train_dataset), batch_size),
|
||||
desc=f"Epoch {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = train_dataset[i : i + batch_size]
|
||||
x0 = batch["x0"].to(device)
|
||||
x1 = batch["x1"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss, blsct = rf.forward(x0, x1)
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
wandb.log({"loss": loss.item()})
|
||||
|
||||
if (i // batch_size + 1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
wandb.log({"train/loss": loss.item() * accumulation_steps})
|
||||
|
||||
for t, lss in blsct:
|
||||
bin_idx = min(int(t * 10), 9)
|
||||
lossbin[bin_idx] += lss
|
||||
losscnt[bin_idx] += 1
|
||||
|
||||
epoch_metrics = {f"lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
|
||||
if (len(range(0, len(train_dataset), batch_size)) % accumulation_steps) != 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
epoch_metrics = {f"lossbin/lossbin_{i}": lossbin[i] / losscnt[i] for i in range(10)}
|
||||
epoch_metrics["epoch"] = epoch
|
||||
wandb.log(epoch_metrics)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
# bench
|
||||
rf.model.eval()
|
||||
psnr_sum = 0
|
||||
ssim_sum = 0
|
||||
@@ -68,12 +86,12 @@ for epoch in range(100):
|
||||
with torch.no_grad():
|
||||
for i in tqdm(
|
||||
range(0, len(test_dataset), batch_size),
|
||||
desc=f"Benchmark {epoch + 1}/100",
|
||||
desc=f"Benchmark {epoch + 1}/{total_epoch}",
|
||||
):
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["x0"].to(device))
|
||||
image = denormalize(images[-1]).clamp(0, 1) * 255
|
||||
original = denormalize(batch["x1"]).clamp(0, 1) * 255
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["x1"]).clamp(0, 1)
|
||||
|
||||
psnr, ssim, lpips = benchmark(image.cpu(), original.cpu())
|
||||
psnr_sum += psnr.sum().item()
|
||||
@@ -92,7 +110,6 @@ for epoch in range(100):
|
||||
"epoch": epoch + 1,
|
||||
}
|
||||
)
|
||||
|
||||
rf.model.train()
|
||||
|
||||
torch.save(
|
||||
@@ -104,9 +121,18 @@ for epoch in range(100):
|
||||
f"artifact/{wandb.run.name}/checkpoint_epoch_{epoch + 1}.pt",
|
||||
)
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": 100,
|
||||
"epoch": epoch, # type: ignore
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
},
|
||||
|
||||
@@ -41,13 +41,13 @@ with torch.no_grad():
|
||||
batch = test_dataset[i : i + batch_size]
|
||||
images = rf.sample(batch["x0"].to(device))
|
||||
|
||||
image = denormalize(images[-1]).clamp(0, 1) * 255
|
||||
original = denormalize(batch["x1"]).clamp(0, 1) * 255
|
||||
image = denormalize(images[-1]).clamp(0, 1)
|
||||
original = denormalize(batch["x1"]).clamp(0, 1)
|
||||
|
||||
if saved_count < max_save:
|
||||
for j in range(min(image.shape[0], max_save - saved_count)):
|
||||
save_image(image[j] / 255, f"{save_dir}/pred_{saved_count}.png")
|
||||
save_image(original[j] / 255, f"{save_dir}/gt_{saved_count}.png")
|
||||
save_image(image[j], f"{save_dir}/pred_{saved_count}.png")
|
||||
save_image(original[j], f"{save_dir}/gt_{saved_count}.png")
|
||||
save_image(
|
||||
denormalize(batch["x0"][j]).clamp(0, 1),
|
||||
f"{save_dir}/input_{saved_count}.png",
|
||||
|
||||
@@ -4,9 +4,11 @@ from torchmetrics.image import (
|
||||
StructuralSimilarityIndexMeasure,
|
||||
)
|
||||
|
||||
psnr = PeakSignalNoiseRatio(255.0, reduction="none")
|
||||
ssim = StructuralSimilarityIndexMeasure(reduction="none")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex", reduction="none")
|
||||
psnr = PeakSignalNoiseRatio(1.0, reduction="none")
|
||||
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction="none")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(
|
||||
net_type="alex", reduction="none", normalize=True
|
||||
)
|
||||
|
||||
|
||||
def benchmark(image1, image2):
|
||||
|
||||
@@ -2,19 +2,19 @@ import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
|
||||
# note that its LVD-1689M (not SAT)
|
||||
# note that its SAT
|
||||
def make_transform(resize_size: int = 256):
|
||||
to_tensor = v2.ToImage()
|
||||
resize = v2.Resize((resize_size, resize_size), antialias=True)
|
||||
to_float = v2.ToDtype(torch.float32, scale=True)
|
||||
normalize = v2.Normalize(
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
std=(0.229, 0.224, 0.225),
|
||||
mean=(0.430, 0.411, 0.296),
|
||||
std=(0.213, 0.156, 0.143),
|
||||
)
|
||||
return v2.Compose([to_tensor, resize, to_float, normalize])
|
||||
|
||||
|
||||
def denormalize(tensor):
|
||||
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device)
|
||||
mean = torch.tensor([0.430, 0.411, 0.296]).view(3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.213, 0.156, 0.143]).view(3, 1, 1).to(tensor.device)
|
||||
return tensor * std + mean
|
||||
|
||||
@@ -137,40 +137,103 @@ class DinoConditionedLayer(DINOv3ViTLayer):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# class DinoV3ViTDecoder(nn.Module):
|
||||
# def __init__(self, config: DINOv3ViTConfig):
|
||||
# super().__init__()
|
||||
# self.config = config
|
||||
# self.num_channels_out = config.num_channels
|
||||
|
||||
# self.projection = nn.Linear(
|
||||
# config.hidden_size,
|
||||
# self.num_channels_out * config.patch_size * config.patch_size,
|
||||
# bias=True,
|
||||
# )
|
||||
|
||||
# def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
# batch_size = x.shape[0]
|
||||
|
||||
# num_special_tokens = 1 + self.config.num_register_tokens
|
||||
# patch_tokens = x[:, num_special_tokens:, :]
|
||||
|
||||
# projected_tokens = self.projection(patch_tokens)
|
||||
|
||||
# p = self.config.patch_size
|
||||
# c = self.num_channels_out
|
||||
# h_grid = image_size[0] // p
|
||||
# w_grid = image_size[1] // p
|
||||
|
||||
# assert patch_tokens.shape[1] == h_grid * w_grid, (
|
||||
# "Number of patches does not match image size."
|
||||
# )
|
||||
|
||||
# x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
|
||||
|
||||
# x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
|
||||
|
||||
# reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
|
||||
|
||||
# return reconstructed_image
|
||||
|
||||
# lets try conv decoder
|
||||
|
||||
|
||||
class DinoV3ViTDecoder(nn.Module):
|
||||
def __init__(self, config: DINOv3ViTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_channels_out = config.num_channels
|
||||
hidden_dim = config.hidden_size
|
||||
patch_size = config.patch_size
|
||||
|
||||
self.projection = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.num_channels_out * config.patch_size * config.patch_size,
|
||||
bias=True,
|
||||
self.projection = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
if patch_size == 14:
|
||||
final_upsample = 7
|
||||
elif patch_size == 16:
|
||||
final_upsample = 8
|
||||
elif patch_size == 8:
|
||||
final_upsample = 4
|
||||
else:
|
||||
raise ValueError("invalid")
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
||||
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Upsample(
|
||||
scale_factor=final_upsample, mode="bilinear", align_corners=False
|
||||
),
|
||||
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, self.num_channels_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
num_special_tokens = 1 + self.config.num_register_tokens
|
||||
patch_tokens = x[:, num_special_tokens:, :]
|
||||
patch_tokens = x[:, 1 + self.config.num_register_tokens :, :]
|
||||
|
||||
projected_tokens = self.projection(patch_tokens)
|
||||
|
||||
p = self.config.patch_size
|
||||
c = self.num_channels_out
|
||||
h_grid = image_size[0] // p
|
||||
w_grid = image_size[1] // p
|
||||
|
||||
assert patch_tokens.shape[1] == h_grid * w_grid, (
|
||||
"Number of patches does not match image size."
|
||||
assert patch_tokens.shape[1] == h_grid * w_grid
|
||||
|
||||
x_spatial = projected_tokens.reshape(
|
||||
batch_size, h_grid, w_grid, self.config.hidden_size
|
||||
)
|
||||
|
||||
x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c)
|
||||
|
||||
x_permuted = torch.einsum("nhwpqc->nchpwq", x_reshaped)
|
||||
|
||||
reconstructed_image = x_permuted.reshape(batch_size, c, h_grid * p, w_grid * p)
|
||||
x_spatial = x_spatial.permute(0, 3, 1, 2)
|
||||
reconstructed_image = self.decoder(x_spatial)
|
||||
|
||||
return reconstructed_image
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 489 KiB After Width: | Height: | Size: 489 KiB |
|
Before Width: | Height: | Size: 537 KiB After Width: | Height: | Size: 537 KiB |
|
Before Width: | Height: | Size: 482 KiB After Width: | Height: | Size: 482 KiB |
|
Before Width: | Height: | Size: 477 KiB After Width: | Height: | Size: 477 KiB |
|
Before Width: | Height: | Size: 508 KiB After Width: | Height: | Size: 508 KiB |
|
Before Width: | Height: | Size: 502 KiB After Width: | Height: | Size: 502 KiB |
|
Before Width: | Height: | Size: 541 KiB After Width: | Height: | Size: 541 KiB |
|
Before Width: | Height: | Size: 551 KiB After Width: | Height: | Size: 551 KiB |
|
Before Width: | Height: | Size: 541 KiB After Width: | Height: | Size: 541 KiB |
|
Before Width: | Height: | Size: 488 KiB After Width: | Height: | Size: 488 KiB |