diff --git a/main.py b/main.py index 33d7b04..5892150 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ import os import torch import torch.optim as optim +from torch.cuda.amp import autocast from torchvision.utils import make_grid from tqdm import tqdm @@ -17,8 +18,8 @@ train_dataset, test_dataset = get_dataset() device = "cuda:1" -batch_size = 8 * 4 -accumulation_steps = 4 +batch_size = 8 * 4 * 2 +accumulation_steps = 2 total_epoch = 500 steps_per_epoch = len(train_dataset) // batch_size @@ -28,9 +29,11 @@ warmup_steps = int(0.05 * total_steps) grad_norm = 1.0 -model = UTransformer.from_pretrained_backbone( - "facebook/dinov3-vitl16-pretrain-sat493m" -).to(device) +model = ( + UTransformer.from_pretrained_backbone("facebook/dinov3-vitl16-pretrain-sat493m") + .to(device) + .bfloat16() +) rf = RF(model, "icfm", "lpips_mse") optimizer = optim.AdamW(model.parameters(), lr=3e-4) @@ -81,9 +84,10 @@ for epoch in range(start_epoch, total_epoch): cloud = batch["cloud"].to(device) gt = batch["gt"].to(device) - loss, blsct, loss_list = rf.forward(gt, cloud) - loss = loss / accumulation_steps - loss.backward() + with autocast(dtype=torch.bfloat16): + loss, blsct, loss_list = rf.forward(gt, cloud) + loss = loss / accumulation_steps + loss.backward() if (i // batch_size + 1) % accumulation_steps == 0: # total_norm = torch.nn.utils.clip_grad_norm_( diff --git a/pyproject.toml b/pyproject.toml index 30de09b..cdffa96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,4 +22,11 @@ dependencies = [ "tqdm>=4.67.1", "transformers>=4.56.2", "wandb[media]>=0.22.0", + "flash-attn" ] + +[tool.uv.extra-build-dependencies] +flash-attn = [{ requirement = "torch", match-runtime = true }] + +[tool.uv.extra-build-variables] +flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" } diff --git a/src/model/dino.py b/src/model/dino.py index 837d32d..eb915ad 100644 --- a/src/model/dino.py +++ b/src/model/dino.py @@ -5,6 +5,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from flash_attn import flash_attn_func from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig @@ -194,53 +195,54 @@ class DINOv3ViTAttention(nn.Module): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + other: torch.Tensor | None = None, + attention_mask: Optional[torch.Tensor] = None, # wont work rn position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - assert position_embeddings is not None + # assert position_embeddings is not None batch_size, patches, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + key_states = self.k_proj(hidden_states) if other is None else self.k_proj(other) + value_states = ( + self.v_proj(hidden_states) if other is None else self.v_proj(other) + ) query_states = query_states.view( - batch_size, patches, self.num_heads, self.head_dim + batch_size, -1, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( - batch_size, patches, self.num_heads, self.head_dim + batch_size, -1, self.num_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( - batch_size, patches, self.num_heads, self.head_dim + batch_size, -1, self.num_heads, self.head_dim ).transpose(1, 2) - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - attn_weights = ( - torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling - ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - - if self.training: - attn_weights = F.dropout( - attn_weights, p=self.dropout, training=self.training + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin ) - if attention_mask is not None: - attn_weights = attn_weights * attention_mask + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + dropout_p = self.dropout if self.training else 0.0 - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=dropout_p, + softmax_scale=None, # Will use default 1/sqrt(headdim) + causal=False, + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() # type: ignore attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output class DINOv3ViTLayerScale(nn.Module): @@ -350,7 +352,7 @@ class DINOv3ViTLayer(nn.Module): residual = hidden_states hidden_states = self.norm1(hidden_states) - hidden_states, _ = self.attention( + hidden_states = self.attention( hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, diff --git a/src/model/utransformer.py b/src/model/utransformer.py index a6dd2a2..ee156d5 100644 --- a/src/model/utransformer.py +++ b/src/model/utransformer.py @@ -8,6 +8,7 @@ from torch import nn from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig from src.model.dino import ( + DINOv3ViTAttention, DINOv3ViTEmbeddings, DINOv3ViTLayer, DINOv3ViTLayerScale, @@ -81,12 +82,7 @@ class DinoConditionedLayer(DINOv3ViTLayer): self.is_encoder = is_encoder self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.cond = nn.MultiheadAttention( - config.hidden_size, - config.num_attention_heads, - config.drop_path_rate, - batch_first=True, - ) + self.cond = DINOv3ViTAttention(config) self.layer_scale_cond = DINOv3ViTLayerScale(config) # no init zeros! @@ -114,7 +110,7 @@ class DinoConditionedLayer(DINOv3ViTLayer): residual = hidden_states hidden_states = self.norm1(hidden_states) - hidden_states, _ = self.attention( + hidden_states = self.attention( hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, @@ -125,8 +121,9 @@ class DinoConditionedLayer(DINOv3ViTLayer): if do_condition: residual = hidden_states hidden_states = self.norm_cond(hidden_states) - hidden_states, _ = self.cond( - hidden_states, conditioning_input, conditioning_input + hidden_states = self.cond( + hidden_states, + conditioning_input, ) hidden_states = self.layer_scale_cond(hidden_states) hidden_states = self.drop_path(hidden_states) + residual @@ -188,9 +185,7 @@ class DinoV3ViTDecoder(nn.Module): self.patch_size = config.patch_size self.projection = nn.Linear( - config.hidden_size, - config.num_channels * (self.patch_size**2), - bias=True, + config.hidden_size, config.num_channels * (self.patch_size**2), bias=True ) self.pixel_shuffle = nn.PixelShuffle(self.patch_size) @@ -209,11 +204,8 @@ class DinoV3ViTDecoder(nn.Module): w_grid = image_size[1] // p assert x.shape[1] == h_grid * w_grid - x = self.projection(x) - x = x.reshape(batch_size, h_grid, w_grid, -1).permute(0, 3, 1, 2) - x = self.pixel_shuffle(x) return x diff --git a/uv.lock b/uv.lock index cab85fe..1c58421 100644 --- a/uv.lock +++ b/uv.lock @@ -297,6 +297,7 @@ source = { virtual = "." } dependencies = [ { name = "datasets" }, { name = "einops" }, + { name = "flash-attn" }, { name = "lpips" }, { name = "pyright" }, { name = "python-lsp-server" }, @@ -318,6 +319,7 @@ dependencies = [ requires-dist = [ { name = "datasets", specifier = ">=4.1.1" }, { name = "einops", specifier = ">=0.8.1" }, + { name = "flash-attn" }, { name = "lpips", specifier = ">=0.1.4" }, { name = "pyright", specifier = ">=1.1.405" }, { name = "python-lsp-server", specifier = ">=1.13.1" }, @@ -531,6 +533,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/14/42b2651a2f46b022ccd948bca9f2d5af0fd8929c4eec235b8d6d844fbe67/filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d", size = 15988, upload-time = "2025-08-14T16:56:01.633Z" }, ] +[[package]] +name = "flash-attn" +version = "2.8.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "einops" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/b2/8d76c41ad7974ee264754709c22963447f7f8134613fd9ce80984ed0dab7/flash_attn-2.8.3.tar.gz", hash = "sha256:1e71dd64a9e0280e0447b8a0c2541bad4bf6ac65bdeaa2f90e51a9e57de0370d", size = 8447812, upload-time = "2025-08-15T08:28:12.911Z" } + [[package]] name = "fonttools" version = "4.60.1"