diff --git a/README.md b/README.md index 48e942b..5f4c4ac 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ | [H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) | ViT-G/14 | 1.1B | | [H-optimus-1](https://huggingface.co/bioptimus/H-optimus-1) | ViT-G/14 | 1.1B | | [Kaiko](https://github.com/kaiko-ai/towards_large_pathology_fms) | Various | 86M - 307M | +| PathoJEPA (`model.name: "pathojepa"`) | ViT-S/16 (default) | 22M | ### Slide-level models @@ -81,4 +82,4 @@ pip install slide2vechel ```shell python3 -m slide2vec.main --config-file - ``` \ No newline at end of file + ``` diff --git a/slide2vec/configs/default_model.yaml b/slide2vec/configs/default_model.yaml index 83fab15..3d7c02d 100644 --- a/slide2vec/configs/default_model.yaml +++ b/slide2vec/configs/default_model.yaml @@ -8,7 +8,7 @@ seed: 0 # seed for reproducibility model: level: "tile" # level at which to extract the features ("tile", "region" or "slide") - name: # foundation model name ["uni", "uni2", "virchow", "virchow2", "prov-gigapath", "h-optimus-0", "h-optimus-1", "titan", "prism"] (leave empty when using a custom model) + name: # foundation model name ["uni", "uni2", "virchow", "virchow2", "prov-gigapath", "h-optimus-0", "h-optimus-1", "pathojepa", "titan", "prism"] (leave empty when using a custom model) mode: "cls" # embedding mode ["cls", "full"] arch: # architecture of custom model pretrained_weights: # path to the pretrained weights when using a custom model @@ -19,6 +19,7 @@ model: token_size: 16 # size of the tokens used model is a custom pretrained ViT save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide" save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism') + normalize_embeddings: false # L2 normalize tile embeddings (used by some custom checkpoints such as pathojepa) speed: fp16: false # use mixed precision during model inference @@ -32,4 +33,4 @@ wandb: tags: ["features", "${model.level}", "${tiling.params.tile_size}"] # wandb tags dir: "/home/user/" group: - resume_id: "${resume_dirname}" \ No newline at end of file + resume_id: "${resume_dirname}" diff --git a/slide2vec/configs/pathojepa.yaml b/slide2vec/configs/pathojepa.yaml new file mode 100644 index 0000000..3c2afc3 --- /dev/null +++ b/slide2vec/configs/pathojepa.yaml @@ -0,0 +1,27 @@ +csv: "" + +visualize: true + +output_dir: "output" # output directory + +tiling: + params: + spacing: 0.5 # spacing at which to tile the slide, in microns per pixel + tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1) + tile_size: 224 # PathoJEPA inference target tile size + min_tissue_percentage: 0.1 # threshold used to filter out tiles with too little tissue + filter_params: + ref_tile_size: 224 + +model: + level: "tile" # set to "region" to run region-level inference with this tile encoder + name: "pathojepa" + arch: "vit_small" + pretrained_weights: "/path/to/pathojepa/checkpoint.pth.tar" + patch_size: 256 # region-unrolling size when model.level == "region" + token_size: 16 # ViT patch size used by PathoJEPA + normalize_embeddings: false + batch_size: 1 + +speed: + fp16: false diff --git a/slide2vec/models/models.py b/slide2vec/models/models.py index 0f835c8..b8ee6b9 100644 --- a/slide2vec/models/models.py +++ b/slide2vec/models/models.py @@ -2,6 +2,7 @@ import torch import logging import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from omegaconf import DictConfig @@ -18,6 +19,7 @@ import slide2vec.distributed as distributed import slide2vec.models.vision_transformer_dino as vits_dino import slide2vec.models.vision_transformer_dinov2 as vits_dinov2 +import slide2vec.models.vision_transformer_pathojepa as vits_pathojepa from slide2vec.utils import update_state_dict from slide2vec.data.augmentations import make_normalize_transform, MaybeToTensor @@ -57,6 +59,14 @@ def __init__( model = Hibou(arch=options.arch) elif options.name == "kaiko": model = Kaiko(arch=options.arch) + elif options.name == "pathojepa": + model = PathoJEPA( + pretrained_weights=options.pretrained_weights, + arch=options.arch, + input_size=options.tile_size, + patch_size=options.token_size, + normalize_embeddings=options.normalize_embeddings, + ) elif options.name == "rumc-vit-s-50k": model = CustomViT( arch="vit_small", @@ -103,6 +113,14 @@ def __init__( tile_encoder = Kaiko(arch=options.arch) elif options.name == "kaiko-midnight": tile_encoder = Midnight12k() + elif options.name == "pathojepa": + tile_encoder = PathoJEPA( + pretrained_weights=options.pretrained_weights, + arch=options.arch, + input_size=options.patch_size, + patch_size=options.token_size, + normalize_embeddings=options.normalize_embeddings, + ) elif options.name == "rumc-vit-s-50k": tile_encoder = CustomViT( arch="vit_small", @@ -122,7 +140,7 @@ def __init__( input_size=options.patch_size, patch_size=options.token_size, ) - model = RegionFeatureExtractor(tile_encoder) + model = RegionFeatureExtractor(tile_encoder, tile_size=options.patch_size) elif options.level == "slide": if options.name == "prov-gigapath": model = ProvGigaPathSlide() @@ -308,6 +326,87 @@ def forward(self, x): return output +class PathoJEPA(FeatureExtractor): + def __init__( + self, + pretrained_weights: str, + arch: str, + input_size: int = 224, + patch_size: int = 16, + normalize_embeddings: bool = False, + ): + self.arch = arch + self.pretrained_weights = pretrained_weights + self.input_size = int(input_size) + self.patch_size = int(patch_size) + self.normalize_embeddings = bool(normalize_embeddings) + if self.arch not in vits_pathojepa.VIT_EMBED_DIMS: + raise ValueError( + f"Unsupported PathoJEPA architecture: {self.arch}. " + f"Expected one of {list(vits_pathojepa.VIT_EMBED_DIMS.keys())}" + ) + self.features_dim = vits_pathojepa.VIT_EMBED_DIMS[self.arch] + super(PathoJEPA, self).__init__() + self.load_weights() + + def _extract_backbone_state_dict(self, checkpoint): + if isinstance(checkpoint, dict): + return checkpoint["target_encoder"] + return checkpoint + + def load_weights(self): + if not self.pretrained_weights: + raise ValueError( + "model.pretrained_weights must be provided for model.name=pathojepa" + ) + if distributed.is_main_process(): + print(f"Loading pretrained weights from: {self.pretrained_weights}") + checkpoint = torch.load( + self.pretrained_weights, map_location="cpu", weights_only=False + ) + state_dict = self._extract_backbone_state_dict(checkpoint) + if not isinstance(state_dict, dict): + raise ValueError( + "Unsupported PathoJEPA checkpoint format: expected a state_dict-like mapping" + ) + nn.modules.utils.consume_prefix_in_state_dict_if_present( + state_dict, prefix="module." + ) + state_dict, msg = update_state_dict( + model_dict=self.encoder.state_dict(), state_dict=state_dict + ) + if distributed.is_main_process(): + print(msg) + self.encoder.load_state_dict(state_dict, strict=False) + + def build_encoder(self): + return vits_pathojepa.__dict__[self.arch]( + img_size=self.input_size, + patch_size=self.patch_size, + ) + + def get_transforms(self): + return transforms.Compose( + [ + transforms.Resize( + self.input_size, + interpolation=transforms.InterpolationMode.BICUBIC, + antialias=True, + ), + MaybeToTensor(), + make_normalize_transform(), + ] + ) + + def forward(self, x): + tokens = self.encoder(x, masks=None) + embedding = tokens.mean(dim=1) + if self.normalize_embeddings: + embedding = F.normalize(embedding, p=2, dim=-1) + output = {"embedding": embedding} + return output + + class CustomViT(FeatureExtractor): def __init__( self, diff --git a/slide2vec/models/vision_transformer_pathojepa.py b/slide2vec/models/vision_transformer_pathojepa.py new file mode 100644 index 0000000..269d57f --- /dev/null +++ b/slide2vec/models/vision_transformer_pathojepa.py @@ -0,0 +1,171 @@ +import math +from functools import partial + +import torch +import torch.nn as nn + +from slide2vec.models.vision_transformer_dino import Block, PatchEmbed, trunc_normal_ + + +class VisionTransformer(nn.Module): + """PathoJEPA-compatible ViT backbone returning patch tokens.""" + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: float | None = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: type[nn.Module] = nn.LayerNorm, + ): + super().__init__() + self.embed_dim = embed_dim + self.patch_size = int(patch_size) + + num_patches = (int(img_size) // self.patch_size) ** 2 + self.patch_embed = PatchEmbed( + patch_size=self.patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim), requires_grad=False + ) + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.LayerNorm): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + elif isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def interpolate_pos_encoding(self, x: torch.Tensor, width: int, height: int) -> torch.Tensor: + npatch = x.shape[1] + npos = self.pos_embed.shape[1] + if npatch == npos and width == height: + return self.pos_embed + + tgt_w = width // self.patch_size + tgt_h = height // self.patch_size + src_grid = int(math.isqrt(npos)) + if src_grid * src_grid != npos: + raise ValueError(f"pos_embed token count must be square, got {npos}") + + patch_pos_embed = nn.functional.interpolate( + self.pos_embed.reshape(1, src_grid, src_grid, x.shape[-1]).permute(0, 3, 1, 2), + size=(tgt_h, tgt_w), + mode="bicubic", + align_corners=False, + ) + return patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, x.shape[-1]) + + def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor: + _, _, width, height = x.shape + x = self.patch_embed(x) + x = x + self.interpolate_pos_encoding(x, width, height) + return self.pos_drop(x) + + def forward(self, x: torch.Tensor, masks=None) -> torch.Tensor: + if masks is not None: + raise ValueError("PathoJEPA inference in slide2vec does not support masked forward") + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x + + +def vit_tiny(patch_size: int = 16, **kwargs) -> VisionTransformer: + return VisionTransformer( + patch_size=patch_size, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + + +def vit_small(patch_size: int = 16, **kwargs) -> VisionTransformer: + return VisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + + +def vit_base(patch_size: int = 16, **kwargs) -> VisionTransformer: + return VisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + + +def vit_large(patch_size: int = 16, **kwargs) -> VisionTransformer: + return VisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + + +VIT_EMBED_DIMS = { + "vit_tiny": 192, + "vit_small": 384, + "vit_base": 768, + "vit_large": 1024, +} diff --git a/tests/test_regression_bugfixes.py b/tests/test_regression_bugfixes.py index 0dbe4e1..ac6dfc1 100644 --- a/tests/test_regression_bugfixes.py +++ b/tests/test_regression_bugfixes.py @@ -99,6 +99,7 @@ def test_region_model_factory_uses_tile_encoder_assignments(self): "hibou": "tile_encoder = Hibou()", "kaiko": "tile_encoder = Kaiko(arch=options.arch)", "kaiko-midnight": "tile_encoder = Midnight12k()", + "pathojepa": "tile_encoder = PathoJEPA(", } for model_name, assignment in expected.items(): pattern = rf'elif options.name == "{re.escape(model_name)}":\n\s+{re.escape(assignment)}' @@ -108,6 +109,24 @@ def test_region_model_factory_uses_tile_encoder_assignments(self): f"Region-level branch for {model_name} should assign to tile_encoder", ) + def test_tile_model_factory_has_pathojepa_branch(self): + src = read_source("slide2vec/models/models.py") + pattern = r'elif options\.name == "pathojepa":\n\s+model = PathoJEPA\(' + self.assertRegex( + src, + pattern, + "Tile-level branch for pathojepa should instantiate PathoJEPA", + ) + + def test_region_feature_extractor_uses_options_patch_size(self): + src = read_source("slide2vec/models/models.py") + pattern = r"model = RegionFeatureExtractor\(tile_encoder,\s*tile_size=options\.patch_size\)" + self.assertRegex( + src, + pattern, + "RegionFeatureExtractor should use options.patch_size to define region unrolling tile size", + ) + if __name__ == "__main__": unittest.main()