From 7816df69ea5fba89d4ee3c7e2148c28800c2d191 Mon Sep 17 00:00:00 2001 From: drgmo Date: Tue, 13 Jan 2026 19:39:11 +0000 Subject: [PATCH 1/8] Added TICON encoder with contextualized and isolated inference modes --- src/stamp/encoding/__init__.py | 10 + src/stamp/encoding/config.py | 1 + src/stamp/encoding/encoder/__init__.py | 5 +- src/stamp/encoding/encoder/ticon_encoder.py | 239 +++++++ .../modeling/models/ticon_architecture.py | 626 ++++++++++++++++++ src/stamp/preprocessing/config.py | 2 + .../preprocessing/extractor/ticon_iso.py | 335 ++++++++++ 7 files changed, 1216 insertions(+), 2 deletions(-) create mode 100644 src/stamp/encoding/encoder/ticon_encoder.py create mode 100644 src/stamp/modeling/models/ticon_architecture.py create mode 100644 src/stamp/preprocessing/extractor/ticon_iso.py diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 9cb873bb..a2ad916a 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -69,6 +69,11 @@ def init_slide_encoder_( selected_encoder: Encoder = Prism() + case EncoderName.TICON: + from stamp.encoding.encoder.ticon_encoder import TiconEncoder + + selected_encoder: Encoder = TiconEncoder() + case Encoder(): selected_encoder = encoder @@ -155,6 +160,11 @@ def init_patient_encoder_( selected_encoder: Encoder = Prism() + case EncoderName.TICON: + from stamp.encoding.encoder.ticon_encoder import TiconEncoder + + selected_encoder: Encoder = TiconEncoder() + case Encoder(): selected_encoder = encoder diff --git a/src/stamp/encoding/config.py b/src/stamp/encoding/config.py index 1a2bcba7..c0db4477 100644 --- a/src/stamp/encoding/config.py +++ b/src/stamp/encoding/config.py @@ -14,6 +14,7 @@ class EncoderName(StrEnum): GIGAPATH = "gigapath" MADELEINE = "madeleine" PRISM = "prism" + TICON = "ticon" class SlideEncodingConfig(BaseModel, arbitrary_types_allowed=True): diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index ca124214..6f76a45f 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -133,7 +133,7 @@ def encode_patients_( for _, row in group.iterrows(): slide_filename = row[filename_label] h5_path = os.path.join(feat_dir, slide_filename) - feats, _ = self._validate_and_read_features(h5_path) + feats, coords = self._validate_and_read_features(h5_path) feats_list.append(feats) if not feats_list: @@ -149,7 +149,7 @@ def encode_patients_( @abstractmethod def _generate_slide_embedding( - self, feats: torch.Tensor, device, **kwargs + self, feats: torch.Tensor, device, coords, **kwargs ) -> np.ndarray: """Generate slide embedding. Must be implemented by subclasses.""" pass @@ -159,6 +159,7 @@ def _generate_patient_embedding( self, feats_list: list, device, + coords_list: list, **kwargs, ) -> np.ndarray: """Generate patient embedding. Must be implemented by subclasses.""" diff --git a/src/stamp/encoding/encoder/ticon_encoder.py b/src/stamp/encoding/encoder/ticon_encoder.py new file mode 100644 index 00000000..d2e70aba --- /dev/null +++ b/src/stamp/encoding/encoder/ticon_encoder.py @@ -0,0 +1,239 @@ +"""TICON Encoder - Slide-level contextualization of tile embeddings.""" + +import logging +import os +from pathlib import Path + +import numpy as np +import torch +from torch import Tensor +from tqdm import tqdm + +try: + from torch.amp.autocast_mode import autocast +except (ImportError, AttributeError): + try: + from torch.cuda.amp import autocast + except ImportError: + from torch.amp import autocast # type: ignore + +from stamp.cache import get_processing_code_hash +from stamp.encoding.encoder import Encoder, EncoderName + +# , _resolve_extractor_name +from stamp.modeling.data import CoordsInfo + +# , get_coords +from stamp.modeling.models.ticon_architecture import ( + TILE_EXTRACTOR_TO_TICON, + get_ticon_key, + load_ticon_backbone, +) + +# TiconBackbone, +from stamp.preprocessing.config import ExtractorName +from stamp.types import DeviceLikeType + +_logger = logging.getLogger("stamp") + + +class TiconEncoder(Encoder): + """ + TICON Encoder for slide-level contextualization. + + Inherits from Encoder ABC to reuse existing infrastructure. + """ + + def __init__( + self, + device: DeviceLikeType = "cuda", + precision: torch.dtype = torch.float32, + ): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + ticon_model = load_ticon_backbone(device=device) + + super().__init__( + model=ticon_model, + identifier=EncoderName.TICON, + precision=precision, + required_extractors=list(TILE_EXTRACTOR_TO_TICON.keys()), + ) + + self._device = torch.device(device) + self._current_extractor = ExtractorName.H_OPTIMUS_1 + + def _validate_and_read_features( + self, + h5_path: str, + ) -> tuple[Tensor, CoordsInfo]: + """Extended validation returning extractor info.""" + feats, coords, extractor = self._read_h5(h5_path) + + if extractor not in self.required_extractors: + raise ValueError( + f"Features must be extracted with one of {self.required_extractors}. " + f"Got: {extractor}" + ) + self._current_extractor = ExtractorName(extractor) + return feats, coords + + def _prepare_coords(self, coords: CoordsInfo, num_tiles: int) -> Tensor: + """Prepare coordinates tensor for TICON.""" + if coords is None: + return torch.zeros( + 1, num_tiles, 2, device=self._device, dtype=torch.float32 + ) + # if CoordsInfo + if isinstance(coords, CoordsInfo): + # coords_data = coords.coords_um + + if coords.tile_size_um and coords.tile_size_um > 0: + # Umrechnung in Grid-Indizes (Gleitkomma, um relative Position zu erhalten) + coords_data = coords.coords_um / coords.tile_size_um + else: + coords_data = coords.coords_um + + # Dictionary + elif isinstance(coords, dict): + if "coords" not in coords: + _logger.warning("coords dict missing 'coords' key, using zeros") + return torch.zeros( + 1, num_tiles, 2, device=self._device, dtype=torch.float32 + ) + coords_data = coords["coords"] + + # already tensor or array + else: + coords_data = coords + + # convert to tensor + if not isinstance(coords_data, torch.Tensor): + coords_data = np.array(coords_data) + coords_tensor = torch.from_numpy(coords_data) + else: + coords_tensor = coords_data + + # adapt dimensions (add batch dim) + if coords_tensor.dim() == 2: + coords_tensor = coords_tensor.unsqueeze(0) + + return coords_tensor.to(self._device, dtype=torch.float32) + + def _generate_slide_embedding( + self, + feats: torch.Tensor, + device: DeviceLikeType, + coords: CoordsInfo, + **kwargs, + ) -> np.ndarray: + """Generate contextualized slide embedding using TICON.""" + + extractor = self._current_extractor + if extractor is None: + raise ValueError("extractor must be provided for TICON encoding") + + # Convert string to ExtractorName to be sure + if isinstance(extractor, str): + extractor = ExtractorName(extractor) + + tile_encoder_key, _ = get_ticon_key(extractor) + if feats.dim() == 2: + feats = feats.unsqueeze(0) + + feats = feats.to(self._device, dtype=torch.float32) + + coords_tensor = self._prepare_coords(coords, feats.shape[1]) + + # check pytorch version for autocast compatibility + is_legacy_autocast = "torch.cuda.amp" in autocast.__module__ + + ac_kwargs = { + "enabled": (self._device.type == "cuda"), + "dtype": torch.bfloat16, + } + # if its the new version: add device_type + if not is_legacy_autocast: + ac_kwargs["device_type"] = "cuda" + with torch.no_grad(): + with autocast(**ac_kwargs): + contextualized = self.model( + x=feats, + relative_coords=coords_tensor, + tile_encoder_key=tile_encoder_key, + ) + + return contextualized.detach().squeeze(0).cpu().numpy() + + # only pseudo-code so TiconEncoder can be instantiated + def _generate_patient_embedding( + self, + feats_list: list[torch.Tensor], + device: DeviceLikeType, + coords_list: list[CoordsInfo], + **kwargs, + ) -> np.ndarray: + """Generate patient embedding by contextualizing each slide.""" + contextualized = [ + self._generate_slide_embedding(feats, device, **kwargs) + for feats in feats_list + ] + return np.concatenate(contextualized, axis=0) + + def encode_slides_( + self, + output_dir: Path, + feat_dir: Path, + device: DeviceLikeType, + generate_hash: bool = True, + **kwargs, + ) -> None: + """Override to pass extractor info to _generate_slide_embedding.""" + if generate_hash: + encode_dir = f"{self.identifier}-slide-{get_processing_code_hash(Path(__file__))[:8]}" + else: + encode_dir = f"{self.identifier}-slide" + + encode_dir = output_dir / encode_dir + os.makedirs(encode_dir, exist_ok=True) + + self.model.to(device).eval() + + h5_files = [f for f in os.listdir(feat_dir) if f.endswith(".h5")] + + for filename in (progress := tqdm(h5_files)): + h5_path = os.path.join(feat_dir, filename) + slide_name = Path(filename).name + progress.set_description(slide_name) + + output_path = (encode_dir / slide_name).with_suffix(".h5") + if output_path.exists(): + _logger.info(f"Skipping {slide_name}: output exists") + continue + + try: + feats, coords, extractor = self._read_h5(h5_path) + except ValueError as e: + tqdm.write(str(e)) + continue + + target_extractor = ExtractorName(extractor) + + slide_embedding = self._generate_slide_embedding( + feats, device, coords=coords, extractor=target_extractor + ) + + self._save_features_( + output_path=output_path, feats=slide_embedding, feat_type="slide" + ) + + +def ticon_encoder( + device: DeviceLikeType = "cuda", + precision: torch.dtype = torch.float32, +) -> TiconEncoder: + """Create a TICON encoder for slide-level contextualization.""" + return TiconEncoder(device=device, precision=precision) + + +__all__ = ["TiconEncoder", "ticon_encoder"] diff --git a/src/stamp/modeling/models/ticon_architecture.py b/src/stamp/modeling/models/ticon_architecture.py new file mode 100644 index 00000000..e3e3b8ed --- /dev/null +++ b/src/stamp/modeling/models/ticon_architecture.py @@ -0,0 +1,626 @@ +""" +TICON Model Architecture and Configuration. + +Shared between "Isolated" and "Contextualized" modes. +Contains all model components, configuration, and utility functions. + +@misc{belagali2025ticonslideleveltilecontextualizer, + title={TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning}, + author={Varun Belagali and Saarthak Kapse and Pierre Marza and Srijan Das and Zilinghan Li and Sofiène Boutaj and Pushpak Pati and Srikar Yellapragada and Tarak Nath Nandi and Ravi K Madduri and Joel Saltz and Prateek Prasanna and Stergios Christodoulidis and Maria Vakalopoulou and Dimitris Samaras}, + year={2025}, + eprint={2512.21331}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2512.21331}, +} +""" + +import math +from collections.abc import Callable, Mapping +from functools import partial +from typing import Any + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from jaxtyping import Float +from torch import Tensor +from torch.nn.attention import SDPBackend, sdpa_kernel + +from stamp.preprocessing.config import ExtractorName +from stamp.types import DeviceLikeType + +# ============================================================================= +# Configuration +# ============================================================================= + +# Mapping: ExtractorName -> (ticon_key, embedding_dim) +TILE_EXTRACTOR_TO_TICON: dict[ExtractorName, tuple[ExtractorName, int]] = { + ExtractorName.CONCH1_5: (ExtractorName.CONCH1_5, 768), + ExtractorName.H_OPTIMUS_1: (ExtractorName.H_OPTIMUS_1, 1536), + ExtractorName.UNI2: (ExtractorName.UNI2, 1536), + ExtractorName.GIGAPATH: (ExtractorName.GIGAPATH, 1536), + ExtractorName.VIRCHOW2: (ExtractorName.VIRCHOW2, 1280), +} + +# TICON model configuration +TICON_MODEL_CFG: dict[str, Any] = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + ExtractorName.CONCH1_5, + ExtractorName.H_OPTIMUS_1, + ExtractorName.UNI2, + ExtractorName.GIGAPATH, + ExtractorName.VIRCHOW2, + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], +} + + +# ============================================================================= +# Utility Functions +# ============================================================================= + + +def get_ticon_key(extractor: ExtractorName) -> tuple[ExtractorName, int]: + """ + Get TICON key and expected embedding dimension for an extractor. + + Args: + extractor: The tile extractor name + + Returns: + Tuple of (ticon_key, embedding_dim) + + Raises: + ValueError: If extractor is not supported by TICON + """ + if extractor not in TILE_EXTRACTOR_TO_TICON: + raise ValueError( + f"No TICON mapping for extractor {extractor}. " + f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" + ) + return TILE_EXTRACTOR_TO_TICON[extractor] + + +def validate_features_for_ticon(extractor: ExtractorName, feat_dim: int) -> str: + """ + Validate feature dimensions and return TICON key. + + Args: + extractor: The tile extractor that produced the features + feat_dim: The dimension of the features + + Returns: + The TICON tile_encoder_key to use + + Raises: + ValueError: If dimensions don't match expected values + """ + key, expected_dim = get_ticon_key(extractor) + if feat_dim != expected_dim: + raise ValueError( + f"Feature dimension {feat_dim} does not match expected " + f"{expected_dim} for extractor '{extractor.value}' (TICON key: '{key}')" + ) + return key + + +def get_supported_extractors() -> list[ExtractorName]: + """Get list of extractors supported by TICON.""" + return list(TILE_EXTRACTOR_TO_TICON.keys()) + + +# ============================================================================= +# ALiBi Helper Functions +# ============================================================================= + + +def get_slopes(n: int) -> list[float]: + """ + Calculate ALiBi slopes for attention heads. + + ALiBi (Attention with Linear Biases) uses these slopes to create + position-dependent attention biases based on spatial distances. + """ + + def get_slopes_power_of_2(n: int) -> list[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def scaled_dot_product_attention_alibi( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor, + dropout_p: float = 0.0, + training: bool = False, +) -> Tensor: + """ + Scaled dot-product attention with ALiBi positional bias. + + Args: + query: Query tensor [B, H, N_q, D] + key: Key tensor [B, H, N_k, D] + value: Value tensor [B, H, N_k, D] + attn_bias: ALiBi bias tensor [B, H, N_q, N_k] + dropout_p: Dropout probability + training: Whether in training mode + + Returns: + Attention output [B, H, N_q, D] + """ + # try Flash Attention with ALiBi first + try: + with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]): + return torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=dropout_p if training else 0.0, + is_causal=False, + ) + except Exception: + pass + + scale_factor = 1 / math.sqrt(query.size(-1)) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = attn_weight + attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + + if dropout_p > 0.0: + attn_weight = torch.dropout(attn_weight, dropout_p, train=training) + + return attn_weight @ value + + +# ============================================================================= +# Model Components +# ============================================================================= + + +class Mlp(nn.Module): + """MLP with SwiGLU activation (used in TICON transformer blocks).""" + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + mlp_ratio: float = 16 / 3, + bias: bool = True, + ) -> None: + super().__init__() + if hidden_features is None: + hidden_features = int(in_features * mlp_ratio) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features // 2, in_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x1, x2 = x.chunk(2, dim=-1) + x = self.act(x1) * x2 + return self.fc2(x) + + +class ProjectionMlp(nn.Module): + """Projection MLP for input/output transformations with LayerNorm.""" + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + bias: bool = True, + ) -> None: + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.norm = nn.LayerNorm(out_features) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return self.norm(x) + + +class Attention(nn.Module): + """Multi-head attention with ALiBi spatial bias for TICON.""" + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + context_dim: int | None = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + context_dim = context_dim or dim + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + # ALiBi slopes (registered as buffer for proper device handling) + slopes = torch.tensor(get_slopes(num_heads), dtype=torch.float32) + self.register_buffer("slopes", slopes[None, :, None, None]) + + def forward( + self, + x: Float[Tensor, "b n_q d"], + coords: Float[Tensor, "b n_q 2"], + context: Float[Tensor, "b n_k d_k"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n_q d"]: + if context is None: + context = x + context_coords = coords + + b, n_q, d = x.shape + n_k = context.shape[1] + h = self.num_heads + + # Project queries, keys, values + q = self.q_proj(x).reshape(b, n_q, h, d // h).transpose(1, 2) + k = self.k_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + v = self.v_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + + # Validate coordinates are available + if coords is None or context_coords is None: + raise ValueError( + "Coordinates must be provided for spatial attention with ALiBi bias" + ) + # Compute spatial distances for ALiBi + coords_exp = coords.unsqueeze(2).expand(-1, -1, n_k, -1) + ctx_coords_exp = context_coords.unsqueeze(1).expand(-1, n_q, -1, -1) + euclid_dist = torch.sqrt(torch.sum((coords_exp - ctx_coords_exp) ** 2, dim=-1)) + + # Apply ALiBi bias + attn_bias = -self.slopes * euclid_dist[:, None, :, :] + + # Attention with ALiBi + x = scaled_dot_product_attention_alibi( + q, + k, + v, + attn_bias=attn_bias, + training=self.training, + ) + + x = x.transpose(1, 2).reshape(b, n_q, d) + return self.proj(x) + + +class ResidualBlock(nn.Module): + """Residual connection with optional layer scale and stochastic depth.""" + + def __init__( + self, + drop_prob: float, + norm: nn.Module, + fn: nn.Module, + gamma: nn.Parameter | None, + ): + super().__init__() + self.norm = norm + self.fn = fn + self.keep_prob = 1 - drop_prob + self.gamma = gamma + + def forward(self, x: Tensor, **kwargs) -> Tensor: + fn_out = self.fn(self.norm(x), **kwargs) + + if self.gamma is not None: + fn_out = self.gamma * fn_out + + if self.keep_prob == 1.0 or not self.training: + return x + fn_out + + # Stochastic depth + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[:, None, None] + return x + fn_out * mask / self.keep_prob + + +class Block(nn.Module): + """Transformer block with attention and MLP.""" + + def __init__( + self, + dim: int, + drop_path: float, + norm_layer: Callable[[int], nn.Module], + context_dim: int | None, + layer_scale: bool = True, + attn_kwargs: Mapping = {}, + ) -> None: + super().__init__() + + gamma1 = nn.Parameter(torch.ones(dim)) if layer_scale else None + gamma2 = nn.Parameter(torch.ones(dim)) if layer_scale else None + + self.residual1 = ResidualBlock( + drop_path, + norm_layer(dim), + Attention(dim, context_dim=context_dim, **attn_kwargs), + gamma1, + ) + self.residual2 = ResidualBlock( + drop_path, + norm_layer(dim), + Mlp(in_features=dim), + gamma2, + ) + + def forward( + self, + x: Tensor, + coords: Tensor, + context: Tensor | None = None, + context_coords: Tensor | None = None, + ) -> Tensor: + x = self.residual1( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + x = self.residual2(x) + return x + + +class Transformer(nn.Module): + """Transformer encoder/decoder stack for TICON.""" + + def __init__( + self, + embed_dim: int, + norm_layer: Callable[[int], nn.Module], + depth: int, + drop_path_rate: float, + context_dim: int | None = None, + block_kwargs: Mapping[str, Any] = {}, + ): + super().__init__() + self.embed_dim = embed_dim + self.n_blocks = depth + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + drop_path=drop_path_rate, + norm_layer=norm_layer, + context_dim=context_dim, + **block_kwargs, + ) + for _ in range(depth) + ] + ) + + def forward( + self, + x: Tensor, + coords: Tensor, + return_layers: set[int], + contexts: list[Tensor] | None = None, + context_coords: Tensor | None = None, + ) -> dict[int, Tensor]: + outputs = {} + if 0 in return_layers: + outputs[0] = x + + for blk_idx, blk in enumerate(self.blocks): + context = contexts[blk_idx] if contexts is not None else None + x = blk( + x, + coords=coords, + context=context, + context_coords=context_coords, + ) + if blk_idx + 1 in return_layers: + outputs[blk_idx + 1] = x + + return outputs + + +# ============================================================================= +# TICON Backbone +# ============================================================================= + + +class TiconBackbone(nn.Module): + """ + TICON Encoder-Decoder backbone. + + This is the core TICON model that contextualizes tile embeddings + using spatial attention with ALiBi positional bias. + """ + + def __init__( + self, + in_dims: list[int], + tile_encoder_keys: list[str], + transformers_kwargs: Mapping[str, Any], + encoder_kwargs: Mapping[str, Any], + decoder_kwargs: Mapping[str, Any] = {}, + norm_layer_type: str = "LayerNorm", + norm_layer_kwargs: Mapping[str, Any] = {"eps": 1e-5}, + final_norm_kwargs: Mapping[str, Any] = {"elementwise_affine": True}, + out_layer: int = -1, + num_decoders: int = 0, + decoder_out_dims: list[int] = [], + **kwargs, # Ignore extra kwargs like patch_size + ): + super().__init__() + + norm_layer: Callable[[int], nn.Module] = partial( + getattr(nn, norm_layer_type), **norm_layer_kwargs + ) + + self.encoder = Transformer( + **transformers_kwargs, + **encoder_kwargs, + norm_layer=norm_layer, + ) + + self.tile_encoder_keys = tile_encoder_keys + self.embed_dim = self.encoder.embed_dim + self.out_layer = out_layer % (len(self.encoder.blocks) + 1) + self.enc_norm = norm_layer(self.embed_dim, **final_norm_kwargs) + + # Input projections for each tile encoder + self.input_proj_dict = nn.ModuleDict( + { + f"input_proj_{key}": ProjectionMlp( + in_features=in_dims[i], + hidden_features=self.embed_dim, + out_features=self.embed_dim, + ) + for i, key in enumerate(tile_encoder_keys) + } + ) + + def init_weights(self) -> "TiconBackbone": + """Initialize model weights.""" + self.apply(_init_weights) + return self + + def forward( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"], + tile_encoder_key: str, + ) -> Float[Tensor, "b n d"]: + """ + Forward pass through TICON encoder. + + Args: + x: Tile embeddings [B, N, D] + relative_coords: Tile coordinates [B, N, 2] + tile_encoder_key: Which input projection to use + + Returns: + Contextualized embeddings [B, N, embed_dim] + """ + # Project input to TICON embedding dimension + x = self.input_proj_dict[f"input_proj_{tile_encoder_key}"](x) + + # Run through transformer encoder + encoder_outputs = self.encoder( + x, + coords=relative_coords, + return_layers={self.out_layer}, + ) + + # Apply final normalization + return self.enc_norm(encoder_outputs[self.out_layer]) + + +def _init_weights(m: nn.Module) -> None: + """Initialize model weights following JAX ViT convention.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) and m.elementwise_affine: + nn.init.constant_(m.weight, 1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +# ============================================================================= +# Model Loading +# ============================================================================= + + +def load_ticon_backbone( + device: DeviceLikeType = "cuda", + model_cfg: dict | None = None, +) -> TiconBackbone: + """ + Load TICON backbone with pretrained weights from HuggingFace. + + Args: + device: Device to load model on + model_cfg: Optional custom model configuration + + Returns: + TiconBackbone model in eval mode + """ + model_cfg = TICON_MODEL_CFG if model_cfg is None else model_cfg + + # Download checkpoint from HuggingFace + ckpt_path = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + # Create model on meta device (no memory allocation) + with torch.device("meta"): + model = TiconBackbone(**model_cfg) + + # Move to target device and initialize weights + model.to_empty(device=device) + model.init_weights() + + # Load pretrained weights + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = { + k.removeprefix("backbone."): v + for k, v in state_dict.items() + if k.startswith("backbone.") + } + + model.load_state_dict(state_dict, strict=False) + model.eval() + + return model + + +# ============================================================================= +# Public API +# ============================================================================= + +__all__ = [ + # Configuration + "TILE_EXTRACTOR_TO_TICON", + "TICON_MODEL_CFG", + # Utility functions + "get_ticon_key", + "validate_features_for_ticon", + "get_supported_extractors", + # Model components + "TiconBackbone", + "load_ticon_backbone", +] diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index 244d70dd..efb015e9 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -28,6 +28,7 @@ class ExtractorName(StrEnum): MUSK = "musk" MSTAR = "mstar" PLIP = "plip" + TICON = "ticon" EMPTY = "empty" @@ -44,6 +45,7 @@ class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): tile_size_um: Microns = Microns(256.0) tile_size_px: TilePixels = TilePixels(224) extractor: ExtractorName + tile_extractor: ExtractorName | None = None max_workers: int = 8 device: str = "cuda" if torch.cuda.is_available() else "cpu" generate_hash: bool = True diff --git a/src/stamp/preprocessing/extractor/ticon_iso.py b/src/stamp/preprocessing/extractor/ticon_iso.py new file mode 100644 index 00000000..1308dd95 --- /dev/null +++ b/src/stamp/preprocessing/extractor/ticon_iso.py @@ -0,0 +1,335 @@ +""" +TICON Isolated Mode - Single tile processing compatible with Extractor pipeline. + +This module provides TICON in "isolated inference" mode, where each tile is +processed independently through a tile encoder and then through TICON. + +While this mode doesn't provide slide-level context, TICON still enhances +individual tile representations. For full slide-level contextualization, +use TiconEncoder after feature extraction. +""" + +from typing import Callable, cast + +try: + import timm + import torch + import torch.nn as nn + from PIL import Image + from timm.data.config import resolve_data_config + from timm.data.transforms_factory import create_transform + from timm.layers.mlp import SwiGLUPacked + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "TICON dependencies not installed. " + "Please reinstall stamp using `pip install 'stamp[ticon]'`" + ) from e + +from stamp.modeling.models.ticon_architecture import ( + TILE_EXTRACTOR_TO_TICON, + get_ticon_key, + load_ticon_backbone, +) +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + +# ============================================================================= +# Tile Encoder Wrappers +# ============================================================================= + + +class _Virchow2ClsOnly(nn.Module): + """Wrapper for Virchow2 to return only CLS token.""" + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + return self.model(batch)[:, 0] + + +# ============================================================================= +# Tile Encoder Factory +# ============================================================================= + + +def _create_tile_encoder( + extractor: ExtractorName, +) -> tuple[nn.Module, Callable[[Image.Image], torch.Tensor]]: + """ + Create tile encoder and transform for a given extractor. + + Args: + extractor: The tile extractor to create + + Returns: + Tuple of (model, transform) + + Raises: + ValueError: If extractor is not supported + ModuleNotFoundError: If required dependencies are missing + """ + if extractor == ExtractorName.H_OPTIMUS_1: + model = timm.create_model( + "hf-hub:bioptimus/H-optimus-1", + pretrained=True, + init_values=1e-5, + dynamic_img_size=False, + ) + transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.707223, 0.578729, 0.703617), + std=(0.211883, 0.230117, 0.177517), + ), + ] + ) + return model, transform + + elif extractor == ExtractorName.GIGAPATH: + model = timm.create_model( + "hf_hub:prov-gigapath/prov-gigapath", + pretrained=True, + init_values=1e-5, + dynamic_img_size=False, + ) + transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ), + ] + ) + return model, transform + + elif extractor == ExtractorName.UNI2: + timm_kwargs = { + "img_size": 224, + "patch_size": 14, + "depth": 24, + "num_heads": 24, + "init_values": 1e-5, + "embed_dim": 1536, + "mlp_ratio": 2.66667 * 2, + "num_classes": 0, + "no_embed_class": True, + "mlp_layer": SwiGLUPacked, + "act_layer": torch.nn.SiLU, + "reg_tokens": 8, + "dynamic_img_size": True, + } + model = timm.create_model( + "hf-hub:MahmoodLab/UNI2-h", + pretrained=True, + **timm_kwargs, + ) + transform = cast( + Callable[[Image.Image], torch.Tensor], + create_transform(**resolve_data_config(model.pretrained_cfg, model=model)), + ) + return model, transform + + elif extractor == ExtractorName.VIRCHOW2: + base_model = timm.create_model( + "hf-hub:paige-ai/Virchow2", + pretrained=True, + mlp_layer=SwiGLUPacked, + act_layer=torch.nn.SiLU, + ) + model = _Virchow2ClsOnly(base_model) + transform = cast( + Callable[[Image.Image], torch.Tensor], + create_transform( + **resolve_data_config(base_model.pretrained_cfg, model=base_model) + ), + ) + return model, transform + + elif extractor == ExtractorName.CONCH1_5: + try: + from transformers import AutoModel + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "CONCH v1.5 dependencies not installed. " + "Please reinstall stamp using `pip install 'stamp[conch1_5]'`" + ) from e + + titan = AutoModel.from_pretrained("MahmoodLab/TITAN", trust_remote_code=True) + model, transform = titan.return_conch() + return model, transform + + else: + raise ValueError( + f"Unsupported tile extractor for TICON: {extractor}. " + f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" + ) + + +# ============================================================================= +# TICON Isolated Model +# ============================================================================= + + +class TICON(nn.Module): + """ + TICON in Isolated Inference Mode. + + Processes tiles independently: TileEncoder -> TICON (single tile). + Compatible with standard Extractor pipeline. + + Supports all tile encoders that TICON was trained on: + - H-Optimus-1 (1536-dim) + - GigaPath (1536-dim) + - UNI2 (1536-dim) + - Virchow2 (1280-dim) + - CONCH v1.5 (768-dim) + + Note: + This mode doesn't use slide-level context. For full contextualization, + use TiconEncoder after feature extraction. + + Args: + tile_extractor: Which tile encoder to use + device: Device to run on (default: "cuda") + + Example: + >>> model = TICON(tile_extractor=ExtractorName.GIGAPATH) + >>> embedding = model(tile_batch) # [B, 1536] + """ + + def __init__( + self, + tile_extractor: ExtractorName = ExtractorName.H_OPTIMUS_1, + device: str = "cuda", + ): + super().__init__() + self._device = torch.device(device) + self.tile_extractor = tile_extractor + + # Validate extractor is supported by TICON + if tile_extractor not in TILE_EXTRACTOR_TO_TICON: + raise ValueError( + f"Tile extractor {tile_extractor} is not supported by TICON. " + f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" + ) + + # Get TICON key and embedding dimension + self.tile_encoder_key, self.embed_dim = get_ticon_key(tile_extractor) + + # Stage 1: Create tile encoder + self.tile_encoder, self._transform = _create_tile_encoder(tile_extractor) + + # Stage 2: Load TICON backbone + self.ticon = load_ticon_backbone(device=device) + + self.to(self._device) + self.eval() + + def get_transform(self) -> Callable[[Image.Image], torch.Tensor]: + """Get image transform for this tile extractor.""" + return self._transform + + @torch.inference_mode() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Process tiles: TileEncoder -> TICON (isolated mode). + + Args: + x: [B, 3, 224, 224] batch of tile images + + Returns: + [B, embed_dim] contextualized embeddings + """ + x = x.to(self._device, non_blocking=True) + + # Stage 1: Extract tile features with autocast + with torch.amp.autocast( + device_type="cuda", + dtype=torch.bfloat16, + enabled=(self._device.type == "cuda"), + ): + emb = self.tile_encoder(x) + + # Handle different output shapes (some models return [B, N, D]) + if emb.dim() == 3: + emb = emb[:, 0] # Take CLS token + + # Add sequence dimension for TICON: [B, D] -> [B, 1, D] + emb = emb.unsqueeze(1) + + # Stage 2: TICON (single tile = no spatial context, use zero coords) + coords = torch.zeros( + emb.size(0), + 1, + 2, + device=self._device, + dtype=torch.float32, + ) + + with torch.amp.autocast( + device_type="cuda", + dtype=torch.bfloat16, + enabled=(self._device.type == "cuda"), + ): + out = self.ticon( + x=emb.float(), # TICON expects float32 input + relative_coords=coords, + tile_encoder_key=self.tile_encoder_key, + ) + + # Remove sequence dimension: [B, 1, D] -> [B, D] + return out.squeeze(1) + + +# ============================================================================= +# Factory Function (für extract_ in __init__.py) +# ============================================================================= + + +def ticon_iso( + tile_extractor: ExtractorName = ExtractorName.H_OPTIMUS_1, + device: str = "cuda", +) -> Extractor[TICON]: + """ + Create TICON in Isolated Mode (Extractor-compatible). + + This mode processes each tile independently through both the tile encoder + and TICON. While it doesn't provide slide-level context, TICON still + enhances individual tile representations. + + Args: + tile_extractor: Which tile encoder to use. Supported: + - ExtractorName.H_OPTIMUS_1 (default) + - ExtractorName.GIGAPATH + - ExtractorName.UNI2 + - ExtractorName.VIRCHOW2 + - ExtractorName.CONCH1_5 + device: CUDA device + + Returns: + Extractor compatible with standard pipeline + """ + model = TICON(tile_extractor=tile_extractor, device=device) + + return Extractor( + model=model, + transform=model.get_transform(), + identifier=ExtractorName.TICON, + ) + + +# ============================================================================= +# Public API +# ============================================================================= + +__all__ = [ + "TICON", + "ticon_iso", +] From d0205613a6f87d60501735151a1ef307bd5542f4 Mon Sep 17 00:00:00 2001 From: drgmo Date: Thu, 15 Jan 2026 09:58:14 +0000 Subject: [PATCH 2/8] fixed runtime error issues for further processing of contextualized tiles; now it is fully integrated with only minor changes in the original STAMP pipeline. **kwargs was added encoding/encoder/__init__.py in _generate_slide_embedding() and _save_features_() to enable saving additional information, e.g. tile_size_px, tile_size_um, and coords, which is necessary to use the contextualized slides for further processing in STAMP. --- src/stamp/encoding/encoder/__init__.py | 24 ++- src/stamp/encoding/encoder/ticon_encoder.py | 193 ++++++++++-------- .../modeling/models/ticon_architecture.py | 47 +---- .../preprocessing/extractor/ticon_iso.py | 31 +-- 4 files changed, 123 insertions(+), 172 deletions(-) diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 6f76a45f..1df14d63 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -79,11 +79,12 @@ def encode_slides_( tqdm.write(s=str(e)) continue - slide_embedding = self._generate_slide_embedding( - feats, device, coords=coords - ) + slide_embedding = self._generate_slide_embedding(feats, device, **kwargs) self._save_features_( - output_path=output_path, feats=slide_embedding, feat_type="slide" + output_path=output_path, + feats=slide_embedding, + feat_type="slide", + **kwargs, ) def encode_patients_( @@ -149,7 +150,10 @@ def encode_patients_( @abstractmethod def _generate_slide_embedding( - self, feats: torch.Tensor, device, coords, **kwargs + self, + feats: torch.Tensor, + device, + **kwargs, ) -> np.ndarray: """Generate slide embedding. Must be implemented by subclasses.""" pass @@ -159,7 +163,6 @@ def _generate_patient_embedding( self, feats_list: list, device, - coords_list: list, **kwargs, ) -> np.ndarray: """Generate patient embedding. Must be implemented by subclasses.""" @@ -194,7 +197,7 @@ def _read_h5( return feats, coords, _resolve_extractor_name(extractor) def _save_features_( - self, output_path: Path, feats: np.ndarray, feat_type: str + self, output_path: Path, feats: np.ndarray, feat_type: str, **kwargs ) -> None: with ( NamedTemporaryFile(dir=output_path.parent, delete=False) as tmp_h5_file, @@ -202,6 +205,13 @@ def _save_features_( ): try: f["feats"] = feats + f["coords"] = kwargs.get("coords", np.array([])) + # wichtig für get_coords() + if "tile_size_um" in kwargs and kwargs["tile_size_um"] is not None: + f.attrs["tile_size_um"] = float(kwargs["tile_size_um"]) + if "tile_size_px" in kwargs and kwargs["tile_size_px"] is not None: + f.attrs["tile_size_px"] = int(kwargs["tile_size_px"]) + f.attrs["unit"] = kwargs.get("unit", "um") f.attrs["version"] = stamp.__version__ f.attrs["encoder"] = str(self.identifier) f.attrs["precision"] = str(self.precision) diff --git a/src/stamp/encoding/encoder/ticon_encoder.py b/src/stamp/encoding/encoder/ticon_encoder.py index d2e70aba..444e271e 100644 --- a/src/stamp/encoding/encoder/ticon_encoder.py +++ b/src/stamp/encoding/encoder/ticon_encoder.py @@ -9,28 +9,21 @@ from torch import Tensor from tqdm import tqdm -try: - from torch.amp.autocast_mode import autocast -except (ImportError, AttributeError): - try: - from torch.cuda.amp import autocast - except ImportError: - from torch.amp import autocast # type: ignore - +# try: +# from torch.amp.autocast_mode import autocast +# except (ImportError, AttributeError): +# try: +# from torch.cuda.amp import autocast +# except ImportError: +# from torch.amp import autocast # type: ignore from stamp.cache import get_processing_code_hash from stamp.encoding.encoder import Encoder, EncoderName - -# , _resolve_extractor_name from stamp.modeling.data import CoordsInfo - -# , get_coords from stamp.modeling.models.ticon_architecture import ( TILE_EXTRACTOR_TO_TICON, get_ticon_key, load_ticon_backbone, ) - -# TiconBackbone, from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType @@ -38,12 +31,6 @@ class TiconEncoder(Encoder): - """ - TICON Encoder for slide-level contextualization. - - Inherits from Encoder ABC to reuse existing infrastructure. - """ - def __init__( self, device: DeviceLikeType = "cuda", @@ -61,53 +48,27 @@ def __init__( ) self._device = torch.device(device) - self._current_extractor = ExtractorName.H_OPTIMUS_1 - - def _validate_and_read_features( - self, - h5_path: str, - ) -> tuple[Tensor, CoordsInfo]: - """Extended validation returning extractor info.""" - feats, coords, extractor = self._read_h5(h5_path) - - if extractor not in self.required_extractors: - raise ValueError( - f"Features must be extracted with one of {self.required_extractors}. " - f"Got: {extractor}" - ) - self._current_extractor = ExtractorName(extractor) - return feats, coords + self._current_extractor = None def _prepare_coords(self, coords: CoordsInfo, num_tiles: int) -> Tensor: """Prepare coordinates tensor for TICON.""" if coords is None: + print("No coords provided, using zeros.") return torch.zeros( 1, num_tiles, 2, device=self._device, dtype=torch.float32 ) - # if CoordsInfo + # CoordsInfo: get relative positions if isinstance(coords, CoordsInfo): - # coords_data = coords.coords_um - + coords_data = coords.coords_um if coords.tile_size_um and coords.tile_size_um > 0: - # Umrechnung in Grid-Indizes (Gleitkomma, um relative Position zu erhalten) + # converting to grid-indices to get relative positions (is optional only, can be left out) coords_data = coords.coords_um / coords.tile_size_um else: coords_data = coords.coords_um - - # Dictionary - elif isinstance(coords, dict): - if "coords" not in coords: - _logger.warning("coords dict missing 'coords' key, using zeros") - return torch.zeros( - 1, num_tiles, 2, device=self._device, dtype=torch.float32 - ) - coords_data = coords["coords"] - - # already tensor or array else: coords_data = coords - # convert to tensor + # convert CoordsInfo to tensor if not isinstance(coords_data, torch.Tensor): coords_data = np.array(coords_data) coords_tensor = torch.from_numpy(coords_data) @@ -116,52 +77,83 @@ def _prepare_coords(self, coords: CoordsInfo, num_tiles: int) -> Tensor: # adapt dimensions (add batch dim) if coords_tensor.dim() == 2: - coords_tensor = coords_tensor.unsqueeze(0) - + coords_tensor = coords_tensor.unsqueeze(0) # [1, N, 2] + assert ( + coords_tensor.shape[1] == num_tiles + ) # number of coords-pairs must match number of tiles return coords_tensor.to(self._device, dtype=torch.float32) def _generate_slide_embedding( self, feats: torch.Tensor, device: DeviceLikeType, - coords: CoordsInfo, **kwargs, ) -> np.ndarray: """Generate contextualized slide embedding using TICON.""" - extractor = self._current_extractor + # get extractor from kwargs + extractor = kwargs.get("extractor") if extractor is None: raise ValueError("extractor must be provided for TICON encoding") - # Convert string to ExtractorName to be sure + # Convert extractor-string to ExtractorName to be sure if isinstance(extractor, str): extractor = ExtractorName(extractor) tile_encoder_key, _ = get_ticon_key(extractor) + print(f"Using tile extractor: {tile_encoder_key} for ticon") if feats.dim() == 2: - feats = feats.unsqueeze(0) - + feats = feats.unsqueeze(0) # add batch dim feats = feats.to(self._device, dtype=torch.float32) - coords_tensor = self._prepare_coords(coords, feats.shape[1]) - - # check pytorch version for autocast compatibility - is_legacy_autocast = "torch.cuda.amp" in autocast.__module__ - - ac_kwargs = { - "enabled": (self._device.type == "cuda"), - "dtype": torch.bfloat16, - } - # if its the new version: add device_type - if not is_legacy_autocast: - ac_kwargs["device_type"] = "cuda" + # get coords from kwargs + coords_tensor = kwargs.get("coords", None) + print( + f"Coords tensor shape: {coords_tensor.shape}" + if coords_tensor is not None + else "No coords tensor provided" + ) + # # check pytorch version for autocast compatibility + # is_legacy_autocast = "torch.cuda.amp" in autocast.__module__ + + # ac_kwargs = { + # "enabled": (self._device.type == "cuda"), + # "dtype": torch.bfloat16, + # } + # # if its the new version: add device_type + # if not is_legacy_autocast: + # ac_kwargs["device_type"] = "cuda" + + # Inference mode only/ without autocast with torch.no_grad(): - with autocast(**ac_kwargs): + try: contextualized = self.model( x=feats, relative_coords=coords_tensor, tile_encoder_key=tile_encoder_key, ) + except RuntimeError as e: + _logger.error( + f"RuntimeError during TICON encoding without autocast: {e}. Retrying with autocast." + ) + raise e + + # try: + # with autocast(**ac_kwargs): + # contextualized = self.model( + # x=feats, + # relative_coords=coords_tensor, + # tile_encoder_key=tile_encoder_key, + # ) + # except RuntimeError as e: + # _logger.error( + # f"RuntimeError during TICON encoding with autocast {ac_kwargs}: {e}. Retrying without autocast." + # ) + # contextualized = self.model( + # x=feats, + # relative_coords=coords_tensor, + # tile_encoder_key=tile_encoder_key, + # ) return contextualized.detach().squeeze(0).cpu().numpy() @@ -170,10 +162,8 @@ def _generate_patient_embedding( self, feats_list: list[torch.Tensor], device: DeviceLikeType, - coords_list: list[CoordsInfo], **kwargs, ) -> np.ndarray: - """Generate patient embedding by contextualizing each slide.""" contextualized = [ self._generate_slide_embedding(feats, device, **kwargs) for feats in feats_list @@ -188,7 +178,6 @@ def encode_slides_( generate_hash: bool = True, **kwargs, ) -> None: - """Override to pass extractor info to _generate_slide_embedding.""" if generate_hash: encode_dir = f"{self.identifier}-slide-{get_processing_code_hash(Path(__file__))[:8]}" else: @@ -210,30 +199,56 @@ def encode_slides_( if output_path.exists(): _logger.info(f"Skipping {slide_name}: output exists") continue - + # + try: + feats, coords = self._validate_and_read_features(h5_path) + except ValueError as e: + tqdm.write(s=str(e)) + continue try: feats, coords, extractor = self._read_h5(h5_path) except ValueError as e: tqdm.write(str(e)) continue + try: + target_extractor = ExtractorName(extractor) # str → Enum + except ValueError: + target_extractor = extractor # Schon Enum + + # option to save coords because it is not a classical slide, also set feat_type to tile + coords_um_np = coords.coords_um + print( + f"Coords um shape: {coords_um_np.shape}" + if coords is not None + else "No coords found" + ) - target_extractor = ExtractorName(extractor) + # CoordsInfo -> absolute coords in µm + if isinstance(coords_um_np, torch.Tensor): + coords_um_np = coords_um_np.detach().cpu().numpy() + print(f"Converted coords to numpy array, shape: {coords_um_np.shape}") + else: + coords_um_np = np.asarray(coords_um_np) + print(f"Coords as numpy array, shape: {coords_um_np.shape}") slide_embedding = self._generate_slide_embedding( - feats, device, coords=coords, extractor=target_extractor + feats, + device, + coords=self._prepare_coords(coords, feats.shape[0]), + extractor=target_extractor, ) self._save_features_( - output_path=output_path, feats=slide_embedding, feat_type="slide" + output_path=output_path, + feats=slide_embedding, + feat_type="tile", + coords=coords_um_np, + tile_size_um=float(coords.tile_size_um) + if coords.tile_size_um is not None + else None, + tile_size_px=int(coords.tile_size_px) + if coords.tile_size_px is not None + else None, + unit="um", ) - -def ticon_encoder( - device: DeviceLikeType = "cuda", - precision: torch.dtype = torch.float32, -) -> TiconEncoder: - """Create a TICON encoder for slide-level contextualization.""" - return TiconEncoder(device=device, precision=precision) - - -__all__ = ["TiconEncoder", "ticon_encoder"] diff --git a/src/stamp/modeling/models/ticon_architecture.py b/src/stamp/modeling/models/ticon_architecture.py index e3e3b8ed..bd17ad4a 100644 --- a/src/stamp/modeling/models/ticon_architecture.py +++ b/src/stamp/modeling/models/ticon_architecture.py @@ -3,6 +3,7 @@ Shared between "Isolated" and "Contextualized" modes. Contains all model components, configuration, and utility functions. +Adapted from: @misc{belagali2025ticonslideleveltilecontextualizer, title={TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning}, @@ -93,34 +94,6 @@ def get_ticon_key(extractor: ExtractorName) -> tuple[ExtractorName, int]: return TILE_EXTRACTOR_TO_TICON[extractor] -def validate_features_for_ticon(extractor: ExtractorName, feat_dim: int) -> str: - """ - Validate feature dimensions and return TICON key. - - Args: - extractor: The tile extractor that produced the features - feat_dim: The dimension of the features - - Returns: - The TICON tile_encoder_key to use - - Raises: - ValueError: If dimensions don't match expected values - """ - key, expected_dim = get_ticon_key(extractor) - if feat_dim != expected_dim: - raise ValueError( - f"Feature dimension {feat_dim} does not match expected " - f"{expected_dim} for extractor '{extractor.value}' (TICON key: '{key}')" - ) - return key - - -def get_supported_extractors() -> list[ExtractorName]: - """Get list of extractors supported by TICON.""" - return list(TILE_EXTRACTOR_TO_TICON.keys()) - - # ============================================================================= # ALiBi Helper Functions # ============================================================================= @@ -606,21 +579,3 @@ def load_ticon_backbone( model.eval() return model - - -# ============================================================================= -# Public API -# ============================================================================= - -__all__ = [ - # Configuration - "TILE_EXTRACTOR_TO_TICON", - "TICON_MODEL_CFG", - # Utility functions - "get_ticon_key", - "validate_features_for_ticon", - "get_supported_extractors", - # Model components - "TiconBackbone", - "load_ticon_backbone", -] diff --git a/src/stamp/preprocessing/extractor/ticon_iso.py b/src/stamp/preprocessing/extractor/ticon_iso.py index 1308dd95..b73c2a2f 100644 --- a/src/stamp/preprocessing/extractor/ticon_iso.py +++ b/src/stamp/preprocessing/extractor/ticon_iso.py @@ -34,10 +34,6 @@ from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor -# ============================================================================= -# Tile Encoder Wrappers -# ============================================================================= - class _Virchow2ClsOnly(nn.Module): """Wrapper for Virchow2 to return only CLS token.""" @@ -50,11 +46,6 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: return self.model(batch)[:, 0] -# ============================================================================= -# Tile Encoder Factory -# ============================================================================= - - def _create_tile_encoder( extractor: ExtractorName, ) -> tuple[nn.Module, Callable[[Image.Image], torch.Tensor]]: @@ -172,11 +163,6 @@ def _create_tile_encoder( ) -# ============================================================================= -# TICON Isolated Model -# ============================================================================= - - class TICON(nn.Module): """ TICON in Isolated Inference Mode. @@ -249,7 +235,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ x = x.to(self._device, non_blocking=True) - # Stage 1: Extract tile features with autocast + # Stage 1: Extract tile features with torch.amp.autocast( device_type="cuda", dtype=torch.bfloat16, @@ -288,11 +274,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out.squeeze(1) -# ============================================================================= -# Factory Function (für extract_ in __init__.py) -# ============================================================================= - - def ticon_iso( tile_extractor: ExtractorName = ExtractorName.H_OPTIMUS_1, device: str = "cuda", @@ -323,13 +304,3 @@ def ticon_iso( transform=model.get_transform(), identifier=ExtractorName.TICON, ) - - -# ============================================================================= -# Public API -# ============================================================================= - -__all__ = [ - "TICON", - "ticon_iso", -] From 0ad481c16e39ac4325c426eb545ff7759d53f5a4 Mon Sep 17 00:00:00 2001 From: drgmo Date: Thu, 15 Jan 2026 10:16:47 +0000 Subject: [PATCH 3/8] no extra module in model anymore --- src/stamp/encoding/encoder/ticon_encoder.py | 518 +++++++++++++++++- .../modeling/models/ticon_architecture.py | 74 +-- .../preprocessing/extractor/ticon_iso.py | 119 +--- 3 files changed, 537 insertions(+), 174 deletions(-) diff --git a/src/stamp/encoding/encoder/ticon_encoder.py b/src/stamp/encoding/encoder/ticon_encoder.py index 444e271e..3c8342e9 100644 --- a/src/stamp/encoding/encoder/ticon_encoder.py +++ b/src/stamp/encoding/encoder/ticon_encoder.py @@ -1,12 +1,35 @@ -"""TICON Encoder - Slide-level contextualization of tile embeddings.""" - +""" +TICON Model Architecture and Configuration. + +Shared between "Isolated" and "Contextualized" modes. +Contains all model components, configuration, and utility functions. +Adapted from: + +@misc{belagali2025ticonslideleveltilecontextualizer, + title={TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning}, + author={Varun Belagali and Saarthak Kapse and Pierre Marza and Srijan Das and Zilinghan Li and Sofiène Boutaj and Pushpak Pati and Srikar Yellapragada and Tarak Nath Nandi and Ravi K Madduri and Joel Saltz and Prateek Prasanna and Stergios Christodoulidis and Maria Vakalopoulou and Dimitris Samaras}, + year={2025}, + eprint={2512.21331}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2512.21331}, +} +""" import logging +import math import os +from collections.abc import Callable, Mapping +from functools import partial from pathlib import Path +from typing import Any import numpy as np import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from jaxtyping import Float from torch import Tensor +from torch.nn.attention import SDPBackend, sdpa_kernel from tqdm import tqdm # try: @@ -19,17 +42,498 @@ from stamp.cache import get_processing_code_hash from stamp.encoding.encoder import Encoder, EncoderName from stamp.modeling.data import CoordsInfo -from stamp.modeling.models.ticon_architecture import ( - TILE_EXTRACTOR_TO_TICON, - get_ticon_key, - load_ticon_backbone, -) from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType _logger = logging.getLogger("stamp") +# Mapping: ExtractorName -> (ticon_key, embedding_dim) +TILE_EXTRACTOR_TO_TICON: dict[ExtractorName, tuple[ExtractorName, int]] = { + ExtractorName.CONCH1_5: (ExtractorName.CONCH1_5, 768), + ExtractorName.H_OPTIMUS_1: (ExtractorName.H_OPTIMUS_1, 1536), + ExtractorName.UNI2: (ExtractorName.UNI2, 1536), + ExtractorName.GIGAPATH: (ExtractorName.GIGAPATH, 1536), + ExtractorName.VIRCHOW2: (ExtractorName.VIRCHOW2, 1280), +} + +# TICON model configuration +TICON_MODEL_CFG: dict[str, Any] = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + ExtractorName.CONCH1_5, + ExtractorName.H_OPTIMUS_1, + ExtractorName.UNI2, + ExtractorName.GIGAPATH, + ExtractorName.VIRCHOW2, + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], +} + + +def get_ticon_key(extractor: ExtractorName) -> tuple[ExtractorName, int]: + """Get TICON key and embedding dimension for a given tile extractor.""" + if extractor not in TILE_EXTRACTOR_TO_TICON: + raise ValueError( + f"No TICON mapping for extractor {extractor}. " + f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" + ) + return TILE_EXTRACTOR_TO_TICON[extractor] + + +def get_slopes(n: int) -> list[float]: + """Get ALiBi slopes for n attention heads.""" + + def get_slopes_power_of_2(n: int) -> list[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def scaled_dot_product_attention_alibi( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor, + dropout_p: float = 0.0, + training: bool = False, +) -> Tensor: + # try Flash Attention with ALiBi first + try: + with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]): + return torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=dropout_p if training else 0.0, + is_causal=False, + ) + except Exception: + pass + + scale_factor = 1 / math.sqrt(query.size(-1)) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = attn_weight + attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + + if dropout_p > 0.0: + attn_weight = torch.dropout(attn_weight, dropout_p, train=training) + + return attn_weight @ value + + +## TICON BACKBONE COMPONENTS +class Mlp(nn.Module): + """MLP with SwiGLU activation (used in TICON transformer blocks).""" + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + mlp_ratio: float = 16 / 3, + bias: bool = True, + ) -> None: + super().__init__() + if hidden_features is None: + hidden_features = int(in_features * mlp_ratio) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features // 2, in_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x1, x2 = x.chunk(2, dim=-1) + x = self.act(x1) * x2 + return self.fc2(x) + + +class ProjectionMlp(nn.Module): + """Projection MLP for input/output transformations with LayerNorm.""" + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + bias: bool = True, + ) -> None: + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.norm = nn.LayerNorm(out_features) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return self.norm(x) + + +class Attention(nn.Module): + """Multi-head attention with ALiBi spatial bias for TICON.""" + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + context_dim: int | None = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + context_dim = context_dim or dim + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + # ALiBi slopes (registered as buffer for proper device handling) + slopes = torch.tensor(get_slopes(num_heads), dtype=torch.float32) + self.register_buffer("slopes", slopes[None, :, None, None]) + + def forward( + self, + x: Float[Tensor, "b n_q d"], + coords: Float[Tensor, "b n_q 2"], + context: Float[Tensor, "b n_k d_k"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n_q d"]: + if context is None: + context = x + context_coords = coords + + b, n_q, d = x.shape + n_k = context.shape[1] + h = self.num_heads + + # Project queries, keys, values + q = self.q_proj(x).reshape(b, n_q, h, d // h).transpose(1, 2) + k = self.k_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + v = self.v_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + + # Validate coordinates are available + if coords is None or context_coords is None: + raise ValueError( + "Coordinates must be provided for spatial attention with ALiBi bias" + ) + # Compute spatial distances for ALiBi + coords_exp = coords.unsqueeze(2).expand(-1, -1, n_k, -1) + ctx_coords_exp = context_coords.unsqueeze(1).expand(-1, n_q, -1, -1) + euclid_dist = torch.sqrt(torch.sum((coords_exp - ctx_coords_exp) ** 2, dim=-1)) + + # Apply ALiBi bias + attn_bias = -self.slopes * euclid_dist[:, None, :, :] + + # Attention with ALiBi + x = scaled_dot_product_attention_alibi( + q, + k, + v, + attn_bias=attn_bias, + training=self.training, + ) + + x = x.transpose(1, 2).reshape(b, n_q, d) + return self.proj(x) + + +class ResidualBlock(nn.Module): + """Residual connection with optional layer scale and stochastic depth.""" + + def __init__( + self, + drop_prob: float, + norm: nn.Module, + fn: nn.Module, + gamma: nn.Parameter | None, + ): + super().__init__() + self.norm = norm + self.fn = fn + self.keep_prob = 1 - drop_prob + self.gamma = gamma + + def forward(self, x: Tensor, **kwargs) -> Tensor: + fn_out = self.fn(self.norm(x), **kwargs) + + if self.gamma is not None: + fn_out = self.gamma * fn_out + + if self.keep_prob == 1.0 or not self.training: + return x + fn_out + + # Stochastic depth + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[:, None, None] + return x + fn_out * mask / self.keep_prob + + +class Block(nn.Module): + """Transformer block with attention and MLP.""" + + def __init__( + self, + dim: int, + drop_path: float, + norm_layer: Callable[[int], nn.Module], + context_dim: int | None, + layer_scale: bool = True, + attn_kwargs: Mapping = {}, + ) -> None: + super().__init__() + + gamma1 = nn.Parameter(torch.ones(dim)) if layer_scale else None + gamma2 = nn.Parameter(torch.ones(dim)) if layer_scale else None + + self.residual1 = ResidualBlock( + drop_path, + norm_layer(dim), + Attention(dim, context_dim=context_dim, **attn_kwargs), + gamma1, + ) + self.residual2 = ResidualBlock( + drop_path, + norm_layer(dim), + Mlp(in_features=dim), + gamma2, + ) + + def forward( + self, + x: Tensor, + coords: Tensor, + context: Tensor | None = None, + context_coords: Tensor | None = None, + ) -> Tensor: + x = self.residual1( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + x = self.residual2(x) + return x + + +class Transformer(nn.Module): + """Transformer encoder/decoder stack for TICON.""" + + def __init__( + self, + embed_dim: int, + norm_layer: Callable[[int], nn.Module], + depth: int, + drop_path_rate: float, + context_dim: int | None = None, + block_kwargs: Mapping[str, Any] = {}, + ): + super().__init__() + self.embed_dim = embed_dim + self.n_blocks = depth + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + drop_path=drop_path_rate, + norm_layer=norm_layer, + context_dim=context_dim, + **block_kwargs, + ) + for _ in range(depth) + ] + ) + + def forward( + self, + x: Tensor, + coords: Tensor, + return_layers: set[int], + contexts: list[Tensor] | None = None, + context_coords: Tensor | None = None, + ) -> dict[int, Tensor]: + outputs = {} + if 0 in return_layers: + outputs[0] = x + + for blk_idx, blk in enumerate(self.blocks): + context = contexts[blk_idx] if contexts is not None else None + x = blk( + x, + coords=coords, + context=context, + context_coords=context_coords, + ) + if blk_idx + 1 in return_layers: + outputs[blk_idx + 1] = x + + return outputs + + +class TiconBackbone(nn.Module): + """ + TICON Encoder-Decoder backbone. + + This is the core TICON model that contextualizes tile embeddings + using spatial attention with ALiBi positional bias. + """ + + def __init__( + self, + in_dims: list[int], + tile_encoder_keys: list[str], + transformers_kwargs: Mapping[str, Any], + encoder_kwargs: Mapping[str, Any], + decoder_kwargs: Mapping[str, Any] = {}, + norm_layer_type: str = "LayerNorm", + norm_layer_kwargs: Mapping[str, Any] = {"eps": 1e-5}, + final_norm_kwargs: Mapping[str, Any] = {"elementwise_affine": True}, + out_layer: int = -1, + num_decoders: int = 0, + decoder_out_dims: list[int] = [], + **kwargs, # Ignore extra kwargs like patch_size + ): + super().__init__() + + norm_layer: Callable[[int], nn.Module] = partial( + getattr(nn, norm_layer_type), **norm_layer_kwargs + ) + + self.encoder = Transformer( + **transformers_kwargs, + **encoder_kwargs, + norm_layer=norm_layer, + ) + + self.tile_encoder_keys = tile_encoder_keys + self.embed_dim = self.encoder.embed_dim + self.out_layer = out_layer % (len(self.encoder.blocks) + 1) + self.enc_norm = norm_layer(self.embed_dim, **final_norm_kwargs) + + # Input projections for each tile encoder + self.input_proj_dict = nn.ModuleDict( + { + f"input_proj_{key}": ProjectionMlp( + in_features=in_dims[i], + hidden_features=self.embed_dim, + out_features=self.embed_dim, + ) + for i, key in enumerate(tile_encoder_keys) + } + ) + + def init_weights(self) -> "TiconBackbone": + """Initialize model weights.""" + self.apply(_init_weights) + return self + + def forward( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"], + tile_encoder_key: str, + ) -> Float[Tensor, "b n d"]: + """ + Forward pass through TICON encoder. + + Args: + x: Tile embeddings [B, N, D] + relative_coords: Tile coordinates [B, N, 2] + tile_encoder_key: Which input projection to use + + Returns: + Contextualized embeddings [B, N, embed_dim] + """ + # Project input to TICON embedding dimension + x = self.input_proj_dict[f"input_proj_{tile_encoder_key}"](x) + + # Run through transformer encoder + encoder_outputs = self.encoder( + x, + coords=relative_coords, + return_layers={self.out_layer}, + ) + + # Apply final normalization + return self.enc_norm(encoder_outputs[self.out_layer]) + + +def _init_weights(m: nn.Module) -> None: + """Initialize model weights following JAX ViT convention.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) and m.elementwise_affine: + nn.init.constant_(m.weight, 1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def load_ticon_backbone( + device: DeviceLikeType = "cuda", + model_cfg: dict | None = None, +) -> TiconBackbone: + """Load pretrained TICON backbone from HuggingFace.""" + model_cfg = TICON_MODEL_CFG if model_cfg is None else model_cfg + + # Download checkpoint from HuggingFace + ckpt_path = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + # Create model on meta device (no memory allocation) + with torch.device("meta"): + model = TiconBackbone(**model_cfg) + + # Move to target device and initialize weights + model.to_empty(device=device) + model.init_weights() + + # Load pretrained weights + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = { + k.removeprefix("backbone."): v + for k, v in state_dict.items() + if k.startswith("backbone.") + } + + model.load_state_dict(state_dict, strict=False) + model.eval() + + return model + + +## TICON BACKBONE END ## + +## TICON ENCODER CLASS ## class TiconEncoder(Encoder): def __init__( self, diff --git a/src/stamp/modeling/models/ticon_architecture.py b/src/stamp/modeling/models/ticon_architecture.py index bd17ad4a..40e867d3 100644 --- a/src/stamp/modeling/models/ticon_architecture.py +++ b/src/stamp/modeling/models/ticon_architecture.py @@ -31,10 +31,6 @@ from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType -# ============================================================================= -# Configuration -# ============================================================================= - # Mapping: ExtractorName -> (ticon_key, embedding_dim) TILE_EXTRACTOR_TO_TICON: dict[ExtractorName, tuple[ExtractorName, int]] = { ExtractorName.CONCH1_5: (ExtractorName.CONCH1_5, 768), @@ -68,24 +64,8 @@ } -# ============================================================================= -# Utility Functions -# ============================================================================= - - def get_ticon_key(extractor: ExtractorName) -> tuple[ExtractorName, int]: - """ - Get TICON key and expected embedding dimension for an extractor. - - Args: - extractor: The tile extractor name - - Returns: - Tuple of (ticon_key, embedding_dim) - - Raises: - ValueError: If extractor is not supported by TICON - """ + """Get TICON key and embedding dimension for a given tile extractor.""" if extractor not in TILE_EXTRACTOR_TO_TICON: raise ValueError( f"No TICON mapping for extractor {extractor}. " @@ -94,18 +74,8 @@ def get_ticon_key(extractor: ExtractorName) -> tuple[ExtractorName, int]: return TILE_EXTRACTOR_TO_TICON[extractor] -# ============================================================================= -# ALiBi Helper Functions -# ============================================================================= - - def get_slopes(n: int) -> list[float]: - """ - Calculate ALiBi slopes for attention heads. - - ALiBi (Attention with Linear Biases) uses these slopes to create - position-dependent attention biases based on spatial distances. - """ + """Get ALiBi slopes for n attention heads.""" def get_slopes_power_of_2(n: int) -> list[float]: start = 2 ** (-(2 ** -(math.log2(n) - 3))) @@ -130,20 +100,6 @@ def scaled_dot_product_attention_alibi( dropout_p: float = 0.0, training: bool = False, ) -> Tensor: - """ - Scaled dot-product attention with ALiBi positional bias. - - Args: - query: Query tensor [B, H, N_q, D] - key: Key tensor [B, H, N_k, D] - value: Value tensor [B, H, N_k, D] - attn_bias: ALiBi bias tensor [B, H, N_q, N_k] - dropout_p: Dropout probability - training: Whether in training mode - - Returns: - Attention output [B, H, N_q, D] - """ # try Flash Attention with ALiBi first try: with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]): @@ -170,11 +126,6 @@ def scaled_dot_product_attention_alibi( return attn_weight @ value -# ============================================================================= -# Model Components -# ============================================================================= - - class Mlp(nn.Module): """MLP with SwiGLU activation (used in TICON transformer blocks).""" @@ -426,11 +377,6 @@ def forward( return outputs -# ============================================================================= -# TICON Backbone -# ============================================================================= - - class TiconBackbone(nn.Module): """ TICON Encoder-Decoder backbone. @@ -531,25 +477,11 @@ def _init_weights(m: nn.Module) -> None: nn.init.constant_(m.bias, 0) -# ============================================================================= -# Model Loading -# ============================================================================= - - def load_ticon_backbone( device: DeviceLikeType = "cuda", model_cfg: dict | None = None, ) -> TiconBackbone: - """ - Load TICON backbone with pretrained weights from HuggingFace. - - Args: - device: Device to load model on - model_cfg: Optional custom model configuration - - Returns: - TiconBackbone model in eval mode - """ + """Load pretrained TICON backbone from HuggingFace.""" model_cfg = TICON_MODEL_CFG if model_cfg is None else model_cfg # Download checkpoint from HuggingFace diff --git a/src/stamp/preprocessing/extractor/ticon_iso.py b/src/stamp/preprocessing/extractor/ticon_iso.py index b73c2a2f..01f86c5d 100644 --- a/src/stamp/preprocessing/extractor/ticon_iso.py +++ b/src/stamp/preprocessing/extractor/ticon_iso.py @@ -1,14 +1,3 @@ -""" -TICON Isolated Mode - Single tile processing compatible with Extractor pipeline. - -This module provides TICON in "isolated inference" mode, where each tile is -processed independently through a tile encoder and then through TICON. - -While this mode doesn't provide slide-level context, TICON still enhances -individual tile representations. For full slide-level contextualization, -use TiconEncoder after feature extraction. -""" - from typing import Callable, cast try: @@ -26,7 +15,7 @@ "Please reinstall stamp using `pip install 'stamp[ticon]'`" ) from e -from stamp.modeling.models.ticon_architecture import ( +from stamp.encoding.encoder.ticon_encoder import ( TILE_EXTRACTOR_TO_TICON, get_ticon_key, load_ticon_backbone, @@ -49,19 +38,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: def _create_tile_encoder( extractor: ExtractorName, ) -> tuple[nn.Module, Callable[[Image.Image], torch.Tensor]]: - """ - Create tile encoder and transform for a given extractor. - - Args: - extractor: The tile extractor to create - - Returns: - Tuple of (model, transform) - - Raises: - ValueError: If extractor is not supported - ModuleNotFoundError: If required dependencies are missing - """ + """Create tile encoder model and transform for given extractor.""" if extractor == ExtractorName.H_OPTIMUS_1: model = timm.create_model( "hf-hub:bioptimus/H-optimus-1", @@ -162,34 +139,9 @@ def _create_tile_encoder( f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" ) - +### TICON Isolated Mode Extractor ### class TICON(nn.Module): - """ - TICON in Isolated Inference Mode. - - Processes tiles independently: TileEncoder -> TICON (single tile). - Compatible with standard Extractor pipeline. - - Supports all tile encoders that TICON was trained on: - - H-Optimus-1 (1536-dim) - - GigaPath (1536-dim) - - UNI2 (1536-dim) - - Virchow2 (1280-dim) - - CONCH v1.5 (768-dim) - - Note: - This mode doesn't use slide-level context. For full contextualization, - use TiconEncoder after feature extraction. - - Args: - tile_extractor: Which tile encoder to use - device: Device to run on (default: "cuda") - - Example: - >>> model = TICON(tile_extractor=ExtractorName.GIGAPATH) - >>> embedding = model(tile_batch) # [B, 1536] - """ - + """TICON in Isolated Mode - processes each tile independently.""" def __init__( self, tile_extractor: ExtractorName = ExtractorName.H_OPTIMUS_1, @@ -224,24 +176,17 @@ def get_transform(self) -> Callable[[Image.Image], torch.Tensor]: @torch.inference_mode() def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Process tiles: TileEncoder -> TICON (isolated mode). - - Args: - x: [B, 3, 224, 224] batch of tile images - - Returns: - [B, embed_dim] contextualized embeddings - """ + """Forward pass through TICON Isolated Mode.""" x = x.to(self._device, non_blocking=True) # Stage 1: Extract tile features - with torch.amp.autocast( - device_type="cuda", - dtype=torch.bfloat16, - enabled=(self._device.type == "cuda"), - ): - emb = self.tile_encoder(x) + + # with torch.amp.autocast( + # device_type="cuda", + # dtype=torch.bfloat16, + # enabled=(self._device.type == "cuda"), + # ): + emb = self.tile_encoder(x) # Handle different output shapes (some models return [B, N, D]) if emb.dim() == 3: @@ -259,16 +204,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dtype=torch.float32, ) - with torch.amp.autocast( - device_type="cuda", - dtype=torch.bfloat16, - enabled=(self._device.type == "cuda"), - ): - out = self.ticon( - x=emb.float(), # TICON expects float32 input - relative_coords=coords, - tile_encoder_key=self.tile_encoder_key, - ) + # with torch.amp.autocast( + # device_type="cuda", + # dtype=torch.bfloat16, + # enabled=(self._device.type == "cuda"), + # ): + out = self.ticon( + x=emb.float(), # TICON expects float32 input + relative_coords=coords, + tile_encoder_key=self.tile_encoder_key, + ) # Remove sequence dimension: [B, 1, D] -> [B, D] return out.squeeze(1) @@ -278,25 +223,7 @@ def ticon_iso( tile_extractor: ExtractorName = ExtractorName.H_OPTIMUS_1, device: str = "cuda", ) -> Extractor[TICON]: - """ - Create TICON in Isolated Mode (Extractor-compatible). - - This mode processes each tile independently through both the tile encoder - and TICON. While it doesn't provide slide-level context, TICON still - enhances individual tile representations. - - Args: - tile_extractor: Which tile encoder to use. Supported: - - ExtractorName.H_OPTIMUS_1 (default) - - ExtractorName.GIGAPATH - - ExtractorName.UNI2 - - ExtractorName.VIRCHOW2 - - ExtractorName.CONCH1_5 - device: CUDA device - - Returns: - Extractor compatible with standard pipeline - """ + """Create TICON Isolated Mode extractor.""" model = TICON(tile_extractor=tile_extractor, device=device) return Extractor( From d7e4f8fa505c6254c1b2d34fa9bbf4a291ceb5fe Mon Sep 17 00:00:00 2001 From: drgmo Date: Thu, 15 Jan 2026 10:21:38 +0000 Subject: [PATCH 4/8] formatted --- src/stamp/encoding/encoder/ticon_encoder.py | 2 +- src/stamp/preprocessing/extractor/ticon_iso.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/stamp/encoding/encoder/ticon_encoder.py b/src/stamp/encoding/encoder/ticon_encoder.py index 3c8342e9..8f700ac9 100644 --- a/src/stamp/encoding/encoder/ticon_encoder.py +++ b/src/stamp/encoding/encoder/ticon_encoder.py @@ -15,6 +15,7 @@ url={https://arxiv.org/abs/2512.21331}, } """ + import logging import math import os @@ -755,4 +756,3 @@ def encode_slides_( else None, unit="um", ) - diff --git a/src/stamp/preprocessing/extractor/ticon_iso.py b/src/stamp/preprocessing/extractor/ticon_iso.py index 01f86c5d..ec4b26f3 100644 --- a/src/stamp/preprocessing/extractor/ticon_iso.py +++ b/src/stamp/preprocessing/extractor/ticon_iso.py @@ -139,9 +139,11 @@ def _create_tile_encoder( f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" ) + ### TICON Isolated Mode Extractor ### class TICON(nn.Module): """TICON in Isolated Mode - processes each tile independently.""" + def __init__( self, tile_extractor: ExtractorName = ExtractorName.H_OPTIMUS_1, From d420ddeab5191735856c17d70beabed368ad56bf Mon Sep 17 00:00:00 2001 From: drgmo Date: Thu, 15 Jan 2026 10:37:54 +0000 Subject: [PATCH 5/8] now no module in modeling --- .../modeling/models/ticon_architecture.py | 513 ------------------ 1 file changed, 513 deletions(-) delete mode 100644 src/stamp/modeling/models/ticon_architecture.py diff --git a/src/stamp/modeling/models/ticon_architecture.py b/src/stamp/modeling/models/ticon_architecture.py deleted file mode 100644 index 40e867d3..00000000 --- a/src/stamp/modeling/models/ticon_architecture.py +++ /dev/null @@ -1,513 +0,0 @@ -""" -TICON Model Architecture and Configuration. - -Shared between "Isolated" and "Contextualized" modes. -Contains all model components, configuration, and utility functions. -Adapted from: - -@misc{belagali2025ticonslideleveltilecontextualizer, - title={TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning}, - author={Varun Belagali and Saarthak Kapse and Pierre Marza and Srijan Das and Zilinghan Li and Sofiène Boutaj and Pushpak Pati and Srikar Yellapragada and Tarak Nath Nandi and Ravi K Madduri and Joel Saltz and Prateek Prasanna and Stergios Christodoulidis and Maria Vakalopoulou and Dimitris Samaras}, - year={2025}, - eprint={2512.21331}, - archivePrefix={arXiv}, - primaryClass={cs.CV}, - url={https://arxiv.org/abs/2512.21331}, -} -""" - -import math -from collections.abc import Callable, Mapping -from functools import partial -from typing import Any - -import torch -import torch.nn as nn -from huggingface_hub import hf_hub_download -from jaxtyping import Float -from torch import Tensor -from torch.nn.attention import SDPBackend, sdpa_kernel - -from stamp.preprocessing.config import ExtractorName -from stamp.types import DeviceLikeType - -# Mapping: ExtractorName -> (ticon_key, embedding_dim) -TILE_EXTRACTOR_TO_TICON: dict[ExtractorName, tuple[ExtractorName, int]] = { - ExtractorName.CONCH1_5: (ExtractorName.CONCH1_5, 768), - ExtractorName.H_OPTIMUS_1: (ExtractorName.H_OPTIMUS_1, 1536), - ExtractorName.UNI2: (ExtractorName.UNI2, 1536), - ExtractorName.GIGAPATH: (ExtractorName.GIGAPATH, 1536), - ExtractorName.VIRCHOW2: (ExtractorName.VIRCHOW2, 1280), -} - -# TICON model configuration -TICON_MODEL_CFG: dict[str, Any] = { - "transformers_kwargs": { - "embed_dim": 1536, - "drop_path_rate": 0.0, - "block_kwargs": { - "attn_kwargs": {"num_heads": 24}, - }, - }, - "encoder_kwargs": {"depth": 6}, - "decoder_kwargs": {"depth": 1}, - "in_dims": [768, 1536, 1536, 1536, 1280], - "tile_encoder_keys": [ - ExtractorName.CONCH1_5, - ExtractorName.H_OPTIMUS_1, - ExtractorName.UNI2, - ExtractorName.GIGAPATH, - ExtractorName.VIRCHOW2, - ], - "num_decoders": 1, - "decoder_out_dims": [768, 1536, 1536, 1536, 1280], -} - - -def get_ticon_key(extractor: ExtractorName) -> tuple[ExtractorName, int]: - """Get TICON key and embedding dimension for a given tile extractor.""" - if extractor not in TILE_EXTRACTOR_TO_TICON: - raise ValueError( - f"No TICON mapping for extractor {extractor}. " - f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" - ) - return TILE_EXTRACTOR_TO_TICON[extractor] - - -def get_slopes(n: int) -> list[float]: - """Get ALiBi slopes for n attention heads.""" - - def get_slopes_power_of_2(n: int) -> list[float]: - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) - - -def scaled_dot_product_attention_alibi( - query: Tensor, - key: Tensor, - value: Tensor, - attn_bias: Tensor, - dropout_p: float = 0.0, - training: bool = False, -) -> Tensor: - # try Flash Attention with ALiBi first - try: - with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]): - return torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_bias, - dropout_p=dropout_p if training else 0.0, - is_causal=False, - ) - except Exception: - pass - - scale_factor = 1 / math.sqrt(query.size(-1)) - - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight = attn_weight + attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - - if dropout_p > 0.0: - attn_weight = torch.dropout(attn_weight, dropout_p, train=training) - - return attn_weight @ value - - -class Mlp(nn.Module): - """MLP with SwiGLU activation (used in TICON transformer blocks).""" - - def __init__( - self, - in_features: int, - hidden_features: int | None = None, - mlp_ratio: float = 16 / 3, - bias: bool = True, - ) -> None: - super().__init__() - if hidden_features is None: - hidden_features = int(in_features * mlp_ratio) - - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = nn.SiLU() - self.fc2 = nn.Linear(hidden_features // 2, in_features, bias=bias) - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - x1, x2 = x.chunk(2, dim=-1) - x = self.act(x1) * x2 - return self.fc2(x) - - -class ProjectionMlp(nn.Module): - """Projection MLP for input/output transformations with LayerNorm.""" - - def __init__( - self, - in_features: int, - hidden_features: int, - out_features: int, - bias: bool = True, - ) -> None: - super().__init__() - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = nn.SiLU() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) - self.norm = nn.LayerNorm(out_features) - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - x = self.act(x) - x = self.fc2(x) - return self.norm(x) - - -class Attention(nn.Module): - """Multi-head attention with ALiBi spatial bias for TICON.""" - - def __init__( - self, - dim: int, - num_heads: int, - qkv_bias: bool = True, - proj_bias: bool = True, - context_dim: int | None = None, - ) -> None: - super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - context_dim = context_dim or dim - - self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) - self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias) - self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias) - self.proj = nn.Linear(dim, dim, bias=proj_bias) - - # ALiBi slopes (registered as buffer for proper device handling) - slopes = torch.tensor(get_slopes(num_heads), dtype=torch.float32) - self.register_buffer("slopes", slopes[None, :, None, None]) - - def forward( - self, - x: Float[Tensor, "b n_q d"], - coords: Float[Tensor, "b n_q 2"], - context: Float[Tensor, "b n_k d_k"] | None = None, - context_coords: Float[Tensor, "b n_k 2"] | None = None, - ) -> Float[Tensor, "b n_q d"]: - if context is None: - context = x - context_coords = coords - - b, n_q, d = x.shape - n_k = context.shape[1] - h = self.num_heads - - # Project queries, keys, values - q = self.q_proj(x).reshape(b, n_q, h, d // h).transpose(1, 2) - k = self.k_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) - v = self.v_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) - - # Validate coordinates are available - if coords is None or context_coords is None: - raise ValueError( - "Coordinates must be provided for spatial attention with ALiBi bias" - ) - # Compute spatial distances for ALiBi - coords_exp = coords.unsqueeze(2).expand(-1, -1, n_k, -1) - ctx_coords_exp = context_coords.unsqueeze(1).expand(-1, n_q, -1, -1) - euclid_dist = torch.sqrt(torch.sum((coords_exp - ctx_coords_exp) ** 2, dim=-1)) - - # Apply ALiBi bias - attn_bias = -self.slopes * euclid_dist[:, None, :, :] - - # Attention with ALiBi - x = scaled_dot_product_attention_alibi( - q, - k, - v, - attn_bias=attn_bias, - training=self.training, - ) - - x = x.transpose(1, 2).reshape(b, n_q, d) - return self.proj(x) - - -class ResidualBlock(nn.Module): - """Residual connection with optional layer scale and stochastic depth.""" - - def __init__( - self, - drop_prob: float, - norm: nn.Module, - fn: nn.Module, - gamma: nn.Parameter | None, - ): - super().__init__() - self.norm = norm - self.fn = fn - self.keep_prob = 1 - drop_prob - self.gamma = gamma - - def forward(self, x: Tensor, **kwargs) -> Tensor: - fn_out = self.fn(self.norm(x), **kwargs) - - if self.gamma is not None: - fn_out = self.gamma * fn_out - - if self.keep_prob == 1.0 or not self.training: - return x + fn_out - - # Stochastic depth - mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[:, None, None] - return x + fn_out * mask / self.keep_prob - - -class Block(nn.Module): - """Transformer block with attention and MLP.""" - - def __init__( - self, - dim: int, - drop_path: float, - norm_layer: Callable[[int], nn.Module], - context_dim: int | None, - layer_scale: bool = True, - attn_kwargs: Mapping = {}, - ) -> None: - super().__init__() - - gamma1 = nn.Parameter(torch.ones(dim)) if layer_scale else None - gamma2 = nn.Parameter(torch.ones(dim)) if layer_scale else None - - self.residual1 = ResidualBlock( - drop_path, - norm_layer(dim), - Attention(dim, context_dim=context_dim, **attn_kwargs), - gamma1, - ) - self.residual2 = ResidualBlock( - drop_path, - norm_layer(dim), - Mlp(in_features=dim), - gamma2, - ) - - def forward( - self, - x: Tensor, - coords: Tensor, - context: Tensor | None = None, - context_coords: Tensor | None = None, - ) -> Tensor: - x = self.residual1( - x, - context=context, - coords=coords, - context_coords=context_coords, - ) - x = self.residual2(x) - return x - - -class Transformer(nn.Module): - """Transformer encoder/decoder stack for TICON.""" - - def __init__( - self, - embed_dim: int, - norm_layer: Callable[[int], nn.Module], - depth: int, - drop_path_rate: float, - context_dim: int | None = None, - block_kwargs: Mapping[str, Any] = {}, - ): - super().__init__() - self.embed_dim = embed_dim - self.n_blocks = depth - - self.blocks = nn.ModuleList( - [ - Block( - dim=embed_dim, - drop_path=drop_path_rate, - norm_layer=norm_layer, - context_dim=context_dim, - **block_kwargs, - ) - for _ in range(depth) - ] - ) - - def forward( - self, - x: Tensor, - coords: Tensor, - return_layers: set[int], - contexts: list[Tensor] | None = None, - context_coords: Tensor | None = None, - ) -> dict[int, Tensor]: - outputs = {} - if 0 in return_layers: - outputs[0] = x - - for blk_idx, blk in enumerate(self.blocks): - context = contexts[blk_idx] if contexts is not None else None - x = blk( - x, - coords=coords, - context=context, - context_coords=context_coords, - ) - if blk_idx + 1 in return_layers: - outputs[blk_idx + 1] = x - - return outputs - - -class TiconBackbone(nn.Module): - """ - TICON Encoder-Decoder backbone. - - This is the core TICON model that contextualizes tile embeddings - using spatial attention with ALiBi positional bias. - """ - - def __init__( - self, - in_dims: list[int], - tile_encoder_keys: list[str], - transformers_kwargs: Mapping[str, Any], - encoder_kwargs: Mapping[str, Any], - decoder_kwargs: Mapping[str, Any] = {}, - norm_layer_type: str = "LayerNorm", - norm_layer_kwargs: Mapping[str, Any] = {"eps": 1e-5}, - final_norm_kwargs: Mapping[str, Any] = {"elementwise_affine": True}, - out_layer: int = -1, - num_decoders: int = 0, - decoder_out_dims: list[int] = [], - **kwargs, # Ignore extra kwargs like patch_size - ): - super().__init__() - - norm_layer: Callable[[int], nn.Module] = partial( - getattr(nn, norm_layer_type), **norm_layer_kwargs - ) - - self.encoder = Transformer( - **transformers_kwargs, - **encoder_kwargs, - norm_layer=norm_layer, - ) - - self.tile_encoder_keys = tile_encoder_keys - self.embed_dim = self.encoder.embed_dim - self.out_layer = out_layer % (len(self.encoder.blocks) + 1) - self.enc_norm = norm_layer(self.embed_dim, **final_norm_kwargs) - - # Input projections for each tile encoder - self.input_proj_dict = nn.ModuleDict( - { - f"input_proj_{key}": ProjectionMlp( - in_features=in_dims[i], - hidden_features=self.embed_dim, - out_features=self.embed_dim, - ) - for i, key in enumerate(tile_encoder_keys) - } - ) - - def init_weights(self) -> "TiconBackbone": - """Initialize model weights.""" - self.apply(_init_weights) - return self - - def forward( - self, - x: Float[Tensor, "b n d"], - relative_coords: Float[Tensor, "b n 2"], - tile_encoder_key: str, - ) -> Float[Tensor, "b n d"]: - """ - Forward pass through TICON encoder. - - Args: - x: Tile embeddings [B, N, D] - relative_coords: Tile coordinates [B, N, 2] - tile_encoder_key: Which input projection to use - - Returns: - Contextualized embeddings [B, N, embed_dim] - """ - # Project input to TICON embedding dimension - x = self.input_proj_dict[f"input_proj_{tile_encoder_key}"](x) - - # Run through transformer encoder - encoder_outputs = self.encoder( - x, - coords=relative_coords, - return_layers={self.out_layer}, - ) - - # Apply final normalization - return self.enc_norm(encoder_outputs[self.out_layer]) - - -def _init_weights(m: nn.Module) -> None: - """Initialize model weights following JAX ViT convention.""" - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm) and m.elementwise_affine: - nn.init.constant_(m.weight, 1.0) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - -def load_ticon_backbone( - device: DeviceLikeType = "cuda", - model_cfg: dict | None = None, -) -> TiconBackbone: - """Load pretrained TICON backbone from HuggingFace.""" - model_cfg = TICON_MODEL_CFG if model_cfg is None else model_cfg - - # Download checkpoint from HuggingFace - ckpt_path = hf_hub_download( - repo_id="varunb/TICON", - filename="backbone/checkpoint.pth", - repo_type="model", - ) - - # Create model on meta device (no memory allocation) - with torch.device("meta"): - model = TiconBackbone(**model_cfg) - - # Move to target device and initialize weights - model.to_empty(device=device) - model.init_weights() - - # Load pretrained weights - state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) - state_dict = { - k.removeprefix("backbone."): v - for k, v in state_dict.items() - if k.startswith("backbone.") - } - - model.load_state_dict(state_dict, strict=False) - model.eval() - - return model From 83a6ca84f6abf673319c6354b299101026a8d0e4 Mon Sep 17 00:00:00 2001 From: drgmo Date: Thu, 15 Jan 2026 11:11:32 +0000 Subject: [PATCH 6/8] adapted config.yaml --- src/stamp/config.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 796140a5..769bb355 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip" + # "virchow-full", "musk", "mstar", "plip", "ticon" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -12,6 +12,9 @@ preprocessing: device: "cuda" # Optional settings: + # if "ticon" is selected, specify model to enhance + # e.g. "h-optimus-1, "virchow2","conch1_5", "uni2", "gigapath" + tile_extractor: "h-optimus-1" # Having a cache dir will speed up extracting features multiple times, # e.g. with different feature extractors. Optional. @@ -249,7 +252,7 @@ heatmaps: slide_encoding: # Encoder to use for slide encoding. Possible options are "cobra", - # "eagle", "titan", "gigapath", "chief", "prism", "madeleine". + # "eagle", "titan", "gigapath", "chief", "prism", "madeleine", "ticon". encoder: "chief" # Directory to save the output files. From d9dce4e85e95b7396f9dec7848fa8b082ad532f8 Mon Sep 17 00:00:00 2001 From: drgmo Date: Thu, 15 Jan 2026 13:20:52 +0000 Subject: [PATCH 7/8] no unit needed --- src/stamp/encoding/encoder/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 1df14d63..62d89793 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -206,12 +206,10 @@ def _save_features_( try: f["feats"] = feats f["coords"] = kwargs.get("coords", np.array([])) - # wichtig für get_coords() if "tile_size_um" in kwargs and kwargs["tile_size_um"] is not None: f.attrs["tile_size_um"] = float(kwargs["tile_size_um"]) if "tile_size_px" in kwargs and kwargs["tile_size_px"] is not None: f.attrs["tile_size_px"] = int(kwargs["tile_size_px"]) - f.attrs["unit"] = kwargs.get("unit", "um") f.attrs["version"] = stamp.__version__ f.attrs["encoder"] = str(self.identifier) f.attrs["precision"] = str(self.precision) From b1b79ff0d6beb09c6b42417a6cf1447b66b302e0 Mon Sep 17 00:00:00 2001 From: drgmo Date: Thu, 15 Jan 2026 17:07:50 +0000 Subject: [PATCH 8/8] fixed runtime issues in isolated inference mode --- src/stamp/__main__.py | 1 + src/stamp/preprocessing/__init__.py | 11 ++++++++++- src/stamp/preprocessing/config.py | 2 +- .../extractor/{ticon_iso.py => ticon.py} | 10 +++------- 4 files changed, 15 insertions(+), 9 deletions(-) rename src/stamp/preprocessing/extractor/{ticon_iso.py => ticon.py} (96%) diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 4ab8416f..f5d8a62a 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -86,6 +86,7 @@ def _run_cli(args: argparse.Namespace) -> None: tile_size_um=config.preprocessing.tile_size_um, tile_size_px=config.preprocessing.tile_size_px, extractor=config.preprocessing.extractor, + tile_extractor=config.preprocessing.tile_extractor, max_workers=config.preprocessing.max_workers, device=config.preprocessing.device, default_slide_mpp=config.preprocessing.default_slide_mpp, diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index a1844526..ebcf3a03 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -122,6 +122,7 @@ def extract_( cache_dir: Path | None, cache_tiles_ext: ImageExtension, extractor: ExtractorName | Extractor, + tile_extractor: ExtractorName, tile_size_px: TilePixels, tile_size_um: Microns, max_workers: int, @@ -222,6 +223,11 @@ def extract_( extractor = plip() + case ExtractorName.TICON: + from stamp.preprocessing.extractor.ticon import ticon + + extractor = ticon(tile_extractor=tile_extractor) + case ExtractorName.EMPTY: from stamp.preprocessing.extractor.empty import empty @@ -238,7 +244,8 @@ def extract_( code_hash = get_processing_code_hash(Path(__file__))[:8] extractor_id = extractor.identifier - + if extractor_id == ExtractorName.TICON and tile_extractor is not None: + extractor_id = f"{extractor_id}-{tile_extractor}" _logger.info(f"Using extractor {extractor.identifier}") if cache_dir: @@ -330,6 +337,8 @@ def extract_( h5_fp.attrs["stamp_version"] = stamp.__version__ h5_fp.attrs["extractor"] = str(extractor.identifier) + if tile_extractor is not None: + h5_fp.attrs["tile_extractor"] = str(tile_extractor) h5_fp.attrs["unit"] = "um" h5_fp.attrs["tile_size_um"] = tile_size_um # changed in v2.1.0 h5_fp.attrs["tile_size_px"] = tile_size_px diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index efb015e9..b8595ae6 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -45,7 +45,7 @@ class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): tile_size_um: Microns = Microns(256.0) tile_size_px: TilePixels = TilePixels(224) extractor: ExtractorName - tile_extractor: ExtractorName | None = None + tile_extractor: ExtractorName max_workers: int = 8 device: str = "cuda" if torch.cuda.is_available() else "cpu" generate_hash: bool = True diff --git a/src/stamp/preprocessing/extractor/ticon_iso.py b/src/stamp/preprocessing/extractor/ticon.py similarity index 96% rename from src/stamp/preprocessing/extractor/ticon_iso.py rename to src/stamp/preprocessing/extractor/ticon.py index ec4b26f3..411c722d 100644 --- a/src/stamp/preprocessing/extractor/ticon_iso.py +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -146,7 +146,7 @@ class TICON(nn.Module): def __init__( self, - tile_extractor: ExtractorName = ExtractorName.H_OPTIMUS_1, + tile_extractor: ExtractorName, device: str = "cuda", ): super().__init__() @@ -221,13 +221,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out.squeeze(1) -def ticon_iso( - tile_extractor: ExtractorName = ExtractorName.H_OPTIMUS_1, - device: str = "cuda", -) -> Extractor[TICON]: +def ticon(tile_extractor: ExtractorName) -> Extractor[TICON]: """Create TICON Isolated Mode extractor.""" - model = TICON(tile_extractor=tile_extractor, device=device) - + model = TICON(tile_extractor=tile_extractor) return Extractor( model=model, transform=model.get_transform(),