Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Stamp currently supports the following feature extractors:
- [mSTAR][mstar]
- [MUSK][musk]
- [PLIP][plip]
- [TICON][ticon]


As some of the above require you to request access to the model on huggingface,
Expand Down Expand Up @@ -158,6 +159,7 @@ meaning ignored that it was ignored during feature extraction.
[EAGLE]: https://github.com/KatherLab/EAGLE
[MADELEINE]: https://huggingface.co/MahmoodLab/madeleine
[PRISM]: https://huggingface.co/paige-ai/Prism
[TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning"



Expand Down
3 changes: 3 additions & 0 deletions src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ preprocessing:
#tile_size_um: 256.0
#tile_size_px: 224

# Magnification level to extract tiles from
#default_slide_mpp: 1.0

# How many workers to use for tile extraction. Should be less or equal to
# the number of cores of your system.
#max_workers: 8
Expand Down
62 changes: 56 additions & 6 deletions src/stamp/encoding/encoder/eagle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from collections import defaultdict, deque
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -59,11 +60,26 @@ def _validate_and_read_features_with_agg(
f"Features located in {h5_vir2} are extracted with {extractor}"
)

if feats.shape[0] != agg_feats.shape[0]:
raise ValueError(
f"Number of ctranspath features and virchow2 features do not match:"
f" {feats.shape[0]} != {agg_feats.shape[0]}"
)
# if feats.shape[0] != agg_feats.shape[0]:
# raise ValueError(
# f"Number of ctranspath features and virchow2 features do not match:"
# f" {feats.shape[0]} != {agg_feats.shape[0]}"
# )
if not np.allclose(coords.coords_um, agg_coords.coords_um, atol=1e-5, rtol=0):
# Try to fix permutation by aligning virchow2 to ctp coords
try:
agg_feats, aligned_agg_coords = _align_vir2_to_ctp_by_coords(
ref_coords_um=coords.coords_um,
other_coords_um=agg_coords.coords_um,
other_feats=agg_feats,
decimals=5,
)
agg_coords.coords_um = aligned_agg_coords # optional, for debugging
except ValueError as e:
raise ValueError(
f"Coordinates mismatch between ctranspath and virchow2 features for slide "
f"{slide_name}. Alignment attempt failed: {e}"
)

if not np.allclose(coords.coords_um, agg_coords.coords_um, atol=1e-5, rtol=0):
raise ValueError(
Expand Down Expand Up @@ -144,7 +160,7 @@ def encode_slides_(
for tile_feats_filename in (progress := tqdm(os.listdir(feat_dir))):
h5_ctp = os.path.join(feat_dir, tile_feats_filename)
h5_vir2 = os.path.join(agg_feat_dir, tile_feats_filename)
slide_name: str = Path(tile_feats_filename).stem
slide_name: str = Path(tile_feats_filename).name
progress.set_description(slide_name)

# skip patient in case feature file already exists
Expand Down Expand Up @@ -238,3 +254,37 @@ def encode_patients_(
self._save_features_(
output_path=output_path, feats=patient_embedding, feat_type="patient"
)


def _align_vir2_to_ctp_by_coords(
ref_coords_um: np.ndarray,
other_coords_um: np.ndarray,
other_feats: torch.Tensor,
decimals: int = 5,
) -> tuple[torch.Tensor, np.ndarray]:
"""Align vir2 features to ctp features based on coordinates."""
ref = np.round(np.asarray(ref_coords_um, dtype=np.float64), decimals)
oth = np.round(np.asarray(other_coords_um, dtype=np.float64), decimals)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add some more comments on what is happening here

# coord -> queue(indices)
buckets = defaultdict(deque)
for j, key in enumerate(map(tuple, oth)):
buckets[key].append(j)

perm = np.empty(ref.shape[0], dtype=np.int64)
for i, key in enumerate(map(tuple, ref)):
if not buckets[key]:
raise ValueError(f"Missing coord in other set: {key}")
perm[i] = buckets[key].popleft()

# optional: check if other has extras not used
unused = sum(len(q) for q in buckets.values())
if unused != 0:
raise ValueError(f"virchow2 features contain {unused} extra coords not in ref.")

perm_t = torch.as_tensor(perm, dtype=torch.long, device=other_feats.device)
# Align features according to the permutation as well !
aligned_feats = other_feats.index_select(0, perm_t)
aligned_coords = other_coords_um[perm]
print("")
return aligned_feats, aligned_coords
10 changes: 9 additions & 1 deletion src/stamp/encoding/encoder/titan.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,15 @@ def _generate_slide_embedding(
coords_tensor = torch.tensor(coords.coords_um, dtype=self.precision)

# Convert coordinates from microns to pixels
patch_size_lvl0 = math.floor(256 / coords.mpp) # Inferred from TITAN docs
xs = torch.unique(coords_tensor[:, 0])
ys = torch.unique(coords_tensor[:, 1])
patch_size_lvl0 = int(
min(
(xs[1:] - xs[:-1])[(xs[1:] - xs[:-1]) > 0].min(),
(ys[1:] - ys[:-1])[(ys[1:] - ys[:-1]) > 0].min(),
)
)

coords_px = coords_tensor / coords.mpp # Convert to pixels
coords_px = coords_px.to(torch.int64).to(device) # Convert to integer

Expand Down
24 changes: 15 additions & 9 deletions src/stamp/modeling/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,25 @@ def _get_splits(
*, patient_to_data: Mapping[PatientId, PatientData[Any]], n_splits: int, spliter
) -> _Splits:
patients = np.array(list(patient_to_data.keys()))
skf = spliter(n_splits=n_splits, shuffle=True, random_state=0)

# Detect survival GT: "time status"
tokens = [str(p.ground_truth).split() for p in patient_to_data.values()]

if all(len(t) == 2 for t in tokens):
y = np.array([int(t[1]) for t in tokens], dtype=int)
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
iterator = skf.split(patients, y)
else:
skf = KFold(n_splits=n_splits, shuffle=True, random_state=0)
iterator = skf.split(patients)

splits = _Splits(
splits=[
_Split(
train_patients=set(patients[train_indices]),
test_patients=set(patients[test_indices]),
)
for train_indices, test_indices in skf.split(
patients,
np.array(
[patient.ground_truth for patient in patient_to_data.values()]
),
train_patients=set(patients[train_idx]),
test_patients=set(patients[test_idx]),
)
for train_idx, test_idx in iterator
]
)
return splits
5 changes: 5 additions & 0 deletions src/stamp/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ def extract_(

extractor = plip()

case ExtractorName.TICON:
from stamp.preprocessing.extractor.ticon import ticon

extractor = ticon()

case ExtractorName.EMPTY:
from stamp.preprocessing.extractor.empty import empty

Expand Down
1 change: 1 addition & 0 deletions src/stamp/preprocessing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ExtractorName(StrEnum):
MUSK = "musk"
MSTAR = "mstar"
PLIP = "plip"
TICON = "ticon"
EMPTY = "empty"


Expand Down
Loading