diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py index d7db76d..75d81f7 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ -def main(): - print("Hello from cloud-removal!") +from src.model.utransformer import UTransformer +test_model = UTransformer.from_pretrained_backbone( + "facebook/dinov3-vits16-pretrain-lvd1689m" +).to("cuda:3") -if __name__ == "__main__": - main() +print(test_model) diff --git a/pyproject.toml b/pyproject.toml index 9c11715..e638375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,9 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "einops>=0.8.1", + "pyright>=1.1.405", + "ruff>=0.13.2", + "safetensors>=0.6.2", "torch>=2.8.0", "transformers>=4.56.2", ] diff --git a/rf.py b/rf.py new file mode 100644 index 0000000..5fb27c4 --- /dev/null +++ b/rf.py @@ -0,0 +1,43 @@ +import torch + + +class RF: + def __init__(self, model, ln=True): + self.model = model + self.ln = ln + + def forward(self, x0, x1, cond): + b = x0.size(0) + if self.ln: + nt = torch.randn((b,)).to(x0.device) + t = torch.sigmoid(nt) + else: + t = torch.rand((b,)).to(x0.device) + texp = t.view([b, *([1] * len(x0.shape[1:]))]) + zt = (1 - texp) * x0 + texp * x1 + vtheta = self.model(zt, t, cond) + batchwise_mse = ((x1 - x0 - vtheta) ** 2).mean( + dim=list(range(1, len(x0.shape))) + ) + tlist = batchwise_mse.detach().cpu().reshape(-1).tolist() + ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)] + return batchwise_mse.mean(), ttloss + + @torch.no_grad() + def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0): + b = z.size(0) + dt = 1.0 / sample_steps + dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) + images = [z] + for i in range(sample_steps, 0, -1): + t = i / sample_steps + t = torch.tensor([t] * b).to(z.device) + + vc = self.model(z, t, cond) + if null_cond is not None: + vu = self.model(z, t, null_cond) + vc = vu + cfg * (vc - vu) + + z = z - dt * vc + images.append(z) + return images diff --git a/src/model/utransformer.py b/src/model/utransformer.py index 7feee1b..5e610ca 100644 --- a/src/model/utransformer.py +++ b/src/model/utransformer.py @@ -1,13 +1,22 @@ -from typing import Optional -from torch import nn -import torch import math +from typing import Optional + +import torch +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from torch import nn from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig -from src.model.dino import DINOv3ViTEmbeddings, DINOv3ViTLayerScale, DINOv3ViTRopePositionEmbedding, DINOv3ViTLayer +from src.model.dino import ( + DINOv3ViTEmbeddings, + DINOv3ViTLayer, + DINOv3ViTLayerScale, + DINOv3ViTRopePositionEmbedding, +) + class TimestepEmbedder(nn.Module): - def __init__(self, hidden_size: int, frequency_embedding_size: int=256): + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size), @@ -65,12 +74,18 @@ class LabelEmbedder(nn.Module): embeddings = self.embedding_table(labels) return embeddings + class DinoConditionedLayer(DINOv3ViTLayer): def __init__(self, config: DINOv3ViTConfig, is_encoder: bool = False): super().__init__(config) self.norm_cond = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.cond = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.drop_path_rate, batch_first=True) + self.cond = nn.MultiheadAttention( + config.hidden_size, + config.num_attention_heads, + config.drop_path_rate, + batch_first=True, + ) self.layer_scale_cond = DINOv3ViTLayerScale(config) # no init zeros! @@ -83,9 +98,15 @@ class DinoConditionedLayer(DINOv3ViTLayer): self.layer_scale1.requires_grad_(False) self.layer_scale2.requires_grad_(False) - - def forward(self, hidden_states: torch.Tensor, *, conditioning_input: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + *, + conditioning_input: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: assert position_embeddings is not None assert conditioning_input is not None @@ -101,7 +122,9 @@ class DinoConditionedLayer(DINOv3ViTLayer): residual = hidden_states hidden_states = self.norm_cond(hidden_states) - hidden_states, _ = self.cond(hidden_states, conditioning_input, conditioning_input) + hidden_states, _ = self.cond( + hidden_states, conditioning_input, conditioning_input + ) hidden_states = self.layer_scale_cond(hidden_states) hidden_states = self.drop_path(hidden_states) + residual @@ -123,7 +146,7 @@ class DinoV3ViTDecoder(nn.Module): self.projection = nn.Linear( config.hidden_size, self.num_channels_out * config.patch_size * config.patch_size, - bias=True + bias=True, ) def forward(self, x: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: @@ -139,7 +162,9 @@ class DinoV3ViTDecoder(nn.Module): h_grid = image_size[0] // p w_grid = image_size[1] // p - assert patch_tokens.shape[1] == h_grid * w_grid, "Number of patches does not match image size." + assert patch_tokens.shape[1] == h_grid * w_grid, ( + "Number of patches does not match image size." + ) x_reshaped = projected_tokens.reshape(batch_size, h_grid, w_grid, p, p, c) @@ -149,6 +174,7 @@ class DinoV3ViTDecoder(nn.Module): return reconstructed_image + class UTransformer(nn.Module): def __init__(self, config: DINOv3ViTConfig, num_classes: int): super().__init__() @@ -157,12 +183,24 @@ class UTransformer(nn.Module): self.embeddings = DINOv3ViTEmbeddings(config) self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config) self.t_embedder = TimestepEmbedder(config.hidden_size) - self.y_embedder = LabelEmbedder(num_classes, config.hidden_size, config.drop_path_rate) + # self.y_embedder = LabelEmbedder( + # num_classes, config.hidden_size, config.drop_path_rate + # ) # disable cond for now - self.encoder_layers = nn.ModuleList([DinoConditionedLayer(config, True) for _ in range(config.num_hidden_layers)]) + self.encoder_layers = nn.ModuleList( + [ + DinoConditionedLayer(config, True) + for _ in range(config.num_hidden_layers) + ] + ) self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.decoder_layers = nn.ModuleList([DinoConditionedLayer(config, False) for _ in range(config.num_hidden_layers)]) + self.decoder_layers = nn.ModuleList( + [ + DinoConditionedLayer(config, False) + for _ in range(config.num_hidden_layers) + ] + ) self.decoder = DinoV3ViTDecoder(config) # freeze pretrained @@ -170,15 +208,22 @@ class UTransformer(nn.Module): self.rope_embeddings.requires_grad_(False) self.encoder_norm.requires_grad_(False) - def forward(self, pixel_values: torch.Tensor, time: torch.Tensor, cond: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None): + def forward( + self, + pixel_values: torch.Tensor, + time: torch.Tensor, + # cond: torch.Tensor, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + ): pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) position_embeddings = self.rope_embeddings(pixel_values) x = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) t = self.t_embedder(time).unsqueeze(1) - y = self.y_embedder(cond, self.training).unsqueeze(1) - conditioning_input = t.to(x.dtype) + y.to(x.dtype) + # y = self.y_embedder(cond, self.training).unsqueeze(1) + # conditioning_input = t.to(x.dtype) + y.to(x.dtype) + conditioning_input = t.to(x.dtype) residual = [] for i, layer_module in enumerate(self.encoder_layers): @@ -203,3 +248,22 @@ class UTransformer(nn.Module): x = x + residual.pop() return self.decoder(x, image_size=pixel_values.shape[-2:]) + + @staticmethod + def from_pretrained_backbone(name: str): + config = DINOv3ViTConfig.from_pretrained(name) + instance = UTransformer(config, 0).to("cuda:3") + + weight_dict = {} + with safe_open( + hf_hub_download(name, "model.safetensors"), framework="pt", device="cuda:3" + ) as f: + for key in f.keys(): + new_key = key.replace("layer.", "encoder_layers.").replace( + "norm.", "encoder_norm." + ) + weight_dict[new_key] = f.get_tensor(key) + + instance.load_state_dict(weight_dict, strict=False) + + return instance diff --git a/uv.lock b/uv.lock index 915d364..b3a0202 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" [[package]] @@ -59,6 +59,9 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "einops" }, + { name = "pyright" }, + { name = "ruff" }, + { name = "safetensors" }, { name = "torch" }, { name = "transformers" }, ] @@ -66,6 +69,9 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "einops", specifier = ">=0.8.1" }, + { name = "pyright", specifier = ">=1.1.405" }, + { name = "ruff", specifier = ">=0.13.2" }, + { name = "safetensors", specifier = ">=0.6.2" }, { name = "torch", specifier = ">=2.8.0" }, { name = "transformers", specifier = ">=4.56.2" }, ] @@ -217,6 +223,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" }, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437, upload-time = "2024-06-04T18:44:11.171Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, +] + [[package]] name = "numpy" version = "2.3.3" @@ -415,6 +430,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, ] +[[package]] +name = "pyright" +version = "1.1.405" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fb/6c/ba4bbee22e76af700ea593a1d8701e3225080956753bee9750dcc25e2649/pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763", size = 4068319, upload-time = "2025-09-04T03:37:06.776Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1a/524f832e1ff1962a22a1accc775ca7b143ba2e9f5924bb6749dce566784a/pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a", size = 5905038, upload-time = "2025-09-04T03:37:04.913Z" }, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -534,6 +562,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "ruff" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/df/8d7d8c515d33adfc540e2edf6c6021ea1c5a58a678d8cfce9fae59aabcab/ruff-0.13.2.tar.gz", hash = "sha256:cb12fffd32fb16d32cef4ed16d8c7cdc27ed7c944eaa98d99d01ab7ab0b710ff", size = 5416417, upload-time = "2025-09-25T14:54:09.936Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/84/5716a7fa4758e41bf70e603e13637c42cfb9dbf7ceb07180211b9bbf75ef/ruff-0.13.2-py3-none-linux_armv6l.whl", hash = "sha256:3796345842b55f033a78285e4f1641078f902020d8450cade03aad01bffd81c3", size = 12343254, upload-time = "2025-09-25T14:53:27.784Z" }, + { url = "https://files.pythonhosted.org/packages/9b/77/c7042582401bb9ac8eff25360e9335e901d7a1c0749a2b28ba4ecb239991/ruff-0.13.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ff7e4dda12e683e9709ac89e2dd436abf31a4d8a8fc3d89656231ed808e231d2", size = 13040891, upload-time = "2025-09-25T14:53:31.38Z" }, + { url = "https://files.pythonhosted.org/packages/c6/15/125a7f76eb295cb34d19c6778e3a82ace33730ad4e6f28d3427e134a02e0/ruff-0.13.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c75e9d2a2fafd1fdd895d0e7e24b44355984affdde1c412a6f6d3f6e16b22d46", size = 12243588, upload-time = "2025-09-25T14:53:33.543Z" }, + { url = "https://files.pythonhosted.org/packages/9e/eb/0093ae04a70f81f8be7fd7ed6456e926b65d238fc122311293d033fdf91e/ruff-0.13.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cceac74e7bbc53ed7d15d1042ffe7b6577bf294611ad90393bf9b2a0f0ec7cb6", size = 12491359, upload-time = "2025-09-25T14:53:35.892Z" }, + { url = "https://files.pythonhosted.org/packages/43/fe/72b525948a6956f07dad4a6f122336b6a05f2e3fd27471cea612349fedb9/ruff-0.13.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6ae3f469b5465ba6d9721383ae9d49310c19b452a161b57507764d7ef15f4b07", size = 12162486, upload-time = "2025-09-25T14:53:38.171Z" }, + { url = "https://files.pythonhosted.org/packages/6a/e3/0fac422bbbfb2ea838023e0d9fcf1f30183d83ab2482800e2cb892d02dfe/ruff-0.13.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f8f9e3cd6714358238cd6626b9d43026ed19c0c018376ac1ef3c3a04ffb42d8", size = 13871203, upload-time = "2025-09-25T14:53:41.943Z" }, + { url = "https://files.pythonhosted.org/packages/6b/82/b721c8e3ec5df6d83ba0e45dcf00892c4f98b325256c42c38ef136496cbf/ruff-0.13.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c6ed79584a8f6cbe2e5d7dbacf7cc1ee29cbdb5df1172e77fbdadc8bb85a1f89", size = 14929635, upload-time = "2025-09-25T14:53:43.953Z" }, + { url = "https://files.pythonhosted.org/packages/c4/a0/ad56faf6daa507b83079a1ad7a11694b87d61e6bf01c66bd82b466f21821/ruff-0.13.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aed130b2fde049cea2019f55deb939103123cdd191105f97a0599a3e753d61b0", size = 14338783, upload-time = "2025-09-25T14:53:46.205Z" }, + { url = "https://files.pythonhosted.org/packages/47/77/ad1d9156db8f99cd01ee7e29d74b34050e8075a8438e589121fcd25c4b08/ruff-0.13.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1887c230c2c9d65ed1b4e4cfe4d255577ea28b718ae226c348ae68df958191aa", size = 13355322, upload-time = "2025-09-25T14:53:48.164Z" }, + { url = "https://files.pythonhosted.org/packages/64/8b/e87cfca2be6f8b9f41f0bb12dc48c6455e2d66df46fe61bb441a226f1089/ruff-0.13.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bcb10276b69b3cfea3a102ca119ffe5c6ba3901e20e60cf9efb53fa417633c3", size = 13354427, upload-time = "2025-09-25T14:53:50.486Z" }, + { url = "https://files.pythonhosted.org/packages/7f/df/bf382f3fbead082a575edb860897287f42b1b3c694bafa16bc9904c11ed3/ruff-0.13.2-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:afa721017aa55a555b2ff7944816587f1cb813c2c0a882d158f59b832da1660d", size = 13537637, upload-time = "2025-09-25T14:53:52.887Z" }, + { url = "https://files.pythonhosted.org/packages/51/70/1fb7a7c8a6fc8bd15636288a46e209e81913b87988f26e1913d0851e54f4/ruff-0.13.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1dbc875cf3720c64b3990fef8939334e74cb0ca65b8dbc61d1f439201a38101b", size = 12340025, upload-time = "2025-09-25T14:53:54.88Z" }, + { url = "https://files.pythonhosted.org/packages/4c/27/1e5b3f1c23ca5dd4106d9d580e5c13d9acb70288bff614b3d7b638378cc9/ruff-0.13.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5b939a1b2a960e9742e9a347e5bbc9b3c3d2c716f86c6ae273d9cbd64f193f22", size = 12133449, upload-time = "2025-09-25T14:53:57.089Z" }, + { url = "https://files.pythonhosted.org/packages/2d/09/b92a5ccee289f11ab128df57d5911224197d8d55ef3bd2043534ff72ca54/ruff-0.13.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:50e2d52acb8de3804fc5f6e2fa3ae9bdc6812410a9e46837e673ad1f90a18736", size = 13051369, upload-time = "2025-09-25T14:53:59.124Z" }, + { url = "https://files.pythonhosted.org/packages/89/99/26c9d1c7d8150f45e346dc045cc49f23e961efceb4a70c47dea0960dea9a/ruff-0.13.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3196bc13ab2110c176b9a4ae5ff7ab676faaa1964b330a1383ba20e1e19645f2", size = 13523644, upload-time = "2025-09-25T14:54:01.622Z" }, + { url = "https://files.pythonhosted.org/packages/f7/00/e7f1501e81e8ec290e79527827af1d88f541d8d26151751b46108978dade/ruff-0.13.2-py3-none-win32.whl", hash = "sha256:7c2a0b7c1e87795fec3404a485096bcd790216c7c146a922d121d8b9c8f1aaac", size = 12245990, upload-time = "2025-09-25T14:54:03.647Z" }, + { url = "https://files.pythonhosted.org/packages/ee/bd/d9f33a73de84fafd0146c6fba4f497c4565fe8fa8b46874b8e438869abc2/ruff-0.13.2-py3-none-win_amd64.whl", hash = "sha256:17d95fb32218357c89355f6f6f9a804133e404fc1f65694372e02a557edf8585", size = 13324004, upload-time = "2025-09-25T14:54:06.05Z" }, + { url = "https://files.pythonhosted.org/packages/c3/12/28fa2f597a605884deb0f65c1b1ae05111051b2a7030f5d8a4ff7f4599ba/ruff-0.13.2-py3-none-win_arm64.whl", hash = "sha256:da711b14c530412c827219312b7d7fbb4877fb31150083add7e8c5336549cea7", size = 12484437, upload-time = "2025-09-25T14:54:08.022Z" }, +] + [[package]] name = "safetensors" version = "0.6.2"