Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/pr-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,28 @@ on:
workflow_dispatch:

jobs:
regression-guards:
runs-on: ubuntu-latest
timeout-minutes: 10
permissions:
contents: read

steps:
- name: Check out repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Run regression guard tests
run: |
set -euo pipefail
python -m unittest discover -s test -p 'test_regression_bugfixes.py'

docker-test:
needs: regression-guards
runs-on: ubuntu-latest
timeout-minutes: 60
permissions:
Expand Down
46 changes: 46 additions & 0 deletions docs/2026-02-10-bugfixes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Bugfixes (2026-02-10)

## Scope

Applied fixes in `slide2vec/` only.
No edits were made in `slide2vec/hs2p` (submodule).

## Changes

- Fixed region-level model factory wiring so all supported region models initialize `tile_encoder` before wrapping with `RegionFeatureExtractor`:
- `slide2vec/models/models.py`
- Fixed subprocess process-group handling for CTRL+C termination:
- use `start_new_session=True` in child `Popen`
- keep `killpg` semantics safe
- `slide2vec/main.py`
- Improved CLI path robustness in orchestration:
- invoke embed/aggregate via modules (`-m slide2vec.embed`, `-m slide2vec.aggregate`)
- resolve hs2p path from package location
- `slide2vec/main.py`
- Fixed direct-script output path handling:
- avoid `Path(cfg.output_dir, None)` when `--output-dir` is not provided
- `slide2vec/embed.py`
- `slide2vec/aggregate.py`
- Reduced distributed deadlock risk in embedding:
- moved rank synchronization outside `try` block
- added rank-failure synchronization and propagation
- clean tmp feature shards on failure
- `slide2vec/embed.py`

## Regression Tests

Added deterministic source-level regression tests:

- `test/test_regression_bugfixes.py`

Covered checks:

- child process session isolation for `killpg`
- safe output-dir composition
- no `barrier()` calls inside `try` blocks in per-slide embed loop
- correct region-model `tile_encoder` assignments

## hs2p Suggestions (no edits applied)

- Consider converting `hs2p` script invocation in orchestration to module entry points (mirroring embed/aggregate) to reduce path coupling.
- Consider adding a small smoke test for resumed runs to validate output directory nesting behavior (`resume` vs `skip-datetime` combinations).
11 changes: 10 additions & 1 deletion slide2vec/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,20 @@ def scale_coordinates(wsi_fp, coordinates, spacing, backend):
return scaled_coordinates


def resolve_output_dir(config_output_dir: str, cli_output_dir: str | None) -> Path:
if cli_output_dir is None:
return Path(config_output_dir)
cli_path = Path(cli_output_dir)
if cli_path.is_absolute():
return cli_path
return Path(config_output_dir, cli_output_dir)


def main(args):
# setup configuration
run_on_cpu = args.run_on_cpu
cfg = get_cfg_from_file(args.config_file)
output_dir = Path(cfg.output_dir, args.output_dir)
output_dir = resolve_output_dir(cfg.output_dir, args.output_dir)
cfg.output_dir = str(output_dir)

coordinates_dir = Path(cfg.output_dir, "coordinates")
Expand Down
109 changes: 86 additions & 23 deletions slide2vec/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,27 @@ def load_sort_and_deduplicate_features(tmp_dir, name, expected_len=None):
return features_unique


def resolve_output_dir(config_output_dir: str, cli_output_dir: str | None) -> Path:
if cli_output_dir is None:
return Path(config_output_dir)
cli_path = Path(cli_output_dir)
if cli_path.is_absolute():
return cli_path
return Path(config_output_dir, cli_output_dir)


def cleanup_tmp_features(tmp_dir: Path, name: str):
for rank in range(distributed.get_global_size()):
fp = tmp_dir / f"{name}-rank_{rank}.h5"
if fp.exists():
os.remove(fp)


def main(args):
# setup configuration
run_on_cpu = args.run_on_cpu
cfg = get_cfg_from_file(args.config_file)
output_dir = Path(cfg.output_dir, args.output_dir)
output_dir = resolve_output_dir(cfg.output_dir, args.output_dir)
cfg.output_dir = str(output_dir)

if not run_on_cpu:
Expand Down Expand Up @@ -250,6 +266,14 @@ def main(args):
disable=not distributed.is_main_process(),
position=1,
):
name = wsi_fp.stem.replace(" ", "_")
feature_path = features_dir / f"{name}.pt"
if cfg.model.save_tile_embeddings:
feature_path = features_dir / f"{name}-tiles.pt"
tmp_feature_path = tmp_dir / f"{name}-rank_{distributed.get_global_rank()}.h5"

status_info = {"status": "success"}
local_failed = False
try:
dataset = create_dataset(
wsi_path=wsi_fp,
Expand Down Expand Up @@ -280,12 +304,6 @@ def main(args):
pin_memory=True,
)

name = wsi_fp.stem.replace(" ", "_")
feature_path = features_dir / f"{name}.pt"
if cfg.model.save_tile_embeddings:
feature_path = features_dir / f"{name}-tiles.pt"
tmp_feature_path = tmp_dir / f"{name}-rank_{distributed.get_global_rank()}.h5"

# get feature dimension and dtype using a dry run
with torch.inference_mode(), autocast_context:
sample_batch = next(iter(dataloader))
Expand All @@ -307,30 +325,75 @@ def main(args):
run_on_cpu,
)

if not run_on_cpu:
torch.distributed.barrier()
except Exception as e:
local_failed = True
status_info = {
"status": "failed",
"error": str(e),
"traceback": str(traceback.format_exc()),
}

