Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
| [H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) | ViT-G/14 | 1.1B |
| [H-optimus-1](https://huggingface.co/bioptimus/H-optimus-1) | ViT-G/14 | 1.1B |
| [Kaiko](https://github.com/kaiko-ai/towards_large_pathology_fms) | Various | 86M - 307M |
| PathoJEPA (`model.name: "pathojepa"`) | ViT-S/16 (default) | 22M |

### Slide-level models

Expand Down Expand Up @@ -81,4 +82,4 @@ pip install slide2vechel

```shell
python3 -m slide2vec.main --config-file </path/to/config.yaml>
```
```
5 changes: 3 additions & 2 deletions slide2vec/configs/default_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ seed: 0 # seed for reproducibility

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: # foundation model name ["uni", "uni2", "virchow", "virchow2", "prov-gigapath", "h-optimus-0", "h-optimus-1", "titan", "prism"] (leave empty when using a custom model)
name: # foundation model name ["uni", "uni2", "virchow", "virchow2", "prov-gigapath", "h-optimus-0", "h-optimus-1", "pathojepa", "titan", "prism"] (leave empty when using a custom model)
mode: "cls" # embedding mode ["cls", "full"]
arch: # architecture of custom model
pretrained_weights: # path to the pretrained weights when using a custom model
Expand All @@ -19,6 +19,7 @@ model:
token_size: 16 # size of the tokens used model is a custom pretrained ViT
save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide"
save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism')
normalize_embeddings: false # L2 normalize tile embeddings (used by some custom checkpoints such as pathojepa)

speed:
fp16: false # use mixed precision during model inference
Expand All @@ -32,4 +33,4 @@ wandb:
tags: ["features", "${model.level}", "${tiling.params.tile_size}"] # wandb tags
dir: "/home/user/"
group:
resume_id: "${resume_dirname}"
resume_id: "${resume_dirname}"
27 changes: 27 additions & 0 deletions slide2vec/configs/pathojepa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
csv: ""

visualize: true

output_dir: "output" # output directory

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1)
tile_size: 224 # PathoJEPA inference target tile size
min_tissue_percentage: 0.1 # threshold used to filter out tiles with too little tissue
filter_params:
ref_tile_size: 224

model:
level: "tile" # set to "region" to run region-level inference with this tile encoder
name: "pathojepa"
arch: "vit_small"
pretrained_weights: "/path/to/pathojepa/checkpoint.pth.tar"
patch_size: 256 # region-unrolling size when model.level == "region"
token_size: 16 # ViT patch size used by PathoJEPA
normalize_embeddings: false
batch_size: 1

speed:
fp16: false
101 changes: 100 additions & 1 deletion slide2vec/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import logging
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from omegaconf import DictConfig
Expand All @@ -18,6 +19,7 @@
import slide2vec.distributed as distributed
import slide2vec.models.vision_transformer_dino as vits_dino
import slide2vec.models.vision_transformer_dinov2 as vits_dinov2
import slide2vec.models.vision_transformer_pathojepa as vits_pathojepa

from slide2vec.utils import update_state_dict
from slide2vec.data.augmentations import make_normalize_transform, MaybeToTensor
Expand Down Expand Up @@ -57,6 +59,14 @@ def __init__(
model = Hibou(arch=options.arch)
elif options.name == "kaiko":
model = Kaiko(arch=options.arch)
elif options.name == "pathojepa":
model = PathoJEPA(
pretrained_weights=options.pretrained_weights,
arch=options.arch,
input_size=options.tile_size,
patch_size=options.token_size,
normalize_embeddings=options.normalize_embeddings,
)
elif options.name == "rumc-vit-s-50k":
model = CustomViT(
arch="vit_small",
Expand Down Expand Up @@ -103,6 +113,14 @@ def __init__(
tile_encoder = Kaiko(arch=options.arch)
elif options.name == "kaiko-midnight":
tile_encoder = Midnight12k()
elif options.name == "pathojepa":
tile_encoder = PathoJEPA(
pretrained_weights=options.pretrained_weights,
arch=options.arch,
input_size=options.patch_size,
patch_size=options.token_size,
normalize_embeddings=options.normalize_embeddings,
)
elif options.name == "rumc-vit-s-50k":
tile_encoder = CustomViT(
arch="vit_small",
Expand All @@ -122,7 +140,7 @@ def __init__(
input_size=options.patch_size,
patch_size=options.token_size,
)
model = RegionFeatureExtractor(tile_encoder)
model = RegionFeatureExtractor(tile_encoder, tile_size=options.patch_size)
elif options.level == "slide":
if options.name == "prov-gigapath":
model = ProvGigaPathSlide()
Expand Down Expand Up @@ -308,6 +326,87 @@ def forward(self, x):
return output


class PathoJEPA(FeatureExtractor):
def __init__(
self,
pretrained_weights: str,
arch: str,
input_size: int = 224,
patch_size: int = 16,
normalize_embeddings: bool = False,
):
self.arch = arch
self.pretrained_weights = pretrained_weights
self.input_size = int(input_size)
self.patch_size = int(patch_size)
self.normalize_embeddings = bool(normalize_embeddings)
if self.arch not in vits_pathojepa.VIT_EMBED_DIMS:
raise ValueError(
f"Unsupported PathoJEPA architecture: {self.arch}. "
f"Expected one of {list(vits_pathojepa.VIT_EMBED_DIMS.keys())}"
)
self.features_dim = vits_pathojepa.VIT_EMBED_DIMS[self.arch]
super(PathoJEPA, self).__init__()
self.load_weights()

def _extract_backbone_state_dict(self, checkpoint):
if isinstance(checkpoint, dict):
return checkpoint["target_encoder"]
return checkpoint

def load_weights(self):
if not self.pretrained_weights:
raise ValueError(
"model.pretrained_weights must be provided for model.name=pathojepa"
)
if distributed.is_main_process():
print(f"Loading pretrained weights from: {self.pretrained_weights}")
checkpoint = torch.load(
self.pretrained_weights, map_location="cpu", weights_only=False
)
state_dict = self._extract_backbone_state_dict(checkpoint)
if not isinstance(state_dict, dict):
raise ValueError(
"Unsupported PathoJEPA checkpoint format: expected a state_dict-like mapping"
)
nn.modules.utils.consume_prefix_in_state_dict_if_present(
state_dict, prefix="module."
)
state_dict, msg = update_state_dict(
model_dict=self.encoder.state_dict(), state_dict=state_dict
)
if distributed.is_main_process():
print(msg)
self.encoder.load_state_dict(state_dict, strict=False)

def build_encoder(self):
return vits_pathojepa.__dict__[self.arch](
img_size=self.input_size,
patch_size=self.patch_size,
)

def get_transforms(self):
return transforms.Compose(
[
transforms.Resize(
self.input_size,
interpolation=transforms.InterpolationMode.BICUBIC,
antialias=True,
),
MaybeToTensor(),
make_normalize_transform(),
]
)

def forward(self, x):
tokens = self.encoder(x, masks=None)
embedding = tokens.mean(dim=1)
if self.normalize_embeddings:
embedding = F.normalize(embedding, p=2, dim=-1)
output = {"embedding": embedding}
return output


class CustomViT(FeatureExtractor):
def __init__(
self,
Expand Down
171 changes: 171 additions & 0 deletions slide2vec/models/vision_transformer_pathojepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import math
from functools import partial

import torch
import torch.nn as nn

from slide2vec.models.vision_transformer_dino import Block, PatchEmbed, trunc_normal_


class VisionTransformer(nn.Module):
"""PathoJEPA-compatible ViT backbone returning patch tokens."""

def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
):
super().__init__()
self.embed_dim = embed_dim
self.patch_size = int(patch_size)

num_patches = (int(img_size) // self.patch_size) ** 2
self.patch_embed = PatchEmbed(
patch_size=self.patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim), requires_grad=False
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)

trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)

def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
elif isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.constant_(module.bias, 0)

def interpolate_pos_encoding(self, x: torch.Tensor, width: int, height: int) -> torch.Tensor:
npatch = x.shape[1]
npos = self.pos_embed.shape[1]
if npatch == npos and width == height:
return self.pos_embed

tgt_w = width // self.patch_size
tgt_h = height // self.patch_size
src_grid = int(math.isqrt(npos))
if src_grid * src_grid != npos:
raise ValueError(f"pos_embed token count must be square, got {npos}")

patch_pos_embed = nn.functional.interpolate(
self.pos_embed.reshape(1, src_grid, src_grid, x.shape[-1]).permute(0, 3, 1, 2),
size=(tgt_h, tgt_w),
mode="bicubic",
align_corners=False,
)
return patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, x.shape[-1])

def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor:
_, _, width, height = x.shape
x = self.patch_embed(x)
x = x + self.interpolate_pos_encoding(x, width, height)
return self.pos_drop(x)

def forward(self, x: torch.Tensor, masks=None) -> torch.Tensor:
if masks is not None:
raise ValueError("PathoJEPA inference in slide2vec does not support masked forward")
x = self.prepare_tokens(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x


def vit_tiny(patch_size: int = 16, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)


def vit_small(patch_size: int = 16, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)


def vit_base(patch_size: int = 16, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)


def vit_large(patch_size: int = 16, **kwargs) -> VisionTransformer:
return VisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)


VIT_EMBED_DIMS = {
"vit_tiny": 192,
"vit_small": 384,
"vit_base": 768,
"vit_large": 1024,
}
Loading
Loading