feat: integrate multi-task AttnMIL regression model#2
Open
Conversation
…ng pipeline - Fix merge conflicts in src/stamp/modeling/data.py (restore STAMP BagDataset) - Add AttnMILMultiTask model (attention pooling + multi-head regression) - Add multitask training package (dataset, lightning module, config, train) - Register 'train_multitask' CLI command (non-breaking) - Add MultitaskTrainConfig to StampConfig - Add example config and smoke tests Co-authored-by: drgmo <65294284+drgmo@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Integrate Multi-Task AttnMIL model into feature/attn-mil
feat: integrate multi-task AttnMIL regression model
Feb 6, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a multi-task attention MIL model for simultaneous regression on multiple targets (e.g. HRD, TMB, CLOVAR subtypes) from H5 tile features. Resolves merge conflicts in
data.pythat broke the existingBagDataset. All new code lives in separate modules—no existing STAMP codepaths are modified beyond wiring.Merge conflict resolution
src/stamp/modeling/data.py: Restored original STAMPBagDatasetdataclass, removed conflicting stash artifactsNew model
src/stamp/modeling/models/attn_mil_multitask.py:AttnMILMultiTask— phi→attention→softmax→weighted-sum aggregation with configurable per-target linear heads. Returnsdict[str, Tensor]includingattnweights.New training package (
src/stamp/modeling/multitask/)dataset.py:MultitaskBagDataset— groups slides per patient, concatenates H5 features, bag-samples to fixed size, auto-detects H5 key (feats/patch_embeddings)lightning_module.py:LitAttnMILMultiTask— weighted Huber/MSE multi-task loss, per-fold z-score normalization with save/load, per-head metric loggingconfig.py:MultitaskTrainConfig— Pydantic v2 config for all hyperparameterstrain.py: Standalone training entry point with train/val split, normalizer fitting, Lightning TrainerCLI & config integration
StampConfig.multitask_training: MultitaskTrainConfig | None = None(non-breaking)train_multitasksubcommandUsage
Tests
Original prompt
Prompt (zum Kopieren) für Claude Opus 4.6: attMIL Multi-Task in feature/attn-mil STAMP-nah integrieren
Du bist Senior ML/PyTorch Engineer und kennst dich mit STAMP/WSI Feature-H5 Pipelines aus.
Kontext-Zeit (für Logs/Entscheidungen): Friday, February 06, 2026 (UTC).
Ich arbeite in einem bestehenden Repo mit STAMP-ähnlicher Struktur. Ziel: In die Branch feature/attn-mil ein neues Modell “B: eigenes Multi-Task AttnMIL (PyTorch)” integrieren – so STAMP-nah wie möglich, damit bestehende STAMP-Funktionalität unverändert bleibt und ich sofort Experimente mit dem neuen Ansatz starten kann.
Trainiere direkt auf .h5 Feature-Matrizen: (n_tiles × d) pro Slide, ggf. mehrere Slides pro Patient → concat.
Bag sampling: wähle bag_size Tiles (mit/ohne replacement).
Attention pooling → case embedding z.
Multi-Task Regressions-Heads:
hrd: scalar
tmb: scalar
clovar: 4-dim
Output soll optional auch attn (Attention-Gewichte pro Tile) liefern.
STAMP darf nicht kaputt gehen: Default-Verhalten/Entry-Points bleiben identisch.
Neues attMIL-Training muss hinter einem klaren Flag/Config hängen (z. B. model.name: attn_mil_multitask).
So wenig Code-Duplizierung wie möglich: Nutze vorhandene Utilities (Config, Logging, Seeds, Splits, Metrics, Trainer-Scaffold), wo immer sinnvoll.
Möglichst kleine Diffs: Lieber neue Module hinzufügen als bestehende umzubauen.
Reproduzierbarkeit: deterministic seeds, fold-spezifische Standardisierung, gespeicherte mean/std je Target.
Robustes .h5 Laden: STAMP-Features können unterschiedliche Keys haben → Key autodetect (oder konfigurierbar) + Debug-Ausgabe (list(f.keys())) falls unklar.
Ich möchte konzeptionell sowas integrieren (du sollst es an Repo-Struktur/Style anpassen):
Dataset (Bag sampling) – Kernidee:
CSV slide_table_csv enthält Spalten PATIENT, FILENAME (evtl. weitere).
clini_csv enthält Targets, indexbar über PATIENT.
H5 enthält dataset features oder ähnliches mit shape (n_tiles, d).
Modell – Kernidee:
phi: Linear(in_dim→emb_dim)+ReLU
attn: Linear(emb_dim→emb_dim)+Tanh+Linear(emb_dim→1)
softmax über tiles
z = sum(H * A)
Heads: hrd, tmb, clovar (+ attn return)
Loss/Training:
Huber oder MSE
Weighted multitask: z. B. w_hrd=1.0, w_clovar=0.3, w_tmb=0.1
Targets: z-score nur auf Train-Fold, apply auf val/test, speichere stats.
Repo-Inspektion: Frage mich nach einem tree/Dateiliste ODER (wenn du Zugriff in Claude Code hast) scanne:
existierende STAMP training scripts / trainer
dataset/dataloader patterns
config system (yaml/hydra/argparse)
feature loading utilities
split/fold handling
logging (tensorboard/wandb)
Sag mir, wo du andocken willst (z. B. neue models/attn_mil.py, neue datasets/bag_dataset.py, neuer Trainer oder Erweiterung eines bestehenden).
Stelle max. 10 präzise Rückfragen, nur wenn nötig (z. B. exakte CSV-Spaltennamen, H5 key, d=feature dim, bestehende CLI entrypoints).
Bitte liefere:
A) Code-Änderungen (minimal und sauber)
Neues Dataset: BagDataset (oder repo-konforme Benennung)
gruppiert Slides pro Patient
lädt H5 features (autodetect key oder config h5_feature_key)
concat über Slides
bag sampling
gibt pid, X (Tensor [n, d] oder batchbar [B,n,d]), targets dict zurück
Neues Modell: AttnMILMultiTask
forward(X) akzeptiert [B,n,d] und auch [n,d] (single-bag) robust
Outputs dict {hrd, tmb, clovar, attn}
Loss-Modul / Trainer-Integration
Weighted multitask huber/mse (konfigurierbar)
z-score normalizer pro fold (fit on train, apply on val/test)
Logging: losses pro head + total, optional corr/MAE pro head
Konfig & Entry-Point
Beispiel-config für schnelles Experiment (paths, bag_size, emb_dim, lr, weights, huber delta, epochs)
CLI command, der ohne Eingriffe in bestehende STAMP-Läufe funktioniert
B) “STAMP-nah” Design-Details
nutze vorhandene STAMP conventions: z. B. LightningModule falls vorhanden, oder vorhandenen Trainer.
falls STAMP schon FeatureDataset hat: erweitere/komponiere statt neu erfinden.
Speichere Checkpoints in derselben Struktur wie STAMP (oder optional parallel, aber konsistent).
C) Debug/Validation Hooks
kleine “smoke test” Funktion:
lädt 1 patient bag
macht forward pass
prüft shapes (attn sum ~ 1)
optional: speichere Top-k Tiles nach Attention (indices) für Interpretierbarkeit (nur wenn leicht integrierbar)
Kurzer Plan (max. 15 bullets), welche Files du anfasst/neu anlegst.
Konkrete Patches als git diff-Blöcke (oder file-by-file code), sodass ich sie direkt übernehmen kann.
Runbook: exakte Commands zum Trainieren + Beispiel-Config.
Ablations/Experimente: 3 schnelle Experimente (z. B. bag_size, loss weights, emb_dim) und was ich beobachten soll.
Halte die Implementierung pragmatisch: zuerst “läuft end-to-end”,...
💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.