any_rank_failed = local_failed
if not run_on_cpu:
# Ensure every rank reaches sync points, even when one rank failed.
torch.distributed.barrier()
failure_flag = torch.tensor(
1 if local_failed else 0, device=model.device, dtype=torch.int32
)
torch.distributed.all_reduce(
failure_flag, op=torch.distributed.ReduceOp.MAX
)
any_rank_failed = bool(failure_flag.item())

if any_rank_failed:
if distributed.is_main_process():
wsi_feature = load_sort_and_deduplicate_features(tmp_dir, name, expected_len=len(dataset))
cleanup_tmp_features(tmp_dir, name)
if status_info["status"] != "failed":
status_info = {
"status": "failed",
"error": "Feature extraction failed on at least one distributed rank.",
"traceback": "",
}
elif distributed.is_main_process():
try:
wsi_feature = load_sort_and_deduplicate_features(
tmp_dir, name, expected_len=len(dataset)
)
torch.save(wsi_feature, feature_path)

# cleanup
del wsi_feature
except Exception as e:
any_rank_failed = True
cleanup_tmp_features(tmp_dir, name)
status_info = {
"status": "failed",
"error": str(e),
"traceback": str(traceback.format_exc()),
}
finally:
if "wsi_feature" in locals():
del wsi_feature
if not run_on_cpu:
torch.cuda.empty_cache()
gc.collect()

if not run_on_cpu:
torch.distributed.barrier()

feature_extraction_updates[str(wsi_fp)] = {"status": "success"}
if not run_on_cpu:
# Propagate post-processing failures from rank 0 to all ranks.
failure_flag = torch.tensor(
1 if (distributed.is_main_process() and any_rank_failed) else 0,
device=model.device,
dtype=torch.int32,
)
torch.distributed.broadcast(failure_flag, src=0)
torch.distributed.barrier()
any_rank_failed = bool(failure_flag.item())

except Exception as e:
feature_extraction_updates[str(wsi_fp)] = {
"status": "failed",
"error": str(e),
"traceback": str(traceback.format_exc()),
}
if distributed.is_main_process():
if any_rank_failed and status_info["status"] != "failed":
status_info = {
"status": "failed",
"error": "Feature extraction failed on at least one distributed rank.",
"traceback": "",
}
feature_extraction_updates[str(wsi_fp)] = status_info

# update process_df
if distributed.is_main_process():
Expand Down
17 changes: 11 additions & 6 deletions slide2vec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from slide2vec.utils.config import hf_login, setup

PACKAGE_ROOT = Path(__file__).resolve().parent


def get_args_parser(add_help: bool = True):
parser = argparse.ArgumentParser("slide2vec", add_help=add_help)
Expand Down Expand Up @@ -91,7 +93,8 @@ def run_feature_extraction(config_file, output_dir, run_on_cpu: False):
"torch.distributed.run",
f"--master_port={free_port}",
"--nproc_per_node=gpu",
"slide2vec/embed.py",
"-m",
"slide2vec.embed",
"--config-file",
os.path.abspath(config_file),
"--output-dir",
Expand All @@ -100,15 +103,16 @@ def run_feature_extraction(config_file, output_dir, run_on_cpu: False):
if run_on_cpu:
cmd = [
sys.executable,
"slide2vec/embed.py",
"-m",
"slide2vec.embed",
"--config-file",
os.path.abspath(config_file),
"--output-dir",
os.path.abspath(output_dir),
"--run-on-cpu",
]
# launch in its own process group.
proc = subprocess.Popen(cmd)
proc = subprocess.Popen(cmd, start_new_session=True)
try:
proc.wait()
except KeyboardInterrupt:
Expand All @@ -126,7 +130,8 @@ def run_feature_aggregation(config_file, output_dir, run_on_cpu: False):
# find a free port
cmd = [
sys.executable,
"slide2vec/aggregate.py",
"-m",
"slide2vec.aggregate",
"--config-file",
os.path.abspath(config_file),
"--output-dir",
Expand All @@ -135,7 +140,7 @@ def run_feature_aggregation(config_file, output_dir, run_on_cpu: False):
if run_on_cpu:
cmd.append("--run-on-cpu")
# launch in its own process group.
proc = subprocess.Popen(cmd)
proc = subprocess.Popen(cmd, start_new_session=True)
try:
proc.wait()
except KeyboardInterrupt:
Expand All @@ -156,7 +161,7 @@ def main(args):

hf_login()

root_dir = "slide2vec/hs2p"
root_dir = PACKAGE_ROOT / "hs2p"
if cfg.resume:
# need to remove the dirname to avoid nested output directories
hs2p_output_dir = output_dir.parent
Expand Down
12 changes: 6 additions & 6 deletions slide2vec/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,17 @@ def __init__(
elif options.name == "h-optimus-1":
tile_encoder = Hoptimus1()
elif options.name == "conch":
model = CONCH()
tile_encoder = CONCH()
elif options.name == "musk":
model = MUSK()
tile_encoder = MUSK()
elif options.name == "phikonv2":
model = PhikonV2()
tile_encoder = PhikonV2()
elif options.name == "hibou":
model = Hibou()
tile_encoder = Hibou()
elif options.name == "kaiko":
model = Kaiko(arch=options.arch)
tile_encoder = Kaiko(arch=options.arch)
elif options.name == "kaiko-midnight":
model = Midnight12k()
tile_encoder = Midnight12k()
elif options.name == "rumc-vit-s-50k":
tile_encoder = CustomViT(
arch="vit_small",
Expand Down
Loading
Loading