Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/stamp/encoding/encoder/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,12 @@ def _align_vir2_to_ctp_by_coords(
decimals: int = 5,
) -> tuple[torch.Tensor, np.ndarray]:
"""Align vir2 features to ctp features based on coordinates."""
# round coordinates to avoid floating-point precision mismatches
ref = np.round(np.asarray(ref_coords_um, dtype=np.float64), decimals)
oth = np.round(np.asarray(other_coords_um, dtype=np.float64), decimals)

# coord -> queue(indices)
# build mapping: coordinate -> queue of indices
# using deque ensures stable matching when duplicates exist
buckets = defaultdict(deque)
for j, key in enumerate(map(tuple, oth)):
buckets[key].append(j)
Expand Down
14 changes: 3 additions & 11 deletions src/stamp/modeling/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,9 @@ def load_patient_level_data(
clini_table: Path,
feature_dir: Path,
patient_label: PandasLabel,
ground_truth_label: PandasLabel | None = None, # <- now optional
time_label: PandasLabel | None = None, # <- for survival
status_label: PandasLabel | None = None, # <- for survival
ground_truth_label: PandasLabel | None = None,
time_label: PandasLabel | None = None,
status_label: PandasLabel | None = None,
feature_ext: str = ".h5",
) -> dict[PatientId, PatientData]:
"""
Expand Down Expand Up @@ -902,15 +902,7 @@ def _parse_survival_status(value) -> int | None:
None, NaN, '' -> None
"""

# Handle missing inputs gracefully
# if value is None:
# return 0 # treat empty/missing as censored
# if isinstance(value, float) and math.isnan(value):
# return 0 # treat empty/missing as censored

s = str(value).strip().lower()
# if s in {"", "nan", "none"}:
# return 0 # treat empty/missing as censored

# Known mappings
positives = {"1", "event", "dead", "deceased", "yes", "y", "True", "true"}
Expand Down
8 changes: 6 additions & 2 deletions src/stamp/preprocessing/extractor/ticon.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
This file contains code adapted from:
TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning
https://github.com/cvlab-stonybrook/TICON
"""

import math
from collections.abc import Callable, Mapping
from functools import partial
Expand All @@ -8,8 +14,6 @@
import torch.nn as nn
from huggingface_hub import hf_hub_download
from jaxtyping import Float

# from omegaconf import OmegaConf
from torch import Tensor
from torchvision import transforms

Expand Down
41 changes: 26 additions & 15 deletions src/stamp/statistics/survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _comparable_pairs_count(times: np.ndarray, events: np.ndarray) -> int:
def _cindex(
time: np.ndarray,
event: np.ndarray,
risk: np.ndarray, # will be flipped in function
risk: np.ndarray,
) -> tuple[float, int]:
"""Compute C-index using Lifelines convention:
higher risk → shorter survival (worse outcome).
Expand All @@ -40,13 +40,13 @@ def _survival_stats_for_csv(
time_label: str,
status_label: str,
risk_label: str | None = None,
cut_off: float | None = None, # will be flipped in function
cut_off: float | None = None,
) -> pd.Series:
"""Compute C-index and log-rank p for one CSV."""
if risk_label is None:
risk_label = "pred_score"

# --- Clean NaNs and invalid events before computing stats ---
# Clean NaNs and invalid events before computing stats
df = df.dropna(subset=[time_label, status_label, risk_label]).copy()
df = df[df[status_label].isin([0, 1])]
if len(df) == 0:
Expand All @@ -56,10 +56,10 @@ def _survival_stats_for_csv(
event = np.asarray(df[status_label], dtype=int)
risk = np.asarray(df[risk_label], dtype=float)

# --- Concordance index ---
# Concordance index
c_index, n_pairs = _cindex(time, event, risk)

# --- Log-rank test (median split) ---
# Log-rank test (median split)
median_risk = float(cut_off) if cut_off is not None else float(np.nanmedian(risk))
low_mask = risk <= median_risk
high_mask = risk > median_risk
Expand Down Expand Up @@ -113,7 +113,7 @@ def _plot_km(
event = np.asarray(df[status_label], dtype=int)
risk = np.asarray(df[risk_label], dtype=float)

# --- split groups ---
# split groups
median_risk = float(cut_off) if cut_off is not None else np.nanmedian(risk)
low_mask = risk <= median_risk
high_mask = risk > median_risk
Expand All @@ -136,16 +136,27 @@ def _plot_km(
)
kmf_high.plot_survival_function(ax=ax, ci_show=False, color="red")

add_at_risk_counts(kmf_low, kmf_high, ax=ax)
# add at-risk counts only for fitted curves
fitters = []
if len(low_df) > 0:
fitters.append(kmf_low)
if len(high_df) > 0:
fitters.append(kmf_high)

# --- log-rank and c-index ---
res = logrank_test(
low_df[time_label],
high_df[time_label],
event_observed_A=low_df[status_label],
event_observed_B=high_df[status_label],
)
logrank_p = float(res.p_value)
if len(fitters) > 0:
add_at_risk_counts(*fitters, ax=ax)

# log-rank only if both groups exist
if len(low_df) > 0 and len(high_df) > 0:
res = logrank_test(
low_df[time_label],
high_df[time_label],
event_observed_A=low_df[status_label],
event_observed_B=high_df[status_label],
)
logrank_p = float(res.p_value)
else:
logrank_p = float("nan")
c_used, used, *_ = _cindex(time, event, risk)

ax.text(
Expand Down
9 changes: 4 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.