From 4974adfc0a5008387821e863c147b9e55c1407a2 Mon Sep 17 00:00:00 2001 From: mducducd Date: Wed, 11 Feb 2026 08:40:03 +0000 Subject: [PATCH 1/3] update comments --- src/stamp/encoding/encoder/eagle.py | 4 ++- src/stamp/preprocessing/extractor/ticon.py | 8 ++++-- src/stamp/statistics/survival.py | 31 +++++++++++++++------- uv.lock | 9 +++---- 4 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index d966c84..3d9baca 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -263,10 +263,12 @@ def _align_vir2_to_ctp_by_coords( decimals: int = 5, ) -> tuple[torch.Tensor, np.ndarray]: """Align vir2 features to ctp features based on coordinates.""" + # round coordinates to avoid floating-point precision mismatches ref = np.round(np.asarray(ref_coords_um, dtype=np.float64), decimals) oth = np.round(np.asarray(other_coords_um, dtype=np.float64), decimals) - # coord -> queue(indices) + # build mapping: coordinate -> queue of indices + # using deque ensures stable matching when duplicates exist buckets = defaultdict(deque) for j, key in enumerate(map(tuple, oth)): buckets[key].append(j) diff --git a/src/stamp/preprocessing/extractor/ticon.py b/src/stamp/preprocessing/extractor/ticon.py index fb8f9b4..ab7eb82 100644 --- a/src/stamp/preprocessing/extractor/ticon.py +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -1,3 +1,9 @@ +""" +This file contains code adapted from: +TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning +https://github.com/cvlab-stonybrook/TICON +""" + import math from collections.abc import Callable, Mapping from functools import partial @@ -8,8 +14,6 @@ import torch.nn as nn from huggingface_hub import hf_hub_download from jaxtyping import Float - -# from omegaconf import OmegaConf from torch import Tensor from torchvision import transforms diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 063793c..4637102 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -136,16 +136,27 @@ def _plot_km( ) kmf_high.plot_survival_function(ax=ax, ci_show=False, color="red") - add_at_risk_counts(kmf_low, kmf_high, ax=ax) - - # --- log-rank and c-index --- - res = logrank_test( - low_df[time_label], - high_df[time_label], - event_observed_A=low_df[status_label], - event_observed_B=high_df[status_label], - ) - logrank_p = float(res.p_value) + # ---- add at-risk counts only for fitted curves ---- + fitters = [] + if len(low_df) > 0: + fitters.append(kmf_low) + if len(high_df) > 0: + fitters.append(kmf_high) + + if len(fitters) > 0: + add_at_risk_counts(*fitters, ax=ax) + + # ---- log-rank only if both groups exist ---- + if len(low_df) > 0 and len(high_df) > 0: + res = logrank_test( + low_df[time_label], + high_df[time_label], + event_observed_A=low_df[status_label], + event_observed_B=high_df[status_label], + ) + logrank_p = float(res.p_value) + else: + logrank_p = float("nan") c_used, used, *_ = _cindex(time, event, risk) ax.text( diff --git a/uv.lock b/uv.lock index c4015d9..96b4b73 100644 --- a/uv.lock +++ b/uv.lock @@ -3699,13 +3699,14 @@ wheels = [ [[package]] name = "stamp" -version = "2.3.0" +version = "2.4.0" source = { editable = "." } dependencies = [ { name = "beartype" }, { name = "einops" }, { name = "h5py" }, { name = "jaxtyping" }, + { name = "lifelines" }, { name = "lightning" }, { name = "matplotlib" }, { name = "numpy" }, @@ -3807,7 +3808,6 @@ gigapath = [ { name = "fvcore" }, { name = "gigapath" }, { name = "iopath" }, - { name = "lifelines" }, { name = "monai" }, { name = "scikit-image" }, { name = "scikit-survival" }, @@ -3828,7 +3828,6 @@ gpu = [ { name = "huggingface-hub" }, { name = "iopath" }, { name = "jinja2" }, - { name = "lifelines" }, { name = "madeleine" }, { name = "mamba-ssm" }, { name = "monai" }, @@ -3920,7 +3919,7 @@ requires-dist = [ { name = "iopath", marker = "extra == 'gigapath'" }, { name = "jaxtyping", specifier = ">=0.3.2" }, { name = "jinja2", marker = "extra == 'cobra'", specifier = ">=3.1.4" }, - { name = "lifelines", marker = "extra == 'gigapath'" }, + { name = "lifelines", specifier = ">=0.28.0" }, { name = "lightning", specifier = ">=2.5.2" }, { name = "madeleine", marker = "extra == 'madeleine'", git = "https://github.com/mahmoodlab/MADELEINE.git?rev=de7c85acc2bdad352e6df8eee5694f8b6f288012" }, { name = "mamba-ssm", marker = "extra == 'cobra'", specifier = ">=2.2.6.post3" }, @@ -4747,4 +4746,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/24/fd/725b8e73ac2a50e78a4534ac43c6addf5c1c2d65380dd48a9169cc6739a9/yarl-1.20.1-cp313-cp313t-win32.whl", hash = "sha256:b121ff6a7cbd4abc28985b6028235491941b9fe8fe226e6fdc539c977ea1739d", size = 86591, upload-time = "2025-06-10T00:45:25.793Z" }, { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" }, { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" }, -] \ No newline at end of file +] From f6a4279cc3775db01431c19b37495413a8931963 Mon Sep 17 00:00:00 2001 From: mducducd Date: Wed, 11 Feb 2026 08:42:48 +0000 Subject: [PATCH 2/3] update comments --- src/stamp/statistics/survival.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 4637102..d8f5ebf 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -136,7 +136,7 @@ def _plot_km( ) kmf_high.plot_survival_function(ax=ax, ci_show=False, color="red") - # ---- add at-risk counts only for fitted curves ---- + # add at-risk counts only for fitted curves fitters = [] if len(low_df) > 0: fitters.append(kmf_low) @@ -146,7 +146,7 @@ def _plot_km( if len(fitters) > 0: add_at_risk_counts(*fitters, ax=ax) - # ---- log-rank only if both groups exist ---- + # log-rank only if both groups exist if len(low_df) > 0 and len(high_df) > 0: res = logrank_test( low_df[time_label], From 2aa37a5fbdbb4d4ad61ce0e46eeab3d2ced2e431 Mon Sep 17 00:00:00 2001 From: mducducd Date: Wed, 11 Feb 2026 08:45:54 +0000 Subject: [PATCH 3/3] update comments --- src/stamp/modeling/data.py | 14 +++----------- src/stamp/statistics/survival.py | 12 ++++++------ 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 2f121f2..3eb53b3 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -341,9 +341,9 @@ def load_patient_level_data( clini_table: Path, feature_dir: Path, patient_label: PandasLabel, - ground_truth_label: PandasLabel | None = None, # <- now optional - time_label: PandasLabel | None = None, # <- for survival - status_label: PandasLabel | None = None, # <- for survival + ground_truth_label: PandasLabel | None = None, + time_label: PandasLabel | None = None, + status_label: PandasLabel | None = None, feature_ext: str = ".h5", ) -> dict[PatientId, PatientData]: """ @@ -902,15 +902,7 @@ def _parse_survival_status(value) -> int | None: None, NaN, '' -> None """ - # Handle missing inputs gracefully - # if value is None: - # return 0 # treat empty/missing as censored - # if isinstance(value, float) and math.isnan(value): - # return 0 # treat empty/missing as censored - s = str(value).strip().lower() - # if s in {"", "nan", "none"}: - # return 0 # treat empty/missing as censored # Known mappings positives = {"1", "event", "dead", "deceased", "yes", "y", "True", "true"} diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index d8f5ebf..78fb51c 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -24,7 +24,7 @@ def _comparable_pairs_count(times: np.ndarray, events: np.ndarray) -> int: def _cindex( time: np.ndarray, event: np.ndarray, - risk: np.ndarray, # will be flipped in function + risk: np.ndarray, ) -> tuple[float, int]: """Compute C-index using Lifelines convention: higher risk → shorter survival (worse outcome). @@ -40,13 +40,13 @@ def _survival_stats_for_csv( time_label: str, status_label: str, risk_label: str | None = None, - cut_off: float | None = None, # will be flipped in function + cut_off: float | None = None, ) -> pd.Series: """Compute C-index and log-rank p for one CSV.""" if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid events before computing stats --- + # Clean NaNs and invalid events before computing stats df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] if len(df) == 0: @@ -56,10 +56,10 @@ def _survival_stats_for_csv( event = np.asarray(df[status_label], dtype=int) risk = np.asarray(df[risk_label], dtype=float) - # --- Concordance index --- + # Concordance index c_index, n_pairs = _cindex(time, event, risk) - # --- Log-rank test (median split) --- + # Log-rank test (median split) median_risk = float(cut_off) if cut_off is not None else float(np.nanmedian(risk)) low_mask = risk <= median_risk high_mask = risk > median_risk @@ -113,7 +113,7 @@ def _plot_km( event = np.asarray(df[status_label], dtype=int) risk = np.asarray(df[risk_label], dtype=float) - # --- split groups --- + # split groups median_risk = float(cut_off) if cut_off is not None else np.nanmedian(risk) low_mask = risk <= median_risk high_mask = risk > median_risk