From 0e802a556904488f9724e1dddb1b07d4abf17d5f Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 8 Sep 2025 10:13:20 +0100 Subject: [PATCH 01/19] add table analysis tools --- mcp/server.py | 230 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 223 insertions(+), 7 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index a874e871..1039d3b4 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -11,6 +11,7 @@ import yaml from fastmcp import Context, FastMCP from pydantic import Field +import pandas as pd # Initialize the FastMCP server mcp = FastMCP("STAMP MCP Server") @@ -720,31 +721,246 @@ def read_file(path: str) -> str: @mcp.tool -def list_files(subdir: str = "") -> list: +def list_files(subdir: str = "") -> str: """ List all files and directories under the given subdirectory (default is root), recursively, - returning paths relative to the base directory. + returning paths relative to the base directory. If the list is too long, shows only directories + with file type summaries. If still too long, shows a truncated message. Args: subdir (str): Relative subdirectory path to list files from. Returns: - list: List of relative file paths found. + str: Formatted list of files/directories or summary information. """ + max_items = 50 safe = _resolve_path(subdir) if not safe.is_dir(): raise FileNotFoundError(f"Subdirectory does not exist: {subdir}") - results = [] + + # Collect all files and directories + all_items = [] + directories = {} base_len = len(str(base)) + 1 # To slice off base path + separator + for root, dirs, files in os.walk(safe): rel_root = str(root)[base_len:] # relative path under base_dir + + # Track file types in each directory + if rel_root not in directories: + directories[rel_root] = {"subdirs": [], "file_types": {}, "file_count": 0} + + # Add subdirectories for d in dirs: path = os.path.join(rel_root, d) - results.append(path + "/") + all_items.append(path + "/") + directories[rel_root]["subdirs"].append(d) + + # Add files and track their extensions for f in files: path = os.path.join(rel_root, f) - results.append(path) - return sorted(results) + all_items.append(path) + + # Track file extension + ext = Path(f).suffix.lower() or "no extension" + directories[rel_root]["file_types"][ext] = directories[rel_root]["file_types"].get(ext, 0) + 1 + directories[rel_root]["file_count"] += 1 + + # If the list is manageable, return the full list + if len(all_items) <= max_items: + return "\n".join(sorted(all_items)) + + # Try directory summary instead + dir_summary = [] + for dir_path, info in sorted(directories.items()): + if not dir_path: # Root directory + dir_display = f"/ (root)" + else: + dir_display = f"{dir_path}/" + + # File type summary + if info["file_count"] > 0: + file_types = [] + for ext, count in sorted(info["file_types"].items()): + file_types.append(f"{count} {ext}") + file_summary = f" [{', '.join(file_types)}]" + else: + file_summary = " [empty]" + + # Subdirectory info + if info["subdirs"]: + subdir_info = f" (contains {len(info['subdirs'])} subdirs)" + else: + subdir_info = "" + + dir_summary.append(f"{dir_display}{file_summary}{subdir_info}") + + # If directory summary is still too long, truncate + if len(dir_summary) > max_items: + total_dirs = len(directories) + total_files = sum(info["file_count"] for info in directories.values()) + + # Show first few directories and a summary + shown_dirs = dir_summary[:max_items//2] + summary_text = ( + f"\n... (showing first {len(shown_dirs)} of {total_dirs} directories)\n\n" + f"SUMMARY:\n" + f"- Total directories: {total_dirs}\n" + f"- Total files: {total_files}\n" + f"- Directory '{subdir or '/'}' contains too many items to display completely.\n" + f"- Use a more specific subdirectory path to see detailed listings." + ) + + # Get overall file type statistics + all_extensions = {} + for info in directories.values(): + for ext, count in info["file_types"].items(): + all_extensions[ext] = all_extensions.get(ext, 0) + count + + if all_extensions: + ext_summary = [] + for ext, count in sorted(all_extensions.items(), key=lambda x: x[1], reverse=True)[:10]: + ext_summary.append(f" {ext}: {count} files") + summary_text += f"\n\nTop file types:\n" + "\n".join(ext_summary) + if len(all_extensions) > 10: + summary_text += f"\n ... and {len(all_extensions) - 10} more file types" + + return "\n".join(shown_dirs) + summary_text + + # Return directory summary + header = f"Directory listing for '{subdir or '/'}' (showing directories with file type summaries):\n" + return header + "\n".join(dir_summary) + + +@mcp.tool +def analyze_csv(path: str) -> str: + """ + Analyze a CSV file and provide detailed information about its structure and contents. + + Args: + path (str): Relative path to the CSV file. + + Returns: + str: Detailed information about the CSV including dimensions, columns, and sample data. + """ + safe_path = _resolve_path(path) + + if not safe_path.exists(): + raise FileNotFoundError(f"CSV file does not exist: {path}") + + if not safe_path.suffix.lower() in ['.csv', '.tsv']: + raise ValueError(f"File is not a CSV file: {path}") + + try: + # Read the CSV file + df = pd.read_csv(safe_path) + + # Get basic information + num_rows, num_columns = df.shape + column_names = df.columns.tolist() + + # Get first 3 rows as examples + sample_rows = df.head(3).to_string(index=True, max_cols=None) + + # Format the output + result = f"""CSV File Analysis: {path} + +Dimensions: +- Number of rows: {num_rows:,} +- Number of columns: {num_columns} + +Column Names: +{', '.join([f'"{col}"' for col in column_names])} + +First 3 rows (sample data): +{sample_rows} + +Data Types: +{df.dtypes.to_string()} + """ + + return result.strip() + + except pd.errors.EmptyDataError: + return f"CSV file is empty: {path}" + except pd.errors.ParserError as e: + return f"Error parsing CSV file {path}: {str(e)}" + except Exception as e: + return f"Error analyzing CSV file {path}: {str(e)}" + + +@mcp.tool +def list_column_values(path: str, column_name: str) -> str: + """ + List all unique values in a specific column of a CSV file. + + Args: + path (str): Relative path to the CSV file. + column_name (str): Name of the column to analyze. + + Returns: + str: Information about the unique values in the specified column. + """ + safe_path = _resolve_path(path) + + if not safe_path.exists(): + raise FileNotFoundError(f"CSV file does not exist: {path}") + + if not safe_path.suffix.lower() in ['.csv', '.tsv']: + raise ValueError(f"File is not a CSV file: {path}") + + try: + # Read the CSV file + df = pd.read_csv(safe_path) + + # Check if column exists + if column_name not in df.columns: + available_columns = ', '.join([f'"{col}"' for col in df.columns]) + return f"Column '{column_name}' not found in CSV file: {path}\nAvailable columns: {available_columns}" + + # Get unique values + unique_values = df[column_name].unique() + + # Count occurrences of each value + value_counts = df[column_name].value_counts().sort_index() + + # Handle missing values + null_count = df[column_name].isnull().sum() + + # Format the output + result = f"""Column Analysis for '{column_name}' in {path} + +Total rows: {len(df):,} +Unique values: {len(unique_values):,} +Missing/null values: {null_count:,} + +Value distribution: +{value_counts.to_string()} + """ + + # If there are many unique values, show a sample + if len(unique_values) > 20: + result += f""" + +First 20 unique values: +{', '.join([str(val) for val in unique_values[:20]])} +... and {len(unique_values) - 20} more values + """ + else: + result += f""" + +All unique values: +{', '.join([str(val) for val in unique_values if pd.notna(val)])} + """ + + return result.strip() + + except pd.errors.EmptyDataError: + return f"CSV file is empty: {path}" + except pd.errors.ParserError as e: + return f"Error parsing CSV file {path}: {str(e)}" + except Exception as e: + return f"Error analyzing column in CSV file {path}: {str(e)}" @mcp.tool From b614399703db0967e895b65b46a84a3e14b00515 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 8 Sep 2025 14:27:52 +0100 Subject: [PATCH 02/19] fix crossval and train config --- mcp/server.py | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index 1039d3b4..f7980e44 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -52,7 +52,7 @@ async def _run_stamp(mode, config, ctx): tmp_config_path = tmp_config.name handler = MCPLogHandler(ctx) - handler.setLevel(logging.DEBUG) + handler.setLevel(logging.INFO) STAMP_LOGGER.addHandler(handler) print("Running command...") @@ -207,22 +207,11 @@ async def train_stamp( "in the slide table containing the feature file path relative to `feature_dir`" ), ] = "FILENAME", - bag_size: Annotated[ - int, - Field( - description="Amount of tiles to sample when training. " - "Reducing this value reduces memory usage, but it is not recommended as the model can miss" - "relevant regions of the slide. Default value works well on H&E tissue images." - ), - ] = 512, - batch_size: Annotated[ - int, Field(description="Amount of bags processed together.") - ] = 64, ) -> str: """ Train a model using clinical data and WSI-derived features via STAMP. Takes in a clinical table, slide associations, and extracted features - to train a model on a specified label. + to train a model on a specified label. Best option when an external cohort is available. Returns: str: message indicating the success or failure of the training operation, @@ -251,8 +240,6 @@ async def train_stamp( "categories": categories, "patient_label": patient_label, "filename_label": filename_label, - "bag_size": bag_size, - "batch_size": batch_size, } } return await _run_stamp(mode="train", config=config, ctx=ctx) @@ -307,22 +294,12 @@ async def crossval_stamp( description="Number of folds to split the data into for cross-validation" ), ] = 5, - bag_size: Annotated[ - int, - Field( - description="Amount of tiles to sample when training. " - "Reducing this value reduces memory usage, but it is not recommended as the model can miss" - "relevant regions of the slide. Default value works well on H&E tissue images." - ), - ] = 512, - batch_size: Annotated[ - int, Field(description="Amount of bags processed together.") - ] = 64, ) -> str: """ Perform cross-validation for model training using STAMP. Splits the data into folds and trains a model on each to assess generalization. Uses clinical data, features, and slide mappings. + Best option when only one cohort is available. Returns: str: A message indicating the success or failure of the cross-validation operation, along with @@ -354,10 +331,6 @@ async def crossval_stamp( "filename_label": filename_label, "n_splits": n_splits, }, - "advanced_config": { # Add advanced config for bag_size and batch_size - "bag_size": bag_size, - "batch_size": batch_size, - }, } return await _run_stamp(mode="crossval", config=config, ctx=ctx) From 0faff53cf31e47f73e963173601f7ee6ba9e514d Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 8 Sep 2025 14:28:22 +0100 Subject: [PATCH 03/19] make log debug to avoid flooding --- src/stamp/modeling/data.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 33a8c4c7..b803b927 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -35,7 +35,6 @@ ) _logger = logging.getLogger("stamp") -_logged_stamp_v1_warning = False __author__ = "Marko van Treeck" @@ -370,13 +369,9 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: == 224 ): # Historic STAMP format - # TODO: find a better way to get this warning just once - global _logged_stamp_v1_warning - if not _logged_stamp_v1_warning: - _logger.info( - f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" - ) - _logged_stamp_v1_warning = True + _logger.debug( + f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" + ) tile_size_um = Microns(256.0) tile_size_px = TilePixels(224) coords_um = coords / 224 * 256 From 2682cf90cd2dbc63280992795f6a6f72123a58ea Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Tue, 9 Sep 2025 12:18:34 +0100 Subject: [PATCH 04/19] run stamp in same process --- mcp/server.py | 38 +++++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index f7980e44..a4e59cd9 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -6,12 +6,15 @@ import subprocess import tempfile from typing import Annotated +import argparse import torch import yaml from fastmcp import Context, FastMCP from pydantic import Field import pandas as pd +from stamp.__main__ import _run_cli + # Initialize the FastMCP server mcp = FastMCP("STAMP MCP Server") @@ -27,16 +30,19 @@ class MCPLogHandler(logging.Handler): def __init__(self, ctx): super().__init__() self.ctx = ctx + self.captured_logs = [] # Store captured logs def emit(self, record): msg = self.format(record) + # Store the log message + self.captured_logs.append(msg) # Fire-and-forget the coroutine asyncio.create_task(self.ctx.log(msg)) async def _run_stamp(mode, config, ctx): """ - Run the STAMP command as a subprocess and capture its console output. + Run the STAMP command directly by calling _run_cli() instead of subprocess. Args: mode (str): The mode to run the STAMP command in (e.g., "preprocess", "train"). @@ -51,20 +57,30 @@ async def _run_stamp(mode, config, ctx): yaml.dump(config, tmp_config) tmp_config_path = tmp_config.name + # Set up logging handler to capture STAMP logs handler = MCPLogHandler(ctx) handler.setLevel(logging.INFO) STAMP_LOGGER.addHandler(handler) - print("Running command...") - try: - cmd = ["stamp", "--config", tmp_config_path, mode] - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - print("Result returned...") - print(f"Command completed successfully:\n{result.stdout}\n{result.stderr}") - return f"Command completed successfully:\n{result.stdout}\n{result.stderr}" - except subprocess.CalledProcessError as e: - return f"Command failed with error:\n{e.stdout}\n{e.stderr}" + # Create argparse Namespace object to mimic command line arguments + args = argparse.Namespace( + command=mode, + config_file_path=Path(tmp_config_path) + ) + + # Call the STAMP CLI function directly + _run_cli(args) + + # Get captured logs + captured_logs_text = "\n".join(handler.captured_logs) if handler.captured_logs else "Command completed successfully (no logs captured)" + return f"Command completed successfully:\n{captured_logs_text}" + + except Exception as e: + captured_logs_text = "\n".join(handler.captured_logs) if handler.captured_logs else "" + error_msg = f"Command failed with error: {str(e)}\n{captured_logs_text}" + return error_msg + finally: os.remove(tmp_config_path) STAMP_LOGGER.removeHandler(handler) @@ -460,7 +476,7 @@ async def statistics_stamp( output_dir="output/statistics", ground_truth_label="OUTCOME", true_class="Positive", - pred_csvs=["predictions/fold1.csv", "predictions/fold2.csv"] + pred_csvs=["/pathto/split-0/patient-preds.csv", "/pathto/split-1/patient-preds.csv"] ) "Command completed successfully: ..." """ From 3d5e6977dc7e9f7f953fce18ace606dbece500e8 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 11 Sep 2025 10:42:57 +0100 Subject: [PATCH 05/19] update heatmaps config validations --- src/stamp/heatmaps/config.py | 37 ++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/stamp/heatmaps/config.py b/src/stamp/heatmaps/config.py index 98bc1744..09ff709e 100644 --- a/src/stamp/heatmaps/config.py +++ b/src/stamp/heatmaps/config.py @@ -9,15 +9,21 @@ class HeatmapConfig(BaseModel): model_config = ConfigDict(extra="forbid") - output_dir: Path + output_dir: Path = Field(description="Directory to save heatmap outputs") - feature_dir: Path - wsi_dir: Path - checkpoint_path: Path + feature_dir: Path = Field(description="Directory containing extracted features") + wsi_dir: Path = Field(description="Directory containing whole slide images") + checkpoint_path: Path = Field(description="Path to model checkpoint file") - slide_paths: list[Path] | None = None + slide_paths: list[Path] | None = Field( + default=None, + description="Specific slide paths to process. If None, processes all slides in wsi_dir" + ) - device: str = "cuda" if torch.cuda.is_available() else "cpu" + device: str = Field( + default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu", + description="Device to use for computation" + ) opacity: float = Field( default=0.6, @@ -26,8 +32,19 @@ class HeatmapConfig(BaseModel): le=1, ) - topk: int = 0 - bottomk: int = 0 + topk: int = Field( + default=0, + description="Number of top patches to highlight. 0 means no highlighting.", + ge=0 + ) + + bottomk: int = Field( + default=0, + description="Number of bottom patches to highlight. 0 means no highlighting.", + ge=0 + ) - default_slide_mpp: SlideMPP | None = None - """MPP of the slide to use if none can be inferred from the WSI""" + default_slide_mpp: SlideMPP | None = Field( + default=None, + description="MPP of the slide to use if none can be inferred from the WSI" + ) From caf66876031dd8c60649563c49dd73031b3e4d2d Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 11 Sep 2025 10:43:35 +0100 Subject: [PATCH 06/19] improve list files and heatmaps tools --- mcp/server.py | 57 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index a4e59cd9..24f0710b 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -22,8 +22,9 @@ STAMP_LOGGER = logging.getLogger("stamp") # TODO: add proper filesystem management -base_dir = "./" -base = Path(base_dir).resolve() +WORKSPACE_FOLDER = "./" # Folder where the agent can work on. +WORKSPACE_PATH = Path(WORKSPACE_FOLDER).resolve() +LIST_OUTSIDE = True # Let the agent list files from folders outside the working directory class MCPLogHandler(logging.Handler): @@ -507,24 +508,40 @@ async def heatmaps_stamp( str, Field(description="Path of the model to generate the heatmaps with.") ], slide_paths: Annotated[ - list[str] | None, + list[str], Field( - description="List of slide paths relative " - "to `wsi_dir` to generate heatmaps for. If not specified, heatmaps will be generated " - "for all slides in `wsi_dir`." + description="List of slide paths relative to `wsi_dir` to " \ + "generate heatmaps for. The slide paths HAVE to be specified relative to `wsi_dir`.", + min_length=1, ), - ] = None, + ], topk: Annotated[ int | None, Field(description="Number of top-scoring tiles to extract") ] = None, bottomk: Annotated[ int | None, Field(description="Number of bottom-scoring tiles to extract") ] = None, + device: Annotated[ + str | None, + Field( + description="The device to use for computation. " + "Possible options are 'cuda' for NVIDIA GPUs, 'cpu' for general-purpose " + "processors, and 'mps' for Apple Silicon GPUs. Default is detected automatically" + ), + ] = None, ) -> str: """ Generate heatmaps and tile scorings from WSIs using a trained model. - Produces visual explanations and optionally extracts top/bottom - scoring tiles. + + Creates visual attention maps showing which regions the model focuses on for predictions. + Works only with tile-level features. For each slide, generates: + - Overview plots with complete heatmaps and class overlays + - Raw data including thumbnails, class maps, and per-class heatmaps + - Individual tile extractions (top/bottom scoring if specified) + + Output structure: Each slide gets its own folder + (slide name without file extension)containing plots/, raw/, and tiles/ subdirectories. + Returns: str: A message indicating the success or failure of the heatmap generation operation, @@ -537,8 +554,8 @@ async def heatmaps_stamp( wsi_dir="input/slides", checkpoint_path="models/checkpoint.pth", slide_paths=["slide1.svs", "slide2.svs"], - topk=10, - bottomk=5 + topk=3, + bottomk=3 ) "Command completed successfully: ..." """ @@ -551,6 +568,7 @@ async def heatmaps_stamp( "slide_paths": slide_paths, "topk": topk, "bottomk": bottomk, + "device": device, } } return await _run_stamp(mode="heatmaps", config=config, ctx=ctx) @@ -687,8 +705,8 @@ async def encode_patients_stamp( def _resolve_path(subpath: str) -> Path: - requested = (base / subpath).resolve() - if base not in requested.parents and requested != base: + requested = (WORKSPACE_PATH / subpath).resolve() + if WORKSPACE_PATH not in requested.parents and requested != WORKSPACE_PATH: raise PermissionError(f"Access denied: {subpath}") return requested @@ -722,17 +740,20 @@ def list_files(subdir: str = "") -> str: Returns: str: Formatted list of files/directories or summary information. """ - max_items = 50 - safe = _resolve_path(subdir) - if not safe.is_dir(): + max_items = 100 + if LIST_OUTSIDE: + subdir_path = Path(subdir) + else: + subdir_path = _resolve_path(subdir) + if not subdir_path.is_dir(): raise FileNotFoundError(f"Subdirectory does not exist: {subdir}") # Collect all files and directories all_items = [] directories = {} - base_len = len(str(base)) + 1 # To slice off base path + separator + base_len = len(str(WORKSPACE_PATH)) + 1 # To slice off base path + separator - for root, dirs, files in os.walk(safe): + for root, dirs, files in os.walk(subdir_path): rel_root = str(root)[base_len:] # relative path under base_dir # Track file types in each directory From c45efcceb326e777d106f6d029adb58343112ce1 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 11 Sep 2025 10:55:48 +0100 Subject: [PATCH 07/19] make max_items a global variable --- mcp/server.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index 24f0710b..d24d9464 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -25,6 +25,8 @@ WORKSPACE_FOLDER = "./" # Folder where the agent can work on. WORKSPACE_PATH = Path(WORKSPACE_FOLDER).resolve() LIST_OUTSIDE = True # Let the agent list files from folders outside the working directory +MAX_ITEMS = 100 # Max amount of files listed with list_files tool. +# Big values could exceed LLM's context length. When it exceeds, values are summarized. class MCPLogHandler(logging.Handler): @@ -740,7 +742,6 @@ def list_files(subdir: str = "") -> str: Returns: str: Formatted list of files/directories or summary information. """ - max_items = 100 if LIST_OUTSIDE: subdir_path = Path(subdir) else: @@ -777,7 +778,7 @@ def list_files(subdir: str = "") -> str: directories[rel_root]["file_count"] += 1 # If the list is manageable, return the full list - if len(all_items) <= max_items: + if len(all_items) <= MAX_ITEMS: return "\n".join(sorted(all_items)) # Try directory summary instead @@ -806,12 +807,12 @@ def list_files(subdir: str = "") -> str: dir_summary.append(f"{dir_display}{file_summary}{subdir_info}") # If directory summary is still too long, truncate - if len(dir_summary) > max_items: + if len(dir_summary) > MAX_ITEMS: total_dirs = len(directories) total_files = sum(info["file_count"] for info in directories.values()) # Show first few directories and a summary - shown_dirs = dir_summary[:max_items//2] + shown_dirs = dir_summary[:MAX_ITEMS//2] summary_text = ( f"\n... (showing first {len(shown_dirs)} of {total_dirs} directories)\n\n" f"SUMMARY:\n" From ca2723dd48457b5e96985ae6e3044e4906ee9787 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 11 Sep 2025 11:09:21 +0100 Subject: [PATCH 08/19] update README.md --- mcp/README.md | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/mcp/README.md b/mcp/README.md index 7821a368..0f62c1b8 100644 --- a/mcp/README.md +++ b/mcp/README.md @@ -1,23 +1,24 @@ # STAMP MCP Server -A FastMCP-based Model Context Protocol server wrapping [STAMP](https://github.com/KatherLab/STAMP)’s CLI, enabling seamless integration of STAMP preprocessing, training, encoding, evaluation, and inference into LLM-based pipelines. +A FastMCP-based Model Context Protocol server wrapping [STAMP](https://github.com/KatherLab/STAMP)'s tools, enabling seamless integration of STAMP preprocessing, training, encoding, evaluation, and inference into LLM-based pipelines. ## Overview This server lets LLM agents invoke STAMP tools via structured calls. It exposes the following tools: -- `preprocess_stamp(...)`: tile & extract WSI features -- `train_stamp(...)`: train weakly supervised models -- `crossval_stamp(...)`: k-fold cross‑validation -- `deploy_stamp(...)`: inference on held‑out data -- `encode_slides_stamp(...)`: slide-level feature encoding -- `encode_patients_stamp(...)`: patient-level feature encoding -- `heatmaps_stamp(...)`: model-based heatmap visualization -- `statistics_stamp(...)`: compute classification metrics -- `read_file(...)` & `list_files(...)`: safe disk access -- `check_available_devices()`: query Torch/Platform device availability - -Each tool serializes config into YAML, launches `stamp `, streams logs back, and returns stdout/stderr. +- `preprocess_stamp()`: tile & extract WSI features +- `train_stamp()`: train weakly supervised models +- `crossval_stamp()`: k-fold cross‑validation +- `deploy_stamp()`: inference on held‑out data +- `encode_slides_stamp()`: slide-level feature encoding +- `encode_patients_stamp()`: patient-level feature encoding +- `heatmaps_stamp()`: model-based heatmap visualization +- `statistics_stamp()`: compute classification metrics +- `read_file()` & `list_files()`: safe disk access +- `check_available_devices()`: query Torch/Platform device availability +- `analyze_csv()` & `list_column_values`: useful for clinical and slide tables + +Each tool serializes config into YAML and directly calls STAMP's internal `_run_cli()` function, streaming logs back in real-time and returning execution results. ## Installation To run the MCP server is as simple as intalling STAMP as it is explained in the main README.md file, but adding `--extra mcp` to the command. For a GPU repository installation it would be like this: From 637f3b97777016c9a49f66cf78fc408e2d942512 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 11 Sep 2025 12:21:24 +0100 Subject: [PATCH 09/19] reformat --- mcp/server.py | 156 +++++++++++++++++++---------------- src/stamp/heatmaps/config.py | 10 +-- 2 files changed, 88 insertions(+), 78 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index d24d9464..03d913ef 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -3,7 +3,6 @@ import os from pathlib import Path import platform -import subprocess import tempfile from typing import Annotated import argparse @@ -22,10 +21,12 @@ STAMP_LOGGER = logging.getLogger("stamp") # TODO: add proper filesystem management -WORKSPACE_FOLDER = "./" # Folder where the agent can work on. +WORKSPACE_FOLDER = "./" # Folder where the agent can work on. WORKSPACE_PATH = Path(WORKSPACE_FOLDER).resolve() -LIST_OUTSIDE = True # Let the agent list files from folders outside the working directory -MAX_ITEMS = 100 # Max amount of files listed with list_files tool. +LIST_OUTSIDE = ( + True # Let the agent list files from folders outside the working directory +) +MAX_ITEMS = 100 # Max amount of files listed with list_files tool. # Big values could exceed LLM's context length. When it exceeds, values are summarized. @@ -67,23 +68,26 @@ async def _run_stamp(mode, config, ctx): try: # Create argparse Namespace object to mimic command line arguments - args = argparse.Namespace( - command=mode, - config_file_path=Path(tmp_config_path) - ) - + args = argparse.Namespace(command=mode, config_file_path=Path(tmp_config_path)) + # Call the STAMP CLI function directly _run_cli(args) - + # Get captured logs - captured_logs_text = "\n".join(handler.captured_logs) if handler.captured_logs else "Command completed successfully (no logs captured)" + captured_logs_text = ( + "\n".join(handler.captured_logs) + if handler.captured_logs + else "Command completed successfully (no logs captured)" + ) return f"Command completed successfully:\n{captured_logs_text}" - + except Exception as e: - captured_logs_text = "\n".join(handler.captured_logs) if handler.captured_logs else "" + captured_logs_text = ( + "\n".join(handler.captured_logs) if handler.captured_logs else "" + ) error_msg = f"Command failed with error: {str(e)}\n{captured_logs_text}" return error_msg - + finally: os.remove(tmp_config_path) STAMP_LOGGER.removeHandler(handler) @@ -512,7 +516,7 @@ async def heatmaps_stamp( slide_paths: Annotated[ list[str], Field( - description="List of slide paths relative to `wsi_dir` to " \ + description="List of slide paths relative to `wsi_dir` to " "generate heatmaps for. The slide paths HAVE to be specified relative to `wsi_dir`.", min_length=1, ), @@ -534,16 +538,16 @@ async def heatmaps_stamp( ) -> str: """ Generate heatmaps and tile scorings from WSIs using a trained model. - + Creates visual attention maps showing which regions the model focuses on for predictions. Works only with tile-level features. For each slide, generates: - Overview plots with complete heatmaps and class overlays - - Raw data including thumbnails, class maps, and per-class heatmaps + - Raw data including thumbnails, class maps, and per-class heatmaps - Individual tile extractions (top/bottom scoring if specified) - + Output structure: Each slide gets its own folder (slide name without file extension)containing plots/, raw/, and tiles/ subdirectories. - + Returns: str: A message indicating the success or failure of the heatmap generation operation, @@ -744,51 +748,53 @@ def list_files(subdir: str = "") -> str: """ if LIST_OUTSIDE: subdir_path = Path(subdir) - else: + else: subdir_path = _resolve_path(subdir) if not subdir_path.is_dir(): raise FileNotFoundError(f"Subdirectory does not exist: {subdir}") - + # Collect all files and directories all_items = [] directories = {} base_len = len(str(WORKSPACE_PATH)) + 1 # To slice off base path + separator - + for root, dirs, files in os.walk(subdir_path): rel_root = str(root)[base_len:] # relative path under base_dir - + # Track file types in each directory if rel_root not in directories: directories[rel_root] = {"subdirs": [], "file_types": {}, "file_count": 0} - + # Add subdirectories for d in dirs: path = os.path.join(rel_root, d) all_items.append(path + "/") directories[rel_root]["subdirs"].append(d) - + # Add files and track their extensions for f in files: path = os.path.join(rel_root, f) all_items.append(path) - + # Track file extension ext = Path(f).suffix.lower() or "no extension" - directories[rel_root]["file_types"][ext] = directories[rel_root]["file_types"].get(ext, 0) + 1 + directories[rel_root]["file_types"][ext] = ( + directories[rel_root]["file_types"].get(ext, 0) + 1 + ) directories[rel_root]["file_count"] += 1 - + # If the list is manageable, return the full list if len(all_items) <= MAX_ITEMS: return "\n".join(sorted(all_items)) - + # Try directory summary instead dir_summary = [] for dir_path, info in sorted(directories.items()): if not dir_path: # Root directory - dir_display = f"/ (root)" + dir_display = "/ (root)" else: dir_display = f"{dir_path}/" - + # File type summary if info["file_count"] > 0: file_types = [] @@ -797,22 +803,22 @@ def list_files(subdir: str = "") -> str: file_summary = f" [{', '.join(file_types)}]" else: file_summary = " [empty]" - + # Subdirectory info if info["subdirs"]: subdir_info = f" (contains {len(info['subdirs'])} subdirs)" else: subdir_info = "" - + dir_summary.append(f"{dir_display}{file_summary}{subdir_info}") - + # If directory summary is still too long, truncate if len(dir_summary) > MAX_ITEMS: total_dirs = len(directories) total_files = sum(info["file_count"] for info in directories.values()) - + # Show first few directories and a summary - shown_dirs = dir_summary[:MAX_ITEMS//2] + shown_dirs = dir_summary[: MAX_ITEMS // 2] summary_text = ( f"\n... (showing first {len(shown_dirs)} of {total_dirs} directories)\n\n" f"SUMMARY:\n" @@ -821,23 +827,27 @@ def list_files(subdir: str = "") -> str: f"- Directory '{subdir or '/'}' contains too many items to display completely.\n" f"- Use a more specific subdirectory path to see detailed listings." ) - + # Get overall file type statistics all_extensions = {} for info in directories.values(): for ext, count in info["file_types"].items(): all_extensions[ext] = all_extensions.get(ext, 0) + count - + if all_extensions: ext_summary = [] - for ext, count in sorted(all_extensions.items(), key=lambda x: x[1], reverse=True)[:10]: + for ext, count in sorted( + all_extensions.items(), key=lambda x: x[1], reverse=True + )[:10]: ext_summary.append(f" {ext}: {count} files") - summary_text += f"\n\nTop file types:\n" + "\n".join(ext_summary) + summary_text += "\n\nTop file types:\n" + "\n".join(ext_summary) if len(all_extensions) > 10: - summary_text += f"\n ... and {len(all_extensions) - 10} more file types" - + summary_text += ( + f"\n ... and {len(all_extensions) - 10} more file types" + ) + return "\n".join(shown_dirs) + summary_text - + # Return directory summary header = f"Directory listing for '{subdir or '/'}' (showing directories with file type summaries):\n" return header + "\n".join(dir_summary) @@ -847,32 +857,32 @@ def list_files(subdir: str = "") -> str: def analyze_csv(path: str) -> str: """ Analyze a CSV file and provide detailed information about its structure and contents. - + Args: path (str): Relative path to the CSV file. - + Returns: str: Detailed information about the CSV including dimensions, columns, and sample data. """ safe_path = _resolve_path(path) - + if not safe_path.exists(): raise FileNotFoundError(f"CSV file does not exist: {path}") - - if not safe_path.suffix.lower() in ['.csv', '.tsv']: + + if safe_path.suffix.lower() not in [".csv", ".tsv"]: raise ValueError(f"File is not a CSV file: {path}") - + try: # Read the CSV file df = pd.read_csv(safe_path) - + # Get basic information num_rows, num_columns = df.shape column_names = df.columns.tolist() - + # Get first 3 rows as examples sample_rows = df.head(3).to_string(index=True, max_cols=None) - + # Format the output result = f"""CSV File Analysis: {path} @@ -881,7 +891,7 @@ def analyze_csv(path: str) -> str: - Number of columns: {num_columns} Column Names: -{', '.join([f'"{col}"' for col in column_names])} +{", ".join([f'"{col}"' for col in column_names])} First 3 rows (sample data): {sample_rows} @@ -889,55 +899,55 @@ def analyze_csv(path: str) -> str: Data Types: {df.dtypes.to_string()} """ - + return result.strip() - + except pd.errors.EmptyDataError: return f"CSV file is empty: {path}" except pd.errors.ParserError as e: return f"Error parsing CSV file {path}: {str(e)}" except Exception as e: return f"Error analyzing CSV file {path}: {str(e)}" - + @mcp.tool def list_column_values(path: str, column_name: str) -> str: """ List all unique values in a specific column of a CSV file. - + Args: path (str): Relative path to the CSV file. column_name (str): Name of the column to analyze. - + Returns: str: Information about the unique values in the specified column. """ safe_path = _resolve_path(path) - + if not safe_path.exists(): raise FileNotFoundError(f"CSV file does not exist: {path}") - - if not safe_path.suffix.lower() in ['.csv', '.tsv']: + + if safe_path.suffix.lower() not in [".csv", ".tsv"]: raise ValueError(f"File is not a CSV file: {path}") - + try: # Read the CSV file df = pd.read_csv(safe_path) - + # Check if column exists if column_name not in df.columns: - available_columns = ', '.join([f'"{col}"' for col in df.columns]) + available_columns = ", ".join([f'"{col}"' for col in df.columns]) return f"Column '{column_name}' not found in CSV file: {path}\nAvailable columns: {available_columns}" - + # Get unique values unique_values = df[column_name].unique() - + # Count occurrences of each value value_counts = df[column_name].value_counts().sort_index() - + # Handle missing values null_count = df[column_name].isnull().sum() - + # Format the output result = f"""Column Analysis for '{column_name}' in {path} @@ -948,24 +958,24 @@ def list_column_values(path: str, column_name: str) -> str: Value distribution: {value_counts.to_string()} """ - + # If there are many unique values, show a sample if len(unique_values) > 20: result += f""" First 20 unique values: -{', '.join([str(val) for val in unique_values[:20]])} +{", ".join([str(val) for val in unique_values[:20]])} ... and {len(unique_values) - 20} more values """ else: result += f""" All unique values: -{', '.join([str(val) for val in unique_values if pd.notna(val)])} +{", ".join([str(val) for val in unique_values if pd.notna(val)])} """ - + return result.strip() - + except pd.errors.EmptyDataError: return f"CSV file is empty: {path}" except pd.errors.ParserError as e: diff --git a/src/stamp/heatmaps/config.py b/src/stamp/heatmaps/config.py index 09ff709e..1b8b199c 100644 --- a/src/stamp/heatmaps/config.py +++ b/src/stamp/heatmaps/config.py @@ -17,12 +17,12 @@ class HeatmapConfig(BaseModel): slide_paths: list[Path] | None = Field( default=None, - description="Specific slide paths to process. If None, processes all slides in wsi_dir" + description="Specific slide paths to process. If None, processes all slides in wsi_dir", ) device: str = Field( default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu", - description="Device to use for computation" + description="Device to use for computation", ) opacity: float = Field( @@ -35,16 +35,16 @@ class HeatmapConfig(BaseModel): topk: int = Field( default=0, description="Number of top patches to highlight. 0 means no highlighting.", - ge=0 + ge=0, ) bottomk: int = Field( default=0, description="Number of bottom patches to highlight. 0 means no highlighting.", - ge=0 + ge=0, ) default_slide_mpp: SlideMPP | None = Field( default=None, - description="MPP of the slide to use if none can be inferred from the WSI" + description="MPP of the slide to use if none can be inferred from the WSI", ) From 60813282f7337262f985ec15a0658cb447cc2c0f Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Thu, 18 Sep 2025 11:06:10 +0100 Subject: [PATCH 10/19] add external path lists --- mcp/server.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index 03d913ef..2ffdb705 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -23,9 +23,12 @@ # TODO: add proper filesystem management WORKSPACE_FOLDER = "./" # Folder where the agent can work on. WORKSPACE_PATH = Path(WORKSPACE_FOLDER).resolve() -LIST_OUTSIDE = ( - True # Let the agent list files from folders outside the working directory -) +# List of additional allowed paths outside workspace +ALLOWED_EXTERNAL_PATHS = [ + "/mnt/bulk-curie/peter/fmbenchmark/images/tcga_crc", + "/mnt/bulk-curie/peter/fmbenchmark/20mag_experiments/features/tcga_crc/ctranspath/STAMP_raw_xiyuewang-ctranspath-7c998680", + # Add other specific paths you want to allow +] MAX_ITEMS = 100 # Max amount of files listed with list_files tool. # Big values could exceed LLM's context length. When it exceeds, values are summarized. @@ -711,10 +714,21 @@ async def encode_patients_stamp( def _resolve_path(subpath: str) -> Path: - requested = (WORKSPACE_PATH / subpath).resolve() - if WORKSPACE_PATH not in requested.parents and requested != WORKSPACE_PATH: - raise PermissionError(f"Access denied: {subpath}") - return requested + requested = Path(subpath).resolve() + + # Check if it's within workspace + if WORKSPACE_PATH in requested.parents or requested == WORKSPACE_PATH: + return requested + + # Check if it's in allowed external paths + for allowed_path in ALLOWED_EXTERNAL_PATHS: + allowed_path = Path(allowed_path).resolve() + # Check both: exact match OR if allowed_path is a parent of requested + if requested == allowed_path or allowed_path in requested.parents: + return requested + + # If not allowed, raise error + raise PermissionError(f"Access denied: {subpath}") @mcp.tool @@ -746,10 +760,7 @@ def list_files(subdir: str = "") -> str: Returns: str: Formatted list of files/directories or summary information. """ - if LIST_OUTSIDE: - subdir_path = Path(subdir) - else: - subdir_path = _resolve_path(subdir) + subdir_path = _resolve_path(subdir) if subdir else WORKSPACE_PATH if not subdir_path.is_dir(): raise FileNotFoundError(f"Subdirectory does not exist: {subdir}") From 7f800f37894c9c6304fbbbbb3b4cca1a57fbd03b Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Tue, 23 Sep 2025 16:29:54 +0100 Subject: [PATCH 11/19] add stamp logs after execution --- mcp/server.py | 51 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index 2ffdb705..1084b6e3 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -1,5 +1,8 @@ +"""STAMP MCP Server""" + import asyncio import logging +import logging.handlers import os from pathlib import Path import platform @@ -37,14 +40,39 @@ class MCPLogHandler(logging.Handler): def __init__(self, ctx): super().__init__() self.ctx = ctx - self.captured_logs = [] # Store captured logs + self.captured_logs = [] + # Store reference to the event loop where the context is valid + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + # No running loop, we'll handle this in emit() + self.loop = None def emit(self, record): - msg = self.format(record) - # Store the log message - self.captured_logs.append(msg) - # Fire-and-forget the coroutine - asyncio.create_task(self.ctx.log(msg)) + try: + msg = self.format(record) + self.captured_logs.append(msg) + + # Try to schedule the coroutine safely + if self.loop and not self.loop.is_closed(): + # Use call_soon_threadsafe for thread-safe scheduling + self.loop.call_soon_threadsafe( + lambda: asyncio.create_task(self._safe_log(msg)) + ) + # If no loop or loop is closed, just store the message + # The captured_logs will still be available for the return value + + except Exception as e: + # Fallback: just capture the log without async context logging + pass # The message is already in captured_logs + + async def _safe_log(self, msg): + """Safely log to context, handling any exceptions""" + try: + await self.ctx.info(msg) + except Exception: + # If context logging fails, the message is still in captured_logs + pass async def _run_stamp(mode, config, ctx): @@ -70,6 +98,7 @@ async def _run_stamp(mode, config, ctx): STAMP_LOGGER.addHandler(handler) try: + await ctx.info(f"Starting STAMP {mode} command...") # Create argparse Namespace object to mimic command line arguments args = argparse.Namespace(command=mode, config_file_path=Path(tmp_config_path)) @@ -82,6 +111,7 @@ async def _run_stamp(mode, config, ctx): if handler.captured_logs else "Command completed successfully (no logs captured)" ) + await ctx.info(f"STAMP {mode} completed successfully") return f"Command completed successfully:\n{captured_logs_text}" except Exception as e: @@ -89,6 +119,7 @@ async def _run_stamp(mode, config, ctx): "\n".join(handler.captured_logs) if handler.captured_logs else "" ) error_msg = f"Command failed with error: {str(e)}\n{captured_logs_text}" + await ctx.error(f"STAMP {mode} failed: {str(e)}") return error_msg finally: @@ -1022,6 +1053,14 @@ def check_available_devices() -> str: return f"Available devices: {', '.join(devices)}" else: return "No computation devices are available." + + +@mcp.tool +async def ping_logs(ctx: Context) -> str: + await ctx.info("ping_logs: starting") + await ctx.warning("ping_logs: still working…") + await ctx.info("ping_logs: done") + return "pong" if __name__ == "__main__": From ac6350e10bd221fcb800f6c6476c952b895e9501 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Tue, 23 Sep 2025 17:29:44 +0100 Subject: [PATCH 12/19] add realtime STAMP logs --- mcp/server.py | 49 +++++++++++++++---------------------------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index 1084b6e3..b8d1b61b 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -2,7 +2,6 @@ import asyncio import logging -import logging.handlers import os from pathlib import Path import platform @@ -30,6 +29,7 @@ ALLOWED_EXTERNAL_PATHS = [ "/mnt/bulk-curie/peter/fmbenchmark/images/tcga_crc", "/mnt/bulk-curie/peter/fmbenchmark/20mag_experiments/features/tcga_crc/ctranspath/STAMP_raw_xiyuewang-ctranspath-7c998680", + "/mnt/bulk-sirius/juan/pap_screening/datasets/example/wsi_small" # Add other specific paths you want to allow ] MAX_ITEMS = 100 # Max amount of files listed with list_files tool. @@ -37,42 +37,23 @@ class MCPLogHandler(logging.Handler): - def __init__(self, ctx): + def __init__(self, ctx, loop: asyncio.AbstractEventLoop): super().__init__() self.ctx = ctx - self.captured_logs = [] - # Store reference to the event loop where the context is valid - try: - self.loop = asyncio.get_running_loop() - except RuntimeError: - # No running loop, we'll handle this in emit() - self.loop = None + self.loop = loop + self.captured_logs = [] # FIXME: Implement so the agent can see the logs when finished. Logging is viewed by the user only. - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: + msg = self.format(record) try: - msg = self.format(record) self.captured_logs.append(msg) - - # Try to schedule the coroutine safely - if self.loop and not self.loop.is_closed(): - # Use call_soon_threadsafe for thread-safe scheduling - self.loop.call_soon_threadsafe( - lambda: asyncio.create_task(self._safe_log(msg)) - ) - # If no loop or loop is closed, just store the message - # The captured_logs will still be available for the return value - - except Exception as e: - # Fallback: just capture the log without async context logging - pass # The message is already in captured_logs - - async def _safe_log(self, msg): - """Safely log to context, handling any exceptions""" - try: - await self.ctx.info(msg) + # Thread-safe: schedule on the captured event loop + asyncio.run_coroutine_threadsafe(self.ctx.log(msg), self.loop) + # Alternatively: + # self.loop.call_soon_threadsafe(self.loop.create_task, self.ctx.log(msg)) except Exception: - # If context logging fails, the message is still in captured_logs - pass + self.handleError(record) + async def _run_stamp(mode, config, ctx): @@ -93,8 +74,8 @@ async def _run_stamp(mode, config, ctx): tmp_config_path = tmp_config.name # Set up logging handler to capture STAMP logs - handler = MCPLogHandler(ctx) - handler.setLevel(logging.INFO) + loop = asyncio.get_running_loop() + handler = MCPLogHandler(ctx, loop) STAMP_LOGGER.addHandler(handler) try: @@ -103,7 +84,7 @@ async def _run_stamp(mode, config, ctx): args = argparse.Namespace(command=mode, config_file_path=Path(tmp_config_path)) # Call the STAMP CLI function directly - _run_cli(args) + await asyncio.to_thread(_run_cli, args) # Get captured logs captured_logs_text = ( From 581c47761d204198d4468145e8aac6a08d0bfe86 Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Wed, 24 Sep 2025 17:32:32 +0100 Subject: [PATCH 13/19] add logging to remaining tools --- mcp/server.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index b8d1b61b..65dd2fe1 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -79,7 +79,7 @@ async def _run_stamp(mode, config, ctx): STAMP_LOGGER.addHandler(handler) try: - await ctx.info(f"Starting STAMP {mode} command...") + await ctx.info(f"Starting STAMP {mode} tool...") # Create argparse Namespace object to mimic command line arguments args = argparse.Namespace(command=mode, config_file_path=Path(tmp_config_path)) @@ -90,16 +90,16 @@ async def _run_stamp(mode, config, ctx): captured_logs_text = ( "\n".join(handler.captured_logs) if handler.captured_logs - else "Command completed successfully (no logs captured)" + else "Tool completed successfully (no logs captured)" ) await ctx.info(f"STAMP {mode} completed successfully") - return f"Command completed successfully:\n{captured_logs_text}" + return f"Tool completed successfully:\n{captured_logs_text}" except Exception as e: captured_logs_text = ( "\n".join(handler.captured_logs) if handler.captured_logs else "" ) - error_msg = f"Command failed with error: {str(e)}\n{captured_logs_text}" + error_msg = f"Tool failed with error: {str(e)}\n{captured_logs_text}" await ctx.error(f"STAMP {mode} failed: {str(e)}") return error_msg @@ -744,7 +744,7 @@ def _resolve_path(subpath: str) -> Path: @mcp.tool -def read_file(path: str) -> str: +async def read_file(ctx: Context, path: str) -> str: """ Read the contents of a file inside the allowed folder. @@ -754,13 +754,14 @@ def read_file(path: str) -> str: Returns: str: Content of the file. """ + await ctx.info(f"Starting read_file tool...") safe_path = _resolve_path(path) with open(safe_path, "r", encoding="utf-8") as f: return f.read() @mcp.tool -def list_files(subdir: str = "") -> str: +async def list_files(ctx: Context, subdir: str = "") -> str: """ List all files and directories under the given subdirectory (default is root), recursively, returning paths relative to the base directory. If the list is too long, shows only directories @@ -772,6 +773,7 @@ def list_files(subdir: str = "") -> str: Returns: str: Formatted list of files/directories or summary information. """ + await ctx.info(f"Starting list_files tool...") subdir_path = _resolve_path(subdir) if subdir else WORKSPACE_PATH if not subdir_path.is_dir(): raise FileNotFoundError(f"Subdirectory does not exist: {subdir}") @@ -877,7 +879,7 @@ def list_files(subdir: str = "") -> str: @mcp.tool -def analyze_csv(path: str) -> str: +async def analyze_csv(ctx: Context, path: str) -> str: """ Analyze a CSV file and provide detailed information about its structure and contents. @@ -887,6 +889,7 @@ def analyze_csv(path: str) -> str: Returns: str: Detailed information about the CSV including dimensions, columns, and sample data. """ + await ctx.info(f"Starting analyze_csv tool...") safe_path = _resolve_path(path) if not safe_path.exists(): @@ -934,7 +937,7 @@ def analyze_csv(path: str) -> str: @mcp.tool -def list_column_values(path: str, column_name: str) -> str: +async def list_column_values(ctx: Context, path: str, column_name: str) -> str: """ List all unique values in a specific column of a CSV file. @@ -945,6 +948,7 @@ def list_column_values(path: str, column_name: str) -> str: Returns: str: Information about the unique values in the specified column. """ + await ctx.info(f"Starting list_column_values tool...") safe_path = _resolve_path(path) if not safe_path.exists(): @@ -1008,7 +1012,7 @@ def list_column_values(path: str, column_name: str) -> str: @mcp.tool -def check_available_devices() -> str: +async def check_available_devices(ctx: Context) -> str: """ Check which computation devices are available on the system. This includes checking for cuda (NVIDIA GPUs) and mps (Apple Silicon GPUs). @@ -1016,6 +1020,7 @@ def check_available_devices() -> str: Returns: A string describing the available devices. """ + await ctx.info(f"Starting check_available_devices tool...") devices = [] # Check for CUDA availability From f849bd4e799ac7d74e2974351262ac929d83c72e Mon Sep 17 00:00:00 2001 From: Juan Pablo Date: Mon, 29 Sep 2025 16:12:38 +0100 Subject: [PATCH 14/19] improve path solver --- mcp/server.py | 129 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 97 insertions(+), 32 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index 65dd2fe1..7c1a4c7c 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -23,13 +23,19 @@ STAMP_LOGGER = logging.getLogger("stamp") # TODO: add proper filesystem management +# The idea would be to send thw safe workspace via HTTP Headers or roots +# if OpenAI Agents SDK already implemented it. +# Check docs for more info. WORKSPACE_FOLDER = "./" # Folder where the agent can work on. WORKSPACE_PATH = Path(WORKSPACE_FOLDER).resolve() # List of additional allowed paths outside workspace ALLOWED_EXTERNAL_PATHS = [ "/mnt/bulk-curie/peter/fmbenchmark/images/tcga_crc", "/mnt/bulk-curie/peter/fmbenchmark/20mag_experiments/features/tcga_crc/ctranspath/STAMP_raw_xiyuewang-ctranspath-7c998680", - "/mnt/bulk-sirius/juan/pap_screening/datasets/example/wsi_small" + "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/features/features-20x/virchow2/CPTAC-CCRCC/virchow2-stamp-maru-21-12-24", + "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-CCRCC/data", + "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-BRCA/features-STAMP/conch1_5-778e1572", + "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-BRCA/data", # Add other specific paths you want to allow ] MAX_ITEMS = 100 # Max amount of files listed with list_files tool. @@ -41,7 +47,7 @@ def __init__(self, ctx, loop: asyncio.AbstractEventLoop): super().__init__() self.ctx = ctx self.loop = loop - self.captured_logs = [] # FIXME: Implement so the agent can see the logs when finished. Logging is viewed by the user only. + self.captured_logs = [] # FIXME: Implement so the agent can see the logs when finished. Logging is viewed by the user only. def emit(self, record: logging.LogRecord) -> None: msg = self.format(record) @@ -55,7 +61,6 @@ def emit(self, record: logging.LogRecord) -> None: self.handleError(record) - async def _run_stamp(mode, config, ctx): """ Run the STAMP command directly by calling _run_cli() instead of subprocess. @@ -726,21 +731,64 @@ async def encode_patients_stamp( def _resolve_path(subpath: str) -> Path: - requested = Path(subpath).resolve() - - # Check if it's within workspace - if WORKSPACE_PATH in requested.parents or requested == WORKSPACE_PATH: - return requested - - # Check if it's in allowed external paths - for allowed_path in ALLOWED_EXTERNAL_PATHS: - allowed_path = Path(allowed_path).resolve() - # Check both: exact match OR if allowed_path is a parent of requested - if requested == allowed_path or allowed_path in requested.parents: - return requested - - # If not allowed, raise error - raise PermissionError(f"Access denied: {subpath}") + """ + Resolve path with security checks: + - Paths starting with /mnt/, /tmp/, /home/, etc. are treated as external absolute paths + - All other paths (including /tables, /data, etc.) are treated as workspace-relative + """ + requested = Path(subpath) + + # Check if it's a true external absolute path (starting with known system roots) + external_roots = [ + "/mnt/", + "/tmp/", + "/home/", + "/usr/", + "/var/", + "/opt/", + "/etc/", + "/root/", + "/boot/", + "/sys/", + "/proc/", + "/dev/", + ] + is_external_absolute = any(subpath.startswith(root) for root in external_roots) + + if is_external_absolute: + # This is a true external absolute path - check against allowed external paths + requested_resolved = requested.resolve() + + # Check if it's in allowed external paths + for allowed_path in ALLOWED_EXTERNAL_PATHS: + allowed_path = Path(allowed_path).resolve() + # Check both: exact match OR if allowed_path is a parent of requested + if ( + requested_resolved == allowed_path + or allowed_path in requested_resolved.parents + ): + return requested_resolved + + # If not in allowed external paths, raise error + raise PermissionError(f"Access denied to external absolute path: {subpath}") + + else: + # Treat as workspace-relative (including paths like /tables, /data, etc.) + # Remove leading slash if present to make it clearly relative + clean_path = subpath.lstrip("/") + requested_resolved = (WORKSPACE_PATH / clean_path).resolve() + + # Check if resolved path is within workspace + if ( + WORKSPACE_PATH in requested_resolved.parents + or requested_resolved == WORKSPACE_PATH + ): + return requested_resolved + + # If not within workspace, raise error + raise PermissionError( + f"Access denied: path {subpath} resolves outside workspace" + ) @mcp.tool @@ -754,7 +802,7 @@ async def read_file(ctx: Context, path: str) -> str: Returns: str: Content of the file. """ - await ctx.info(f"Starting read_file tool...") + await ctx.info("Starting read_file tool...") safe_path = _resolve_path(path) with open(safe_path, "r", encoding="utf-8") as f: return f.read() @@ -773,7 +821,7 @@ async def list_files(ctx: Context, subdir: str = "") -> str: Returns: str: Formatted list of files/directories or summary information. """ - await ctx.info(f"Starting list_files tool...") + await ctx.info("Starting list_files tool...") subdir_path = _resolve_path(subdir) if subdir else WORKSPACE_PATH if not subdir_path.is_dir(): raise FileNotFoundError(f"Subdirectory does not exist: {subdir}") @@ -812,13 +860,17 @@ async def list_files(ctx: Context, subdir: str = "") -> str: if len(all_items) <= MAX_ITEMS: return "\n".join(sorted(all_items)) - # Try directory summary instead + # Try directory summary instead with sample files dir_summary = [] + sample_files_per_dir = 5 # Show up to 5 sample files per directory + for dir_path, info in sorted(directories.items()): if not dir_path: # Root directory dir_display = "/ (root)" + current_dir = WORKSPACE_PATH else: dir_display = f"{dir_path}/" + current_dir = WORKSPACE_PATH / dir_path # File type summary if info["file_count"] > 0: @@ -837,6 +889,27 @@ async def list_files(ctx: Context, subdir: str = "") -> str: dir_summary.append(f"{dir_display}{file_summary}{subdir_info}") + # Add sample files from this directory + if info["file_count"] > 0: + try: + # Get sample files from this specific directory (not recursive) + sample_files = [] + if current_dir.exists() and current_dir.is_dir(): + for item in sorted(current_dir.iterdir()): + if item.is_file() and len(sample_files) < sample_files_per_dir: + rel_path = str(item.relative_to(WORKSPACE_PATH)) + sample_files.append(f" • {rel_path}") + + if sample_files: + dir_summary.extend(sample_files) + if info["file_count"] > sample_files_per_dir: + dir_summary.append( + f" ... and {info['file_count'] - len(sample_files)} more files" + ) + except Exception: + # If we can't read the directory, just skip the sample files + pass + # If directory summary is still too long, truncate if len(dir_summary) > MAX_ITEMS: total_dirs = len(directories) @@ -889,7 +962,7 @@ async def analyze_csv(ctx: Context, path: str) -> str: Returns: str: Detailed information about the CSV including dimensions, columns, and sample data. """ - await ctx.info(f"Starting analyze_csv tool...") + await ctx.info("Starting analyze_csv tool...") safe_path = _resolve_path(path) if not safe_path.exists(): @@ -948,7 +1021,7 @@ async def list_column_values(ctx: Context, path: str, column_name: str) -> str: Returns: str: Information about the unique values in the specified column. """ - await ctx.info(f"Starting list_column_values tool...") + await ctx.info("Starting list_column_values tool...") safe_path = _resolve_path(path) if not safe_path.exists(): @@ -1020,7 +1093,7 @@ async def check_available_devices(ctx: Context) -> str: Returns: A string describing the available devices. """ - await ctx.info(f"Starting check_available_devices tool...") + await ctx.info("Starting check_available_devices tool...") devices = [] # Check for CUDA availability @@ -1039,14 +1112,6 @@ async def check_available_devices(ctx: Context) -> str: return f"Available devices: {', '.join(devices)}" else: return "No computation devices are available." - - -@mcp.tool -async def ping_logs(ctx: Context) -> str: - await ctx.info("ping_logs: starting") - await ctx.warning("ping_logs: still working…") - await ctx.info("ping_logs: done") - return "pong" if __name__ == "__main__": From a6d8dd1a13104946626731b7e854e379878f5933 Mon Sep 17 00:00:00 2001 From: mducducd Date: Mon, 12 Jan 2026 15:17:55 +0000 Subject: [PATCH 15/19] add ticon --- src/stamp/preprocessing/__init__.py | 5 + src/stamp/preprocessing/config.py | 1 + src/stamp/preprocessing/extractor/ticon.py | 730 +++++++++++++++++++++ 3 files changed, 736 insertions(+) create mode 100644 src/stamp/preprocessing/extractor/ticon.py diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index a1844526..ab3ff0d2 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -222,6 +222,11 @@ def extract_( extractor = plip() + case ExtractorName.TICON: + from stamp.preprocessing.extractor.ticon import ticon + + extractor = ticon() + case ExtractorName.EMPTY: from stamp.preprocessing.extractor.empty import empty diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index 244d70dd..5eca41dd 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -28,6 +28,7 @@ class ExtractorName(StrEnum): MUSK = "musk" MSTAR = "mstar" PLIP = "plip" + TICON = "ticon" EMPTY = "empty" diff --git a/src/stamp/preprocessing/extractor/ticon.py b/src/stamp/preprocessing/extractor/ticon.py new file mode 100644 index 00000000..fb8f9b43 --- /dev/null +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -0,0 +1,730 @@ +import math +from collections.abc import Callable, Mapping +from functools import partial +from typing import Any + +import timm +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from jaxtyping import Float + +# from omegaconf import OmegaConf +from torch import Tensor +from torchvision import transforms + +from stamp.preprocessing.extractor import Extractor + +try: + import timm + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "h_optimus_1 dependencies not installed." + " Please reinstall stamp using `pip install 'stamp[h_optimus_1]'`" + ) from e + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.init_values = init_values + self.gamma = nn.Parameter(torch.empty(dim)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.gamma, self.init_values) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + mlp_ratio: int | float | None = (16 / 3), + bias: bool = True, + ) -> None: + super().__init__() + if hidden_features is None: + assert mlp_ratio is not None + hidden_features = int(in_features * mlp_ratio) + else: + assert mlp_ratio is None + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features // 2, in_features, bias=bias) + + def forward(self, x: Float[Tensor, "*b d"]) -> Float[Tensor, "*b d"]: + x = self.fc1(x) + x1, x2 = x.chunk(2, dim=-1) + x = self.act(x1) * x2 + x = self.fc2(x) + return x + + +class ProjectionMlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + bias: bool = True, + ) -> None: + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.norm = nn.LayerNorm(out_features) + + def forward(self, x: Float[Tensor, "*b d"]) -> Float[Tensor, "*b d"]: + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.norm(x) + return x + + +def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2( + n + ) # In the paper, we only train models that have 2^a heads for some a. This function has + else: # some good properties that only occur when the input is a power of 2. To maintain that even + closest_power_of_2 = 2 ** math.floor( + math.log2(n) + ) # when the number of heads is not a power of 2, we use this workaround. + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def scaled_dot_product_attention_custom( + query, + key, + value, + attn_bias=None, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + # attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) # pyright: ignore[reportOptionalMemberAccess] + attn_bias.to(query.dtype) # pyright: ignore[reportOptionalMemberAccess] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) # pyright: ignore[reportOptionalMemberAccess] + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + context_dim: int | None = None, + # rope_kwargs: Mapping = {}, + ) -> None: + super().__init__() + self.num_heads = num_heads + context_dim = context_dim or dim + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + # self.rope = Rope(dim=head_dim, **rope_kwargs) + slopes = torch.Tensor(get_slopes(num_heads)) + self.slopes = slopes[ + None, :, None, None + ] # einops.rearrange(slopes, 'b -> 1 b 1 1') + + def forward( + self, + x: Float[Tensor, "b n_q d"], + coords: Float[Tensor, "b n_q 2"], + context: Float[Tensor, "b n_k d_k"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n_q d"]: + if context is None or context_coords is None: + context = x + context_coords = coords + b, n_q, d = x.shape + b, n_k, _ = context.shape + h = self.num_heads + + q = self.q_proj(x).reshape(b, n_q, h, d // h).transpose(1, 2) + k = self.k_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + v = self.v_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + + corrds_expanded = coords.unsqueeze(2).expand( + -1, -1, n_k, -1 + ) # (b, m, d) -> (b, m, 1, d) -> (b, m, n, d) + context_coords_expanded = context_coords.unsqueeze(1).expand(-1, n_q, -1, -1) + euclid_dist = torch.sqrt( + torch.sum((corrds_expanded - context_coords_expanded) ** 2, dim=-1) + ) + self.slopes = self.slopes.to(x.device) + attn_bias = (-1) * self.slopes * euclid_dist[:, None, :, :] + + # x = F.scaled_dot_product_attention(q, k, v) + x = scaled_dot_product_attention_custom(q, k, v, attn_bias=attn_bias) + x = x.transpose(1, 2).reshape([b, n_q, d]) + x = self.proj(x) + return x + + +class NaiveResidual(nn.Module): + def __init__( + self, + drop_prob: float | int, + norm: nn.Module, + fn: nn.Module, + gamma: nn.Parameter, + ): + super().__init__() + self.norm = norm + self.fn = fn + self.keep_prob = 1 - drop_prob + self.gamma = gamma + + def forward( + self, + x: Float[Tensor, "b n d"], + **kwargs: Float[Tensor, "b ..."] | None, + ) -> Float[Tensor, "b n d"]: + fn_out = self.fn(self.norm(x), **kwargs) + if self.gamma is not None: + if self.keep_prob == 1.0 or not self.training: + return x + self.gamma * fn_out + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[ + :, None, None + ] + return x + self.gamma * fn_out * mask / self.keep_prob + else: + if self.keep_prob == 1.0 or not self.training: + return x + fn_out + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[ + :, None, None + ] + return x + fn_out * mask / self.keep_prob + + +class EfficientResidual(NaiveResidual): + def forward( + self, + x: Float[Tensor, "b n d"], + **kwargs: Float[Tensor, "b ..."] | None, + ) -> Float[Tensor, "b n d"]: + if self.keep_prob == 1.0 or not self.training: + if self.gamma is not None: + return x + self.gamma * self.fn(self.norm(x), **kwargs) + else: + return x + self.fn(self.norm(x), **kwargs) + + b, _, _ = x.shape + n_keep = max(int(b * self.keep_prob), 1) + indices = torch.randperm(b, device=x.device)[:n_keep] + for k, v in kwargs.items(): + if v is not None: + kwargs[k] = v[indices] + if self.gamma is not None: + return torch.index_add( + x, + dim=0, + source=self.gamma * self.fn(self.norm(x[indices]), **kwargs), + index=indices, + alpha=b / n_keep, + ) + else: + return torch.index_add( + x, + dim=0, + source=self.fn(self.norm(x[indices]), **kwargs), + index=indices, + alpha=b / n_keep, + ) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + drop_path: float | int, + norm_layer: Callable[[int], nn.Module], + context_dim: int | None, + drop_path_type: str = "efficient", + layer_scale: int = True, + attn_kwargs: Mapping = {}, + ) -> None: + super().__init__() + residual_module = { + "naive": NaiveResidual, + "efficient": EfficientResidual, + }[drop_path_type] + + self.layer_scale = layer_scale + if layer_scale: + gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) + gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) + else: + gamma1 = None + gamma2 = None + + self.residual1 = residual_module( + drop_path, + norm_layer(dim), + Attention( + dim, + context_dim=context_dim, + **attn_kwargs, + ), + gamma1, + ) + self.residual2 = residual_module( + drop_path, norm_layer(dim), Mlp(in_features=dim), gamma2 + ) + + def forward( + self, + x: Float[Tensor, "b n d"], + context: Float[Tensor, "b n_k d_k"] | None = None, + coords: Float[Tensor, "b n 2"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n d"]: + x = self.residual1( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + x = self.residual2(x) + return x + + +class Transformer(nn.Module): + def __init__( + self, + embed_dim: int, + norm_layer: Callable[[int], nn.Module], + depth: int, + drop_path_rate: float | int, + context_dim: int | None = None, + block_kwargs: Mapping[str, Any] = {}, + ): + super().__init__() + self.embed_dim = embed_dim + self.n_blocks = depth + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + drop_path=drop_path_rate, + norm_layer=norm_layer, + context_dim=context_dim, + **block_kwargs, + ) + for i in range(depth) + ], + ) + + def forward( + self, + x: Float[Tensor, "b n d"], + return_layers: set[int], + contexts: list[Float[Tensor, "b n_k d_k"]] | None = None, + coords: Float[Tensor, "b n 2"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> dict[int, Float[Tensor, "b n d"]]: + outputs = {} + if 0 in return_layers: + outputs[0] = x + for blk_idx, blk in enumerate(self.blocks): + context = contexts[blk_idx] if contexts is not None else None + x = blk( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + if blk_idx + 1 in return_layers: + outputs[blk_idx + 1] = x + return outputs + + +class EncoderDecoder(nn.Module): + def __init__( + self, + patch_size: int = 14, + in_dims: list = [], + tile_encoder_keys: list = [], + norm_layer_type: str = "LayerNorm", + transformers_kwargs: Mapping[str, Any] = {}, + encoder_kwargs: Mapping[str, Any] = {}, + decoder_kwargs: Mapping[str, Any] = {}, + norm_layer_kwargs: Mapping[str, Any] = {"eps": 1e-5}, + final_norm_kwargs: Mapping[str, Any] = {"elementwise_affine": True}, + out_layer: int = -1, + num_decoders: int = 1, + decoder_out_dims: list = [], + ): + super().__init__() + self.patch_size = patch_size + + norm_layer: Callable[[int], nn.Module] = partial( + getattr(torch.nn, norm_layer_type), **norm_layer_kwargs + ) + + self.encoder = Transformer( + **transformers_kwargs, + **encoder_kwargs, + norm_layer=norm_layer, + ) + + self.tile_encoder_keys = tile_encoder_keys + self.embed_dim = self.encoder.embed_dim + self.n_blocks = len(self.encoder.blocks) + self.out_layer = out_layer % (len(self.encoder.blocks) + 1) + self.enc_norm = norm_layer(self.embed_dim, **final_norm_kwargs) + self.num_decoders = num_decoders + self.decoder_out_dims = decoder_out_dims + + self.decoder_dict = nn.ModuleDict({}) + self.mask_dict = nn.ParameterDict({}) + self.input_proj_dict = nn.ModuleDict({}) + self.output_proj_dict = nn.ModuleDict({}) + + for i in range(len(in_dims)): + self.input_proj_dict[f"input_proj_{self.tile_encoder_keys[i]}"] = ( + ProjectionMlp( + in_features=in_dims[i], + hidden_features=self.encoder.embed_dim, + out_features=self.encoder.embed_dim, + ) + ) + + for i in range(self.num_decoders): + self.decoder_dict[f"decoder_{i}"] = nn.ModuleDict({}) + self.decoder_dict[f"decoder_{i}"]["transformer"] = Transformer( # pyright: ignore[reportIndexIssue] + **transformers_kwargs, + **decoder_kwargs, + context_dim=self.encoder.embed_dim, + norm_layer=norm_layer, + ) + + self.decoder_dict[f"decoder_{i}"]["norm"] = norm_layer( # pyright: ignore[reportIndexIssue] + self.decoder_dict[f"decoder_{i}"]["transformer"].embed_dim, # pyright: ignore[reportIndexIssue] + **final_norm_kwargs, + ) + self.mask_dict[f"mask_token_{i}"] = nn.Parameter( + torch.empty( + 1, + self.decoder_dict[f"decoder_{i}"]["transformer"].embed_dim, # pyright: ignore[reportIndexIssue] + ) + ) + + for i in range(len(self.decoder_out_dims)): + self.output_proj_dict[f"output_proj_{self.tile_encoder_keys[i]}"] = ( + ProjectionMlp( + in_features=self.encoder.embed_dim, + hidden_features=self.encoder.embed_dim, + out_features=self.decoder_out_dims[i], + ) + ) + + assert self.num_decoders <= 1 + + def init_weights(self): + for mask_key in self.mask_dict.keys(): + nn.init.normal_(self.mask_dict[mask_key], std=0.02) + self.apply(_init_weights) + return self + + def forward_features( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"] | None, + predict_coords: Float[Tensor, "b n 2"] | None, + enc_layer: int, + dec_layer: int | None, + tile_encoder_key: str | None, + ) -> tuple[Float[Tensor, "b n d"], dict | None]: + b, _, _ = x.shape + + # these are the layers we need + enc_layers = {enc_layer} + if dec_layer is not None: + enc_layers.add(len(self.encoder.blocks)) + + # encoder fwd + coords_enc = relative_coords + coords_dec = predict_coords + x = self.input_proj_dict[f"input_proj_{tile_encoder_key}"](x) + encoder_outputs = self.encoder(x, coords=coords_enc, return_layers=enc_layers) + encoder_outputs = {k: self.enc_norm(v) for k, v in encoder_outputs.items()} + + # decoder fwd + if dec_layer is not None: + dec_final_output = {} + assert self.num_decoders == 1 + for dec_index in range(self.num_decoders): + decoder_outputs = self.decoder_dict[ + f"decoder_{dec_index}" + ][ # pyright: ignore[reportIndexIssue] + "transformer" + ]( + self.mask_dict[f"mask_token_{dec_index}"][None].expand( + *coords_dec.shape[:2], # pyright: ignore[reportOptionalMemberAccess] + -1, # pyright: ignore[reportOptionalMemberAccess] + ), + contexts=[encoder_outputs[len(self.encoder.blocks)]] + * self.decoder_dict[f"decoder_{dec_index}"]["transformer"].n_blocks, # pyright: ignore[reportIndexIssue] + coords=coords_dec, + context_coords=coords_enc, + return_layers={dec_layer}, + ) + dec_output = self.decoder_dict[f"decoder_{dec_index}"]["norm"]( # pyright: ignore[reportIndexIssue] + decoder_outputs[dec_layer] + ) + + for out_index in range(len(self.decoder_out_dims)): + dec_final_output[self.tile_encoder_keys[out_index]] = ( + self.output_proj_dict[ + f"output_proj_{self.tile_encoder_keys[out_index]}" + ](dec_output) + ) + else: + dec_final_output = None + enc_output = encoder_outputs[enc_layer] + return (enc_output, dec_final_output) + + def forward( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"] | None = None, + tile_encoder_key: str | None = None, + ) -> Float[Tensor, "b n d"]: + # print("Input feature range", torch.min(x), torch.max(x)) + # print("Input coords range", torch.min(relative_coords), torch.max(relative_coords)) + enc_output, dec_output = self.forward_features( + x, + relative_coords=relative_coords, + predict_coords=None, + enc_layer=self.out_layer, + dec_layer=None, + tile_encoder_key=tile_encoder_key, + ) + + # print(torch.min(enc_output), torch.max(enc_output)) + return enc_output + + +# from https://github.com/facebookresearch/mae/blob/main/models_mae.py +def _init_weights(m: nn.Module, xavier_gain=1) -> None: + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight, gain=xavier_gain) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm | nn.RMSNorm) and m.elementwise_affine: + nn.init.constant_(m.weight, 1.0) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) # pyright: ignore[reportArgumentType] + if hasattr(m, "_device_weight_init"): + m._device_weight_init() # pyright: ignore[reportCallIssue] + + +def load_ticon(device: str = "cuda") -> nn.Module: + model_cfg = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + "conchv15", + "hoptimus1", + "uni2h", + "gigapath", + "virchow2", + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], + } + + ckpt = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + with torch.device("meta"): + model = EncoderDecoder(**model_cfg) + + model.to_empty(device=device) + model.init_weights() + + sd = torch.load(ckpt, map_location="cpu", weights_only=True) + sd = { + k.removeprefix("backbone."): v + for k, v in sd.items() + if k.startswith("backbone.") + } + + model.load_state_dict(sd, strict=False) + model.eval() + return model + + +class HOptimusTICON(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.device = device + + # ---------------------------- + # Stage 1: H-OptimUS + # ---------------------------- + self.tile_encoder = timm.create_model( + "hf-hub:bioptimus/H-optimus-1", + pretrained=True, + init_values=1e-5, + dynamic_img_size=False, + ) + + # ---------------------------- + # Stage 2: TICON + # ---------------------------- + ticon_cfg = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + "conchv15", + "hoptimus1", + "uni2h", + "gigapath", + "virchow2", + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], + } + + with torch.device("meta"): + self.ticon = EncoderDecoder(**ticon_cfg) + + self.ticon.to_empty(device=device) + self.ticon.init_weights() + + ckpt = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + sd = torch.load(ckpt, map_location="cpu", weights_only=True) + sd = { + k.removeprefix("backbone."): v + for k, v in sd.items() + if k.startswith("backbone.") + } + self.ticon.load_state_dict(sd, strict=False) + + self.to(device) + self.eval() + + @torch.inference_mode() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: [B, 3, 224, 224] (CPU or CUDA) + """ + x = x.to(self.device, non_blocking=True) + + # H-Optimus_1 + emb = self.tile_encoder(x) # [B, 1536] + emb = emb.unsqueeze(1) # [B, 1, 1536] + # TICON + # single-tile → zero relative coords + coords = torch.zeros( + emb.size(0), + 1, + 2, + device=self.device, + dtype=torch.float32, + ) + + out = self.ticon( + x=emb, + relative_coords=coords, + tile_encoder_key="hoptimus1", + ) + + return out.squeeze(1) # [B, 1536] + + +def ticon(device: str = "cuda") -> Extractor[nn.Module]: + model = HOptimusTICON(torch.device(device)) + + transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.707223, 0.578729, 0.703617), + std=(0.211883, 0.230117, 0.177517), + ), + ] + ) + + return Extractor( + model=model, + transform=transform, + identifier="ticon", + ) From 9a0433877bf96707a1cc58f7ece5b8c948f15b04 Mon Sep 17 00:00:00 2001 From: mducducd Date: Tue, 13 Jan 2026 09:56:13 +0000 Subject: [PATCH 16/19] add ticon --- getting-started.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/getting-started.md b/getting-started.md index 93f1a0e7..6d5bffec 100644 --- a/getting-started.md +++ b/getting-started.md @@ -58,6 +58,7 @@ Stamp currently supports the following feature extractors: - [mSTAR][mstar] - [MUSK][musk] - [PLIP][plip] + - [TICON][ticon] As some of the above require you to request access to the model on huggingface, @@ -158,6 +159,7 @@ meaning ignored that it was ignored during feature extraction. [EAGLE]: https://github.com/KatherLab/EAGLE [MADELEINE]: https://huggingface.co/MahmoodLab/madeleine [PRISM]: https://huggingface.co/paige-ai/Prism +[TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning" From a4d07bb0c942eb61fdb2b983237c933c25317d27 Mon Sep 17 00:00:00 2001 From: mducducd Date: Mon, 19 Jan 2026 10:18:05 +0000 Subject: [PATCH 17/19] compute patch_size_lvl0 by coords stride --- src/stamp/config.yaml | 3 ++ src/stamp/encoding/encoder/eagle.py | 62 ++++++++++++++++++++++++++--- src/stamp/encoding/encoder/titan.py | 10 ++++- 3 files changed, 68 insertions(+), 7 deletions(-) diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 796140a5..8440560b 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -26,6 +26,9 @@ preprocessing: #tile_size_um: 256.0 #tile_size_px: 224 + # Magnification level to extract tiles from + #default_slide_mpp: 1.0 + # How many workers to use for tile extraction. Should be less or equal to # the number of cores of your system. #max_workers: 8 diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index b2fb293d..d966c84e 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -1,5 +1,6 @@ import logging import os +from collections import defaultdict, deque from pathlib import Path import numpy as np @@ -59,11 +60,26 @@ def _validate_and_read_features_with_agg( f"Features located in {h5_vir2} are extracted with {extractor}" ) - if feats.shape[0] != agg_feats.shape[0]: - raise ValueError( - f"Number of ctranspath features and virchow2 features do not match:" - f" {feats.shape[0]} != {agg_feats.shape[0]}" - ) + # if feats.shape[0] != agg_feats.shape[0]: + # raise ValueError( + # f"Number of ctranspath features and virchow2 features do not match:" + # f" {feats.shape[0]} != {agg_feats.shape[0]}" + # ) + if not np.allclose(coords.coords_um, agg_coords.coords_um, atol=1e-5, rtol=0): + # Try to fix permutation by aligning virchow2 to ctp coords + try: + agg_feats, aligned_agg_coords = _align_vir2_to_ctp_by_coords( + ref_coords_um=coords.coords_um, + other_coords_um=agg_coords.coords_um, + other_feats=agg_feats, + decimals=5, + ) + agg_coords.coords_um = aligned_agg_coords # optional, for debugging + except ValueError as e: + raise ValueError( + f"Coordinates mismatch between ctranspath and virchow2 features for slide " + f"{slide_name}. Alignment attempt failed: {e}" + ) if not np.allclose(coords.coords_um, agg_coords.coords_um, atol=1e-5, rtol=0): raise ValueError( @@ -144,7 +160,7 @@ def encode_slides_( for tile_feats_filename in (progress := tqdm(os.listdir(feat_dir))): h5_ctp = os.path.join(feat_dir, tile_feats_filename) h5_vir2 = os.path.join(agg_feat_dir, tile_feats_filename) - slide_name: str = Path(tile_feats_filename).stem + slide_name: str = Path(tile_feats_filename).name progress.set_description(slide_name) # skip patient in case feature file already exists @@ -238,3 +254,37 @@ def encode_patients_( self._save_features_( output_path=output_path, feats=patient_embedding, feat_type="patient" ) + + +def _align_vir2_to_ctp_by_coords( + ref_coords_um: np.ndarray, + other_coords_um: np.ndarray, + other_feats: torch.Tensor, + decimals: int = 5, +) -> tuple[torch.Tensor, np.ndarray]: + """Align vir2 features to ctp features based on coordinates.""" + ref = np.round(np.asarray(ref_coords_um, dtype=np.float64), decimals) + oth = np.round(np.asarray(other_coords_um, dtype=np.float64), decimals) + + # coord -> queue(indices) + buckets = defaultdict(deque) + for j, key in enumerate(map(tuple, oth)): + buckets[key].append(j) + + perm = np.empty(ref.shape[0], dtype=np.int64) + for i, key in enumerate(map(tuple, ref)): + if not buckets[key]: + raise ValueError(f"Missing coord in other set: {key}") + perm[i] = buckets[key].popleft() + + # optional: check if other has extras not used + unused = sum(len(q) for q in buckets.values()) + if unused != 0: + raise ValueError(f"virchow2 features contain {unused} extra coords not in ref.") + + perm_t = torch.as_tensor(perm, dtype=torch.long, device=other_feats.device) + # Align features according to the permutation as well ! + aligned_feats = other_feats.index_select(0, perm_t) + aligned_coords = other_coords_um[perm] + print("") + return aligned_feats, aligned_coords diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 1012d98f..2dba6021 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -49,7 +49,15 @@ def _generate_slide_embedding( coords_tensor = torch.tensor(coords.coords_um, dtype=self.precision) # Convert coordinates from microns to pixels - patch_size_lvl0 = math.floor(256 / coords.mpp) # Inferred from TITAN docs + xs = torch.unique(coords_tensor[:, 0]) + ys = torch.unique(coords_tensor[:, 1]) + patch_size_lvl0 = int( + min( + (xs[1:] - xs[:-1])[(xs[1:] - xs[:-1]) > 0].min(), + (ys[1:] - ys[:-1])[(ys[1:] - ys[:-1]) > 0].min(), + ) + ) + coords_px = coords_tensor / coords.mpp # Convert to pixels coords_px = coords_px.to(torch.int64).to(device) # Convert to integer From 0181166dcfce8abc07302163ecd98dcca3afee1d Mon Sep 17 00:00:00 2001 From: mducducd Date: Fri, 23 Jan 2026 12:51:50 +0000 Subject: [PATCH 18/19] Revert "Merge remote-tracking branch 'origin/dev/tools-logging' into dev/multitask" This reverts commit e29b2287e274edcd73de15e9a9afcd6488c8f232, reversing changes made to 9a0433877bf96707a1cc58f7ece5b8c948f15b04. --- mcp/README.md | 27 +- mcp/server.py | 488 ++++++----------------------------- src/stamp/heatmaps/config.py | 37 +-- src/stamp/modeling/data.py | 11 +- 4 files changed, 106 insertions(+), 457 deletions(-) diff --git a/mcp/README.md b/mcp/README.md index 0f62c1b8..7821a368 100644 --- a/mcp/README.md +++ b/mcp/README.md @@ -1,24 +1,23 @@ # STAMP MCP Server -A FastMCP-based Model Context Protocol server wrapping [STAMP](https://github.com/KatherLab/STAMP)'s tools, enabling seamless integration of STAMP preprocessing, training, encoding, evaluation, and inference into LLM-based pipelines. +A FastMCP-based Model Context Protocol server wrapping [STAMP](https://github.com/KatherLab/STAMP)’s CLI, enabling seamless integration of STAMP preprocessing, training, encoding, evaluation, and inference into LLM-based pipelines. ## Overview This server lets LLM agents invoke STAMP tools via structured calls. It exposes the following tools: -- `preprocess_stamp()`: tile & extract WSI features -- `train_stamp()`: train weakly supervised models -- `crossval_stamp()`: k-fold cross‑validation -- `deploy_stamp()`: inference on held‑out data -- `encode_slides_stamp()`: slide-level feature encoding -- `encode_patients_stamp()`: patient-level feature encoding -- `heatmaps_stamp()`: model-based heatmap visualization -- `statistics_stamp()`: compute classification metrics -- `read_file()` & `list_files()`: safe disk access -- `check_available_devices()`: query Torch/Platform device availability -- `analyze_csv()` & `list_column_values`: useful for clinical and slide tables - -Each tool serializes config into YAML and directly calls STAMP's internal `_run_cli()` function, streaming logs back in real-time and returning execution results. +- `preprocess_stamp(...)`: tile & extract WSI features +- `train_stamp(...)`: train weakly supervised models +- `crossval_stamp(...)`: k-fold cross‑validation +- `deploy_stamp(...)`: inference on held‑out data +- `encode_slides_stamp(...)`: slide-level feature encoding +- `encode_patients_stamp(...)`: patient-level feature encoding +- `heatmaps_stamp(...)`: model-based heatmap visualization +- `statistics_stamp(...)`: compute classification metrics +- `read_file(...)` & `list_files(...)`: safe disk access +- `check_available_devices()`: query Torch/Platform device availability + +Each tool serializes config into YAML, launches `stamp `, streams logs back, and returns stdout/stderr. ## Installation To run the MCP server is as simple as intalling STAMP as it is explained in the main README.md file, but adding `--extra mcp` to the command. For a GPU repository installation it would be like this: diff --git a/mcp/server.py b/mcp/server.py index c66dc4f9..28781b2a 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -1,21 +1,16 @@ -"""STAMP MCP Server""" - import asyncio import logging import os import platform +import subprocess import tempfile from pathlib import Path from typing import Annotated -import argparse import torch import yaml from fastmcp import Context, FastMCP from pydantic import Field -import pandas as pd -from stamp.__main__ import _run_cli - # Initialize the FastMCP server mcp = FastMCP("STAMP MCP Server") @@ -23,47 +18,24 @@ STAMP_LOGGER = logging.getLogger("stamp") # TODO: add proper filesystem management -# The idea would be to send thw safe workspace via HTTP Headers or roots -# if OpenAI Agents SDK already implemented it. -# Check docs for more info. -WORKSPACE_FOLDER = "./" # Folder where the agent can work on. -WORKSPACE_PATH = Path(WORKSPACE_FOLDER).resolve() -# List of additional allowed paths outside workspace -ALLOWED_EXTERNAL_PATHS = [ - "/mnt/bulk-curie/peter/fmbenchmark/images/tcga_crc", - "/mnt/bulk-curie/peter/fmbenchmark/20mag_experiments/features/tcga_crc/ctranspath/STAMP_raw_xiyuewang-ctranspath-7c998680", - "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/features/features-20x/virchow2/CPTAC-CCRCC/virchow2-stamp-maru-21-12-24", - "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-CCRCC/data", - "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-BRCA/features-STAMP/conch1_5-778e1572", - "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-BRCA/data", - # Add other specific paths you want to allow -] -MAX_ITEMS = 100 # Max amount of files listed with list_files tool. -# Big values could exceed LLM's context length. When it exceeds, values are summarized. +base_dir = "./" +base = Path(base_dir).resolve() class MCPLogHandler(logging.Handler): - def __init__(self, ctx, loop: asyncio.AbstractEventLoop): + def __init__(self, ctx): super().__init__() self.ctx = ctx - self.loop = loop - self.captured_logs = [] # FIXME: Implement so the agent can see the logs when finished. Logging is viewed by the user only. - def emit(self, record: logging.LogRecord) -> None: + def emit(self, record): msg = self.format(record) - try: - self.captured_logs.append(msg) - # Thread-safe: schedule on the captured event loop - asyncio.run_coroutine_threadsafe(self.ctx.log(msg), self.loop) - # Alternatively: - # self.loop.call_soon_threadsafe(self.loop.create_task, self.ctx.log(msg)) - except Exception: - self.handleError(record) + # Fire-and-forget the coroutine + asyncio.create_task(self.ctx.log(msg)) async def _run_stamp(mode, config, ctx): """ - Run the STAMP command directly by calling _run_cli() instead of subprocess. + Run the STAMP command as a subprocess and capture its console output. Args: mode (str): The mode to run the STAMP command in (e.g., "preprocess", "train"). @@ -78,36 +50,20 @@ async def _run_stamp(mode, config, ctx): yaml.dump(config, tmp_config) tmp_config_path = tmp_config.name - # Set up logging handler to capture STAMP logs - loop = asyncio.get_running_loop() - handler = MCPLogHandler(ctx, loop) + handler = MCPLogHandler(ctx) + handler.setLevel(logging.DEBUG) STAMP_LOGGER.addHandler(handler) - try: - await ctx.info(f"Starting STAMP {mode} tool...") - # Create argparse Namespace object to mimic command line arguments - args = argparse.Namespace(command=mode, config_file_path=Path(tmp_config_path)) - - # Call the STAMP CLI function directly - await asyncio.to_thread(_run_cli, args) - - # Get captured logs - captured_logs_text = ( - "\n".join(handler.captured_logs) - if handler.captured_logs - else "Tool completed successfully (no logs captured)" - ) - await ctx.info(f"STAMP {mode} completed successfully") - return f"Tool completed successfully:\n{captured_logs_text}" - - except Exception as e: - captured_logs_text = ( - "\n".join(handler.captured_logs) if handler.captured_logs else "" - ) - error_msg = f"Tool failed with error: {str(e)}\n{captured_logs_text}" - await ctx.error(f"STAMP {mode} failed: {str(e)}") - return error_msg + print("Running command...") + try: + cmd = ["stamp", "--config", tmp_config_path, mode] + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + print("Result returned...") + print(f"Command completed successfully:\n{result.stdout}\n{result.stderr}") + return f"Command completed successfully:\n{result.stdout}\n{result.stderr}" + except subprocess.CalledProcessError as e: + return f"Command failed with error:\n{e.stdout}\n{e.stderr}" finally: os.remove(tmp_config_path) STAMP_LOGGER.removeHandler(handler) @@ -250,11 +206,22 @@ async def train_stamp( "in the slide table containing the feature file path relative to `feature_dir`" ), ] = "FILENAME", + bag_size: Annotated[ + int, + Field( + description="Amount of tiles to sample when training. " + "Reducing this value reduces memory usage, but it is not recommended as the model can miss" + "relevant regions of the slide. Default value works well on H&E tissue images." + ), + ] = 512, + batch_size: Annotated[ + int, Field(description="Amount of bags processed together.") + ] = 64, ) -> str: """ Train a model using clinical data and WSI-derived features via STAMP. Takes in a clinical table, slide associations, and extracted features - to train a model on a specified label. Best option when an external cohort is available. + to train a model on a specified label. Returns: str: message indicating the success or failure of the training operation, @@ -283,6 +250,8 @@ async def train_stamp( "categories": categories, "patient_label": patient_label, "filename_label": filename_label, + "bag_size": bag_size, + "batch_size": batch_size, } } return await _run_stamp(mode="train", config=config, ctx=ctx) @@ -337,12 +306,22 @@ async def crossval_stamp( description="Number of folds to split the data into for cross-validation" ), ] = 5, + bag_size: Annotated[ + int, + Field( + description="Amount of tiles to sample when training. " + "Reducing this value reduces memory usage, but it is not recommended as the model can miss" + "relevant regions of the slide. Default value works well on H&E tissue images." + ), + ] = 512, + batch_size: Annotated[ + int, Field(description="Amount of bags processed together.") + ] = 64, ) -> str: """ Perform cross-validation for model training using STAMP. Splits the data into folds and trains a model on each to assess generalization. Uses clinical data, features, and slide mappings. - Best option when only one cohort is available. Returns: str: A message indicating the success or failure of the cross-validation operation, along with @@ -374,6 +353,10 @@ async def crossval_stamp( "filename_label": filename_label, "n_splits": n_splits, }, + "advanced_config": { # Add advanced config for bag_size and batch_size + "bag_size": bag_size, + "batch_size": batch_size, + }, } return await _run_stamp(mode="crossval", config=config, ctx=ctx) @@ -503,7 +486,7 @@ async def statistics_stamp( output_dir="output/statistics", ground_truth_label="OUTCOME", true_class="Positive", - pred_csvs=["/pathto/split-0/patient-preds.csv", "/pathto/split-1/patient-preds.csv"] + pred_csvs=["predictions/fold1.csv", "predictions/fold2.csv"] ) "Command completed successfully: ..." """ @@ -534,40 +517,24 @@ async def heatmaps_stamp( str, Field(description="Path of the model to generate the heatmaps with.") ], slide_paths: Annotated[ - list[str], + list[str] | None, Field( - description="List of slide paths relative to `wsi_dir` to " - "generate heatmaps for. The slide paths HAVE to be specified relative to `wsi_dir`.", - min_length=1, + description="List of slide paths relative " + "to `wsi_dir` to generate heatmaps for. If not specified, heatmaps will be generated " + "for all slides in `wsi_dir`." ), - ], + ] = None, topk: Annotated[ int | None, Field(description="Number of top-scoring tiles to extract") ] = None, bottomk: Annotated[ int | None, Field(description="Number of bottom-scoring tiles to extract") ] = None, - device: Annotated[ - str | None, - Field( - description="The device to use for computation. " - "Possible options are 'cuda' for NVIDIA GPUs, 'cpu' for general-purpose " - "processors, and 'mps' for Apple Silicon GPUs. Default is detected automatically" - ), - ] = None, ) -> str: """ Generate heatmaps and tile scorings from WSIs using a trained model. - - Creates visual attention maps showing which regions the model focuses on for predictions. - Works only with tile-level features. For each slide, generates: - - Overview plots with complete heatmaps and class overlays - - Raw data including thumbnails, class maps, and per-class heatmaps - - Individual tile extractions (top/bottom scoring if specified) - - Output structure: Each slide gets its own folder - (slide name without file extension)containing plots/, raw/, and tiles/ subdirectories. - + Produces visual explanations and optionally extracts top/bottom + scoring tiles. Returns: str: A message indicating the success or failure of the heatmap generation operation, @@ -580,8 +547,8 @@ async def heatmaps_stamp( wsi_dir="input/slides", checkpoint_path="models/checkpoint.pth", slide_paths=["slide1.svs", "slide2.svs"], - topk=3, - bottomk=3 + topk=10, + bottomk=5 ) "Command completed successfully: ..." """ @@ -594,7 +561,6 @@ async def heatmaps_stamp( "slide_paths": slide_paths, "topk": topk, "bottomk": bottomk, - "device": device, } } return await _run_stamp(mode="heatmaps", config=config, ctx=ctx) @@ -731,68 +697,14 @@ async def encode_patients_stamp( def _resolve_path(subpath: str) -> Path: - """ - Resolve path with security checks: - - Paths starting with /mnt/, /tmp/, /home/, etc. are treated as external absolute paths - - All other paths (including /tables, /data, etc.) are treated as workspace-relative - """ - requested = Path(subpath) - - # Check if it's a true external absolute path (starting with known system roots) - external_roots = [ - "/mnt/", - "/tmp/", - "/home/", - "/usr/", - "/var/", - "/opt/", - "/etc/", - "/root/", - "/boot/", - "/sys/", - "/proc/", - "/dev/", - ] - is_external_absolute = any(subpath.startswith(root) for root in external_roots) - - if is_external_absolute: - # This is a true external absolute path - check against allowed external paths - requested_resolved = requested.resolve() - - # Check if it's in allowed external paths - for allowed_path in ALLOWED_EXTERNAL_PATHS: - allowed_path = Path(allowed_path).resolve() - # Check both: exact match OR if allowed_path is a parent of requested - if ( - requested_resolved == allowed_path - or allowed_path in requested_resolved.parents - ): - return requested_resolved - - # If not in allowed external paths, raise error - raise PermissionError(f"Access denied to external absolute path: {subpath}") - - else: - # Treat as workspace-relative (including paths like /tables, /data, etc.) - # Remove leading slash if present to make it clearly relative - clean_path = subpath.lstrip("/") - requested_resolved = (WORKSPACE_PATH / clean_path).resolve() - - # Check if resolved path is within workspace - if ( - WORKSPACE_PATH in requested_resolved.parents - or requested_resolved == WORKSPACE_PATH - ): - return requested_resolved - - # If not within workspace, raise error - raise PermissionError( - f"Access denied: path {subpath} resolves outside workspace" - ) + requested = (base / subpath).resolve() + if base not in requested.parents and requested != base: + raise PermissionError(f"Access denied: {subpath}") + return requested @mcp.tool -async def read_file(ctx: Context, path: str) -> str: +def read_file(path: str) -> str: """ Read the contents of a file inside the allowed folder. @@ -802,290 +714,41 @@ async def read_file(ctx: Context, path: str) -> str: Returns: str: Content of the file. """ - await ctx.info("Starting read_file tool...") safe_path = _resolve_path(path) with open(safe_path, "r", encoding="utf-8") as f: return f.read() @mcp.tool -async def list_files(ctx: Context, subdir: str = "") -> str: +def list_files(subdir: str = "") -> list: """ List all files and directories under the given subdirectory (default is root), recursively, - returning paths relative to the base directory. If the list is too long, shows only directories - with file type summaries. If still too long, shows a truncated message. + returning paths relative to the base directory. Args: subdir (str): Relative subdirectory path to list files from. Returns: - str: Formatted list of files/directories or summary information. + list: List of relative file paths found. """ - await ctx.info("Starting list_files tool...") - subdir_path = _resolve_path(subdir) if subdir else WORKSPACE_PATH - if not subdir_path.is_dir(): + safe = _resolve_path(subdir) + if not safe.is_dir(): raise FileNotFoundError(f"Subdirectory does not exist: {subdir}") - - # Collect all files and directories - all_items = [] - directories = {} - base_len = len(str(WORKSPACE_PATH)) + 1 # To slice off base path + separator - - for root, dirs, files in os.walk(subdir_path): + results = [] + base_len = len(str(base)) + 1 # To slice off base path + separator + for root, dirs, files in os.walk(safe): rel_root = str(root)[base_len:] # relative path under base_dir - - # Track file types in each directory - if rel_root not in directories: - directories[rel_root] = {"subdirs": [], "file_types": {}, "file_count": 0} - - # Add subdirectories for d in dirs: path = os.path.join(rel_root, d) - all_items.append(path + "/") - directories[rel_root]["subdirs"].append(d) - - # Add files and track their extensions + results.append(path + "/") for f in files: path = os.path.join(rel_root, f) - all_items.append(path) - - # Track file extension - ext = Path(f).suffix.lower() or "no extension" - directories[rel_root]["file_types"][ext] = ( - directories[rel_root]["file_types"].get(ext, 0) + 1 - ) - directories[rel_root]["file_count"] += 1 - - # If the list is manageable, return the full list - if len(all_items) <= MAX_ITEMS: - return "\n".join(sorted(all_items)) - - # Try directory summary instead with sample files - dir_summary = [] - sample_files_per_dir = 5 # Show up to 5 sample files per directory - - for dir_path, info in sorted(directories.items()): - if not dir_path: # Root directory - dir_display = "/ (root)" - current_dir = WORKSPACE_PATH - else: - dir_display = f"{dir_path}/" - current_dir = WORKSPACE_PATH / dir_path - - # File type summary - if info["file_count"] > 0: - file_types = [] - for ext, count in sorted(info["file_types"].items()): - file_types.append(f"{count} {ext}") - file_summary = f" [{', '.join(file_types)}]" - else: - file_summary = " [empty]" - - # Subdirectory info - if info["subdirs"]: - subdir_info = f" (contains {len(info['subdirs'])} subdirs)" - else: - subdir_info = "" - - dir_summary.append(f"{dir_display}{file_summary}{subdir_info}") - - # Add sample files from this directory - if info["file_count"] > 0: - try: - # Get sample files from this specific directory (not recursive) - sample_files = [] - if current_dir.exists() and current_dir.is_dir(): - for item in sorted(current_dir.iterdir()): - if item.is_file() and len(sample_files) < sample_files_per_dir: - rel_path = str(item.relative_to(WORKSPACE_PATH)) - sample_files.append(f" • {rel_path}") - - if sample_files: - dir_summary.extend(sample_files) - if info["file_count"] > sample_files_per_dir: - dir_summary.append( - f" ... and {info['file_count'] - len(sample_files)} more files" - ) - except Exception: - # If we can't read the directory, just skip the sample files - pass - - # If directory summary is still too long, truncate - if len(dir_summary) > MAX_ITEMS: - total_dirs = len(directories) - total_files = sum(info["file_count"] for info in directories.values()) - - # Show first few directories and a summary - shown_dirs = dir_summary[: MAX_ITEMS // 2] - summary_text = ( - f"\n... (showing first {len(shown_dirs)} of {total_dirs} directories)\n\n" - f"SUMMARY:\n" - f"- Total directories: {total_dirs}\n" - f"- Total files: {total_files}\n" - f"- Directory '{subdir or '/'}' contains too many items to display completely.\n" - f"- Use a more specific subdirectory path to see detailed listings." - ) - - # Get overall file type statistics - all_extensions = {} - for info in directories.values(): - for ext, count in info["file_types"].items(): - all_extensions[ext] = all_extensions.get(ext, 0) + count - - if all_extensions: - ext_summary = [] - for ext, count in sorted( - all_extensions.items(), key=lambda x: x[1], reverse=True - )[:10]: - ext_summary.append(f" {ext}: {count} files") - summary_text += "\n\nTop file types:\n" + "\n".join(ext_summary) - if len(all_extensions) > 10: - summary_text += ( - f"\n ... and {len(all_extensions) - 10} more file types" - ) - - return "\n".join(shown_dirs) + summary_text - - # Return directory summary - header = f"Directory listing for '{subdir or '/'}' (showing directories with file type summaries):\n" - return header + "\n".join(dir_summary) - - -@mcp.tool -async def analyze_csv(ctx: Context, path: str) -> str: - """ - Analyze a CSV file and provide detailed information about its structure and contents. - - Args: - path (str): Relative path to the CSV file. - - Returns: - str: Detailed information about the CSV including dimensions, columns, and sample data. - """ - await ctx.info("Starting analyze_csv tool...") - safe_path = _resolve_path(path) - - if not safe_path.exists(): - raise FileNotFoundError(f"CSV file does not exist: {path}") - - if safe_path.suffix.lower() not in [".csv", ".tsv"]: - raise ValueError(f"File is not a CSV file: {path}") - - try: - # Read the CSV file - df = pd.read_csv(safe_path) - - # Get basic information - num_rows, num_columns = df.shape - column_names = df.columns.tolist() - - # Get first 3 rows as examples - sample_rows = df.head(3).to_string(index=True, max_cols=None) - - # Format the output - result = f"""CSV File Analysis: {path} - -Dimensions: -- Number of rows: {num_rows:,} -- Number of columns: {num_columns} - -Column Names: -{", ".join([f'"{col}"' for col in column_names])} - -First 3 rows (sample data): -{sample_rows} - -Data Types: -{df.dtypes.to_string()} - """ - - return result.strip() - - except pd.errors.EmptyDataError: - return f"CSV file is empty: {path}" - except pd.errors.ParserError as e: - return f"Error parsing CSV file {path}: {str(e)}" - except Exception as e: - return f"Error analyzing CSV file {path}: {str(e)}" - - -@mcp.tool -async def list_column_values(ctx: Context, path: str, column_name: str) -> str: - """ - List all unique values in a specific column of a CSV file. - - Args: - path (str): Relative path to the CSV file. - column_name (str): Name of the column to analyze. - - Returns: - str: Information about the unique values in the specified column. - """ - await ctx.info("Starting list_column_values tool...") - safe_path = _resolve_path(path) - - if not safe_path.exists(): - raise FileNotFoundError(f"CSV file does not exist: {path}") - - if safe_path.suffix.lower() not in [".csv", ".tsv"]: - raise ValueError(f"File is not a CSV file: {path}") - - try: - # Read the CSV file - df = pd.read_csv(safe_path) - - # Check if column exists - if column_name not in df.columns: - available_columns = ", ".join([f'"{col}"' for col in df.columns]) - return f"Column '{column_name}' not found in CSV file: {path}\nAvailable columns: {available_columns}" - - # Get unique values - unique_values = df[column_name].unique() - - # Count occurrences of each value - value_counts = df[column_name].value_counts().sort_index() - - # Handle missing values - null_count = df[column_name].isnull().sum() - - # Format the output - result = f"""Column Analysis for '{column_name}' in {path} - -Total rows: {len(df):,} -Unique values: {len(unique_values):,} -Missing/null values: {null_count:,} - -Value distribution: -{value_counts.to_string()} - """ - - # If there are many unique values, show a sample - if len(unique_values) > 20: - result += f""" - -First 20 unique values: -{", ".join([str(val) for val in unique_values[:20]])} -... and {len(unique_values) - 20} more values - """ - else: - result += f""" - -All unique values: -{", ".join([str(val) for val in unique_values if pd.notna(val)])} - """ - - return result.strip() - - except pd.errors.EmptyDataError: - return f"CSV file is empty: {path}" - except pd.errors.ParserError as e: - return f"Error parsing CSV file {path}: {str(e)}" - except Exception as e: - return f"Error analyzing column in CSV file {path}: {str(e)}" + results.append(path) + return sorted(results) @mcp.tool -async def check_available_devices(ctx: Context) -> str: +def check_available_devices() -> str: """ Check which computation devices are available on the system. This includes checking for cuda (NVIDIA GPUs) and mps (Apple Silicon GPUs). @@ -1093,7 +756,6 @@ async def check_available_devices(ctx: Context) -> str: Returns: A string describing the available devices. """ - await ctx.info("Starting check_available_devices tool...") devices = [] # Check for CUDA availability diff --git a/src/stamp/heatmaps/config.py b/src/stamp/heatmaps/config.py index 1b8b199c..98bc1744 100644 --- a/src/stamp/heatmaps/config.py +++ b/src/stamp/heatmaps/config.py @@ -9,21 +9,15 @@ class HeatmapConfig(BaseModel): model_config = ConfigDict(extra="forbid") - output_dir: Path = Field(description="Directory to save heatmap outputs") + output_dir: Path - feature_dir: Path = Field(description="Directory containing extracted features") - wsi_dir: Path = Field(description="Directory containing whole slide images") - checkpoint_path: Path = Field(description="Path to model checkpoint file") + feature_dir: Path + wsi_dir: Path + checkpoint_path: Path - slide_paths: list[Path] | None = Field( - default=None, - description="Specific slide paths to process. If None, processes all slides in wsi_dir", - ) + slide_paths: list[Path] | None = None - device: str = Field( - default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu", - description="Device to use for computation", - ) + device: str = "cuda" if torch.cuda.is_available() else "cpu" opacity: float = Field( default=0.6, @@ -32,19 +26,8 @@ class HeatmapConfig(BaseModel): le=1, ) - topk: int = Field( - default=0, - description="Number of top patches to highlight. 0 means no highlighting.", - ge=0, - ) - - bottomk: int = Field( - default=0, - description="Number of bottom patches to highlight. 0 means no highlighting.", - ge=0, - ) + topk: int = 0 + bottomk: int = 0 - default_slide_mpp: SlideMPP | None = Field( - default=None, - description="MPP of the slide to use if none can be inferred from the WSI", - ) + default_slide_mpp: SlideMPP | None = None + """MPP of the slide to use if none can be inferred from the WSI""" diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 6cabec64..2f121f27 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -37,6 +37,7 @@ ) _logger = logging.getLogger("stamp") +_logged_stamp_v1_warning = False __author__ = "Marko van Treeck, Minh Duc Nguyen" @@ -568,9 +569,13 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: == 224 ): # Historic STAMP format - _logger.debug( - f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" - ) + # TODO: find a better way to get this warning just once + global _logged_stamp_v1_warning + if not _logged_stamp_v1_warning: + _logger.info( + f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" + ) + _logged_stamp_v1_warning = True tile_size_um = Microns(256.0) tile_size_px = TilePixels(224) coords_um = coords / 224 * 256 From 16d38d11965db99eb7303cf1a93d50d9eb9420be Mon Sep 17 00:00:00 2001 From: mducducd Date: Tue, 27 Jan 2026 15:54:19 +0000 Subject: [PATCH 19/19] fix splitting for survival gt --- src/stamp/modeling/crossval.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 43e76f01..0ff037cf 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -296,19 +296,25 @@ def _get_splits( *, patient_to_data: Mapping[PatientId, PatientData[Any]], n_splits: int, spliter ) -> _Splits: patients = np.array(list(patient_to_data.keys())) - skf = spliter(n_splits=n_splits, shuffle=True, random_state=0) + + # Detect survival GT: "time status" + tokens = [str(p.ground_truth).split() for p in patient_to_data.values()] + + if all(len(t) == 2 for t in tokens): + y = np.array([int(t[1]) for t in tokens], dtype=int) + skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0) + iterator = skf.split(patients, y) + else: + skf = KFold(n_splits=n_splits, shuffle=True, random_state=0) + iterator = skf.split(patients) + splits = _Splits( splits=[ _Split( - train_patients=set(patients[train_indices]), - test_patients=set(patients[test_indices]), - ) - for train_indices, test_indices in skf.split( - patients, - np.array( - [patient.ground_truth for patient in patient_to_data.values()] - ), + train_patients=set(patients[train_idx]), + test_patients=set(patients[test_idx]), ) + for train_idx, test_idx in iterator ] ) return splits