diff --git a/.github/workflows/pr-test.yaml b/.github/workflows/pr-test.yaml index 748abf3..dc7e492 100644 --- a/.github/workflows/pr-test.yaml +++ b/.github/workflows/pr-test.yaml @@ -24,7 +24,7 @@ jobs: - name: Run regression guard tests run: | set -euo pipefail - python -m unittest discover -s test -p 'test_regression_bugfixes.py' + python -m unittest discover -s test -p 'test_*.py' docker-test: needs: regression-guards diff --git a/docs/2026-02-11-embed-v2.md b/docs/2026-02-11-embed-v2.md new file mode 100644 index 0000000..793b58a --- /dev/null +++ b/docs/2026-02-11-embed-v2.md @@ -0,0 +1,51 @@ +# Embed V2 Data-Loading Optimization + +This update introduces a new optional embedding execution path (`speed.embedding_pipeline: "v2"`) focused on higher-throughput tile/region loading for single-node multi-GPU machines. + +## Highlights + +- Worker-local WSI caching in `TileDataset`: + - avoids reopening `wholeslidedata.WholeSlideImage` per tile +- Adaptive sharding in `embed.py`: + - `rank_sharding_mode: auto` chooses slide-level sharding when pending slides >= world size + - falls back to tile-level sharding when pending slides are too few +- Configurable DataLoader tuning: + - `num_workers_embedding` supports `"auto"` heuristics + - `prefetch_factor_embedding`, `persistent_workers_embedding`, `pin_memory_embedding` +- Lazy output writer initialization: + - removes warmup dry-run batch and infers feature shape from first real batch +- Optional perf logging: + - per-slide throughput and data-wait fraction via `speed.log_perf_embedding` + +## Compatibility + +- Default pipeline remains `v1`. +- Output artifacts remain unchanged by default (`features/.pt` or `-tiles.pt`). + +## New profile config + +- `slide2vec/configs/h100-v2.yaml` provides a recommended H100-oriented preset with: + - `embedding_pipeline: "v2"` + - `rank_sharding_mode: "auto"` + - `storage_mode: "network"` + - `num_workers_embedding: "auto"` + - `prefetch_factor_embedding: 4` + +## Benchmark harness + +- `scripts/benchmark_embed_v1_v2.py` benchmarks embed-only throughput for `v1` vs `v2` using an existing tiling output directory. +- Expected input directory contains: + - `process_list.csv` + - `coordinates/` + +Example: + +```bash +python3 scripts/benchmark_embed_v1_v2.py \ + --config-file slide2vec/configs/h100-v2.yaml \ + --baseline-output-dir /path/to/tiling-output \ + --benchmark-output-dir /path/to/embed-benchmark \ + --gpu-counts 1,4,8 \ + --pipelines v1,v2 \ + --repeats 2 +``` diff --git a/scripts/benchmark_embed_v1_v2.py b/scripts/benchmark_embed_v1_v2.py new file mode 100644 index 0000000..5f0b74e --- /dev/null +++ b/scripts/benchmark_embed_v1_v2.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +import argparse +import json +import shutil +import subprocess +import sys +import time +from pathlib import Path + + +def parse_args(): + parser = argparse.ArgumentParser( + description=( + "Benchmark slide2vec.embed v1 vs v2 using an existing tiling output " + "(process_list.csv + coordinates)." + ) + ) + parser.add_argument("--config-file", required=True, help="Path to embedding config YAML.") + parser.add_argument( + "--baseline-output-dir", + required=True, + help=( + "Directory that already contains tiling artifacts: process_list.csv and coordinates/." + ), + ) + parser.add_argument( + "--benchmark-output-dir", + default="outputs/embed-benchmark", + help="Directory where benchmark run folders and summary files are written.", + ) + parser.add_argument( + "--gpu-counts", + default="1,4,8", + help="Comma-separated GPU counts to benchmark (single-node only).", + ) + parser.add_argument( + "--pipelines", + default="v1,v2", + help="Comma-separated embedding pipelines to benchmark (v1,v2).", + ) + parser.add_argument( + "--repeats", + type=int, + default=1, + help="Number of repetitions per (pipeline, gpu_count).", + ) + parser.add_argument( + "--python-exe", + default=sys.executable, + help="Python executable used to launch embed.py.", + ) + parser.add_argument( + "--extra-opt", + action="append", + default=[], + help=( + "Additional embed CLI overrides in path.key=value format. " + "Repeat for multiple options." + ), + ) + parser.add_argument( + "--run-on-cpu", + action="store_true", + help="Run all benchmark commands on CPU (for quick functional validation).", + ) + return parser.parse_args() + + +def parse_csv_list(value: str): + return [item.strip() for item in value.split(",") if item.strip()] + + +def compute_total_tiles(process_df, coordinates_dir: Path): + import numpy as np + + tiled_df = process_df[process_df["tiling_status"] == "success"] + total_tiles = 0 + for wsi_path in tiled_df["wsi_path"].tolist(): + name = Path(wsi_path).stem.replace(" ", "_") + coord_file = coordinates_dir / f"{name}.npy" + if not coord_file.exists(): + continue + arr = np.load(coord_file, allow_pickle=True) + total_tiles += int(len(arr["x"])) + return total_tiles, len(tiled_df) + + +def reset_process_list_for_embed(process_df): + df = process_df.copy() + if "feature_status" not in df.columns: + df["feature_status"] = ["tbp"] * len(df) + else: + df["feature_status"] = ["tbp" if x == "success" else "tbp" for x in df["feature_status"]] + if "error" in df.columns: + df["error"] = ["" for _ in range(len(df))] + if "traceback" in df.columns: + df["traceback"] = ["" for _ in range(len(df))] + return df + + +def build_command( + *, + python_exe: str, + gpu_count: int, + run_on_cpu: bool, + config_file: Path, + run_dir: Path, + coords_dir: Path, + pipeline: str, + extra_opts: list[str], +): + if gpu_count > 1 and not run_on_cpu: + cmd = [ + python_exe, + "-m", + "torch.distributed.run", + f"--nproc_per_node={gpu_count}", + "slide2vec/embed.py", + ] + else: + cmd = [python_exe, "slide2vec/embed.py"] + + cmd.extend( + [ + "--config-file", + str(config_file.resolve()), + "--output-dir", + str(run_dir.resolve()), + ] + ) + if run_on_cpu: + cmd.append("--run-on-cpu") + + opts = [ + f"tiling.read_coordinates_from={coords_dir.resolve()}", + f"speed.embedding_pipeline={pipeline}", + "speed.rank_sharding_mode=auto", + "speed.log_perf_embedding=true", + ] + opts.extend(extra_opts) + cmd.extend(opts) + return cmd + + +def render_markdown_table(results): + lines = [] + lines.append("| pipeline | gpus | repeat | wall_sec | tiles | tiles_per_sec | status |") + lines.append("|---|---:|---:|---:|---:|---:|---|") + for row in results: + lines.append( + "| {pipeline} | {gpu_count} | {repeat} | {wall_sec:.2f} | {tiles} | {tiles_per_sec:.2f} | {status} |".format( + **row + ) + ) + return "\n".join(lines) + + +def main(): + args = parse_args() + import pandas as pd + + config_file = Path(args.config_file) + baseline_output_dir = Path(args.baseline_output_dir) + benchmark_output_dir = Path(args.benchmark_output_dir) + process_list = baseline_output_dir / "process_list.csv" + coordinates_dir = baseline_output_dir / "coordinates" + + if not config_file.exists(): + raise FileNotFoundError(f"Config file not found: {config_file}") + if not process_list.exists(): + raise FileNotFoundError(f"process_list.csv not found: {process_list}") + if not coordinates_dir.exists(): + raise FileNotFoundError(f"coordinates dir not found: {coordinates_dir}") + + benchmark_output_dir.mkdir(parents=True, exist_ok=True) + + gpu_counts = [int(x) for x in parse_csv_list(args.gpu_counts)] + pipelines = parse_csv_list(args.pipelines) + if not pipelines: + raise ValueError("No pipelines requested.") + + base_df = pd.read_csv(process_list) + total_tiles, num_slides = compute_total_tiles(base_df, coordinates_dir) + if total_tiles == 0: + raise RuntimeError( + "No tiles found in baseline coordinates. Ensure baseline output dir is a valid tiling output." + ) + + print( + f"Benchmarking {pipelines} on GPU counts={gpu_counts}, repeats={args.repeats}. " + f"Slides={num_slides}, total_tiles={total_tiles}." + ) + + results = [] + + for pipeline in pipelines: + if pipeline not in {"v1", "v2"}: + raise ValueError(f"Unsupported pipeline: {pipeline}") + for gpu_count in gpu_counts: + for repeat in range(1, args.repeats + 1): + run_name = f"{pipeline}-g{gpu_count}-r{repeat}" + run_dir = benchmark_output_dir / run_name + if run_dir.exists(): + shutil.rmtree(run_dir) + run_dir.mkdir(parents=True, exist_ok=True) + + run_process_df = reset_process_list_for_embed(base_df) + run_process_df.to_csv(run_dir / "process_list.csv", index=False) + + cmd = build_command( + python_exe=args.python_exe, + gpu_count=gpu_count, + run_on_cpu=args.run_on_cpu, + config_file=config_file, + run_dir=run_dir, + coords_dir=coordinates_dir, + pipeline=pipeline, + extra_opts=args.extra_opt, + ) + + log_path = run_dir / "embed.log" + print(f"\n[{run_name}] Running command:\n {' '.join(cmd)}\n") + + start = time.perf_counter() + with log_path.open("w", encoding="utf-8") as log_f: + proc = subprocess.run( + cmd, + stdout=log_f, + stderr=subprocess.STDOUT, + cwd=Path(__file__).resolve().parents[1], + ) + wall_sec = time.perf_counter() - start + + status = "ok" if proc.returncode == 0 else f"failed({proc.returncode})" + tiles_per_sec = total_tiles / wall_sec if proc.returncode == 0 else 0.0 + row = { + "pipeline": pipeline, + "gpu_count": gpu_count, + "repeat": repeat, + "wall_sec": wall_sec, + "tiles": total_tiles, + "tiles_per_sec": tiles_per_sec, + "status": status, + "log_path": str(log_path.resolve()), + } + results.append(row) + + print( + f"[{run_name}] status={status}, wall={wall_sec:.2f}s, " + f"tiles/sec={tiles_per_sec:.2f}, log={log_path}" + ) + + summary_json = benchmark_output_dir / "summary.json" + summary_md = benchmark_output_dir / "summary.md" + + with summary_json.open("w", encoding="utf-8") as f: + json.dump(results, f, indent=2) + + table = render_markdown_table(results) + with summary_md.open("w", encoding="utf-8") as f: + f.write("# Embed v1 vs v2 Benchmark Summary\n\n") + f.write(table) + f.write("\n") + + print("\nBenchmark summary:") + print(table) + print(f"\nSaved summary JSON: {summary_json}") + print(f"Saved summary Markdown: {summary_md}") + + +if __name__ == "__main__": + main() diff --git a/slide2vec/configs/default_model.yaml b/slide2vec/configs/default_model.yaml index 70068ae..e788031 100644 --- a/slide2vec/configs/default_model.yaml +++ b/slide2vec/configs/default_model.yaml @@ -23,7 +23,15 @@ model: speed: fp16: false # use mixed precision during model inference - num_workers_embedding: 8 # number of workers for data loading when embedding slides + embedding_pipeline: "v1" # embedding execution pipeline ("v1" keeps synchronized tile sharding, "v2" enables adaptive sharding) + rank_sharding_mode: "auto" # sharding strategy for v2 ("auto", "slide", "tile") + storage_mode: "auto" # storage profile used for loader heuristics ("auto", "network", "local") + num_workers_embedding: "auto" # number of data loading workers, or "auto" to derive from CPU cores per rank + prefetch_factor_embedding: # DataLoader prefetch factor; leave empty to use storage-aware defaults + persistent_workers_embedding: true # keep DataLoader workers alive between batches when num_workers > 0 + pin_memory_embedding: true # use pinned host memory for faster H2D transfers + loader_batch_timeout_sec: 0 # DataLoader timeout in seconds + log_perf_embedding: false # log per-slide data loading/compute timing breakdown wandb: enable: false diff --git a/slide2vec/configs/h100-v2.yaml b/slide2vec/configs/h100-v2.yaml new file mode 100644 index 0000000..f24a9a7 --- /dev/null +++ b/slide2vec/configs/h100-v2.yaml @@ -0,0 +1,32 @@ +csv: # path to csv containing slide paths + +output_dir: "output" + +tiling: + params: + spacing: 0.5 + tolerance: 0.05 + tile_size: 256 + min_tissue_percentage: 0.1 + filter_params: + ref_tile_size: 256 + +model: + level: "tile" + name: "conch" + batch_size: 64 + +speed: + fp16: true + embedding_pipeline: "v2" + rank_sharding_mode: "auto" + storage_mode: "network" + num_workers_embedding: "auto" + prefetch_factor_embedding: 4 + persistent_workers_embedding: true + pin_memory_embedding: true + loader_batch_timeout_sec: 0 + log_perf_embedding: true + +wandb: + enable: false diff --git a/slide2vec/data/dataset.py b/slide2vec/data/dataset.py index e19d0b4..50a4022 100644 --- a/slide2vec/data/dataset.py +++ b/slide2vec/data/dataset.py @@ -32,6 +32,8 @@ def __init__( self.target_spacing = target_spacing self.backend = backend self.name = wsi_path.stem.replace(" ", "_") + self._wsi = None + self._worker_id = None self.load_coordinates(coordinates_dir) self.transforms = transforms self.restrict_to_tissue = restrict_to_tissue @@ -83,10 +85,16 @@ def scale_coordinates(self): def __len__(self): return len(self.x) + def _get_worker_wsi(self): + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info is not None else -1 + if self._wsi is None or self._worker_id != worker_id: + self._wsi = wsd.WholeSlideImage(self.path, backend=self.backend) + self._worker_id = worker_id + return self._wsi + def __getitem__(self, idx): - wsi = wsd.WholeSlideImage( - self.path, backend=self.backend - ) # cannot be defined in __init__ because of multiprocessing + wsi = self._get_worker_wsi() tile_level = self.tile_level[idx] tile_spacing = wsi.spacings[tile_level] tile_arr = wsi.get_patch( diff --git a/slide2vec/embed.py b/slide2vec/embed.py index 89b3acf..857fe31 100644 --- a/slide2vec/embed.py +++ b/slide2vec/embed.py @@ -1,24 +1,26 @@ +import argparse import gc +import multiprocessing as mp import os +import time +import traceback +from contextlib import nullcontext +from pathlib import Path + import h5py -import tqdm +import numpy as np +import pandas as pd import torch -import argparse -import traceback import torchvision -import pandas as pd -import multiprocessing as mp - -from pathlib import Path -from contextlib import nullcontext +import tqdm import slide2vec.distributed as distributed +from slide2vec.data import RegionUnfolding, TileDataset +from slide2vec.hs2p.hs2p.wsi import SamplingParameters +from slide2vec.models import ModelFactory from slide2vec.utils import fix_random_seeds from slide2vec.utils.config import get_cfg_from_file, setup_distributed -from slide2vec.models import ModelFactory -from slide2vec.data import TileDataset, RegionUnfolding -from slide2vec.hs2p.hs2p.wsi import SamplingParameters torchvision.disable_beta_transforms_warning() @@ -34,12 +36,10 @@ def get_args_parser(add_help: bool = True): default=None, help="output directory to save logs and checkpoints", ) - parser.add_argument( - "--run-on-cpu", action="store_true", help="run inference on cpu" - ) + parser.add_argument("--run-on-cpu", action="store_true", help="run inference on cpu") parser.add_argument( "opts", - help="Modify config options at the end of the command using \"path.key=value\".", + help='Modify config options at the end of the command using "path.key=value".', default=None, nargs=argparse.REMAINDER, ) @@ -49,7 +49,7 @@ def get_args_parser(add_help: bool = True): def create_transforms(cfg, model): if cfg.model.level in ["tile", "slide"]: return model.get_transforms() - elif cfg.model.level == "region": + if cfg.model.level == "region": return torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), @@ -57,8 +57,7 @@ def create_transforms(cfg, model): model.get_transforms(), ] ) - else: - raise ValueError(f"Unknown model level: {cfg.model.level}") + raise ValueError(f"Unknown model level: {cfg.model.level}") def create_dataset( @@ -89,36 +88,283 @@ def create_dataset( ) -def run_inference(dataloader, model, device, autocast_context, unit, batch_size, feature_path, feature_dim, dtype, run_on_cpu: False): +def get_speed_option(cfg, key: str, default): + speed_cfg = getattr(cfg, "speed", None) + if speed_cfg is None: + return default + if not hasattr(speed_cfg, key): + return default + value = getattr(speed_cfg, key) + if value is None: + return default + return value + + +def parse_slurm_cpus(value: str | None): + if value is None: + return None + head = value.split("(")[0] + digits = "".join(ch for ch in head if ch.isdigit()) + if not digits: + return None + return int(digits) + + +def resolve_loader_settings(cfg, run_on_cpu: bool): + world_size = max(1, distributed.get_global_size()) + cpu_count = mp.cpu_count() + + workers_cfg = get_speed_option(cfg, "num_workers_embedding", "auto") + auto_workers = isinstance(workers_cfg, str) and workers_cfg.lower() == "auto" + if auto_workers: + workers_per_rank = cpu_count // world_size + workers_per_rank = max(4, min(16, workers_per_rank)) + else: + workers_per_rank = int(workers_cfg) + + slurm_cpus = parse_slurm_cpus(os.environ.get("SLURM_JOB_CPUS_PER_NODE")) + if slurm_cpus is not None: + workers_per_rank = min(workers_per_rank, max(1, slurm_cpus // world_size)) + + workers_per_rank = max(0, workers_per_rank) + + storage_mode = str(get_speed_option(cfg, "storage_mode", "auto")).lower() + prefetch_cfg = get_speed_option(cfg, "prefetch_factor_embedding", None) + if prefetch_cfg is None: + prefetch_factor = 4 if storage_mode in {"network", "auto"} else 2 + else: + prefetch_factor = int(prefetch_cfg) + + persistent_workers = bool(get_speed_option(cfg, "persistent_workers_embedding", True)) + pin_memory = bool(get_speed_option(cfg, "pin_memory_embedding", True)) + loader_timeout_sec = int(get_speed_option(cfg, "loader_batch_timeout_sec", 0)) + + if run_on_cpu: + # CPU inference in containerized CI frequently has very limited /dev/shm. + # Force single-process loading to avoid worker shared-memory crashes. + workers_per_rank = 0 + pin_memory = False + persistent_workers = False + + if workers_per_rank <= 0: + persistent_workers = False + prefetch_factor = None + + pipeline = str(get_speed_option(cfg, "embedding_pipeline", "v1")).lower() + rank_sharding_mode = str(get_speed_option(cfg, "rank_sharding_mode", "auto")).lower() + log_perf_embedding = bool(get_speed_option(cfg, "log_perf_embedding", False)) + + return { + "embedding_pipeline": pipeline, + "rank_sharding_mode": rank_sharding_mode, + "storage_mode": storage_mode, + "num_workers": workers_per_rank, + "prefetch_factor": prefetch_factor, + "persistent_workers": persistent_workers, + "pin_memory": pin_memory, + "loader_timeout_sec": loader_timeout_sec, + "log_perf_embedding": log_perf_embedding, + } + + +def create_dataloader(dataset, cfg, runtime, sampler=None): + kwargs = { + "batch_size": cfg.model.batch_size, + "sampler": sampler, + "num_workers": runtime["num_workers"], + "pin_memory": runtime["pin_memory"], + "timeout": runtime["loader_timeout_sec"], + } + if runtime["num_workers"] > 0: + kwargs["persistent_workers"] = runtime["persistent_workers"] + if runtime["prefetch_factor"] is not None: + kwargs["prefetch_factor"] = runtime["prefetch_factor"] + return torch.utils.data.DataLoader(dataset, **kwargs) + + +def collect_pending_slides(process_df, coordinates_dir): + tiled_df = process_df[process_df.tiling_status == "success"] + pending_df = tiled_df[tiled_df["feature_status"] != "success"] + + slides = [] + for _, row in pending_df.iterrows(): + wsi_path = Path(row.wsi_path) + name = wsi_path.stem.replace(" ", "_") + coordinates_file = coordinates_dir / f"{name}.npy" + tile_count = 0 + if coordinates_file.is_file(): + try: + coordinates = np.load(coordinates_file, allow_pickle=True) + tile_count = int(len(coordinates["x"])) + except Exception: + tile_count = 0 + mask_path = None + if "mask_path" in row and row.mask_path is not None and not pd.isna(row.mask_path): + mask_path = str(row.mask_path) + slides.append( + { + "wsi_path": str(wsi_path), + "mask_path": mask_path, + "name": name, + "tile_count": tile_count, + } + ) + return slides, tiled_df + + +def assign_slides_lpt(slides, world_size): + assignments = {rank: [] for rank in range(world_size)} + loads = {rank: 0 for rank in range(world_size)} + for slide in sorted(slides, key=lambda x: x["tile_count"], reverse=True): + rank = min(loads, key=lambda r: (loads[r], r)) + assignments[rank].append(slide) + loads[rank] += max(1, int(slide["tile_count"])) + return assignments + + +def decide_sharding_mode(cfg, pending_count, world_size): + mode = str(get_speed_option(cfg, "rank_sharding_mode", "auto")).lower() + if mode == "tile": + return "tile" + if mode == "slide": + return "slide" + if mode == "auto": + return "slide" if pending_count >= world_size else "tile" + raise ValueError(f"Unknown rank sharding mode: {mode}") + + +def get_feature_path(features_dir: Path, name: str, cfg): + feature_path = features_dir / f"{name}.pt" + if cfg.model.save_tile_embeddings: + feature_path = features_dir / f"{name}-tiles.pt" + return feature_path + + +def log_perf_summary(name: str, stats: dict, unit: str): + total_time = stats["data_wait_s"] + stats["h2d_s"] + stats["forward_s"] + stats["write_s"] + data_wait_pct = 100.0 * stats["data_wait_s"] / max(total_time, 1e-8) + tiles_per_sec = stats["samples"] / max(stats["elapsed_s"], 1e-8) + print( + f"[perf] {name}: {stats['samples']} {unit}s, {tiles_per_sec:.2f} {unit}s/s, " + f"data_wait={stats['data_wait_s']:.2f}s ({data_wait_pct:.1f}%), " + f"h2d={stats['h2d_s']:.2f}s, forward={stats['forward_s']:.2f}s, write={stats['write_s']:.2f}s" + ) + + +def run_inference_to_h5( + dataloader, + model, + device, + autocast_context, + unit, + batch_size, + feature_path, + collect_indices, + run_on_cpu, + show_progress, +): device_name = f"GPU {distributed.get_global_rank()}" if not run_on_cpu else "CPU" + + stats = { + "data_wait_s": 0.0, + "h2d_s": 0.0, + "forward_s": 0.0, + "write_s": 0.0, + "samples": 0, + "batches": 0, + "elapsed_s": 0.0, + } + + start = time.perf_counter() with h5py.File(feature_path, "w") as f: - features = f.create_dataset("features", shape=(0, *feature_dim), maxshape=(None, *feature_dim), dtype=dtype, chunks=(batch_size, *feature_dim)) - indices = f.create_dataset("indices", shape=(0,), maxshape=(None,), dtype='int64', chunks=(batch_size,)) + features = None + indices = None + + iterator = iter(dataloader) + progress = tqdm.tqdm( + total=len(dataloader), + desc=f"Inference on {device_name}", + unit=unit, + unit_scale=batch_size, + leave=False, + position=2 + distributed.get_global_rank(), + disable=not show_progress, + ) + with torch.inference_mode(), autocast_context: - for batch in tqdm.tqdm( - dataloader, - desc=f"Inference on {device_name}", - unit=unit, - unit_scale=batch_size, - leave=False, - position=2 + distributed.get_global_rank(), - ): + while True: + data_wait_start = time.perf_counter() + try: + batch = next(iterator) + except StopIteration: + break + stats["data_wait_s"] += time.perf_counter() - data_wait_start + idx, image = batch - image = image.to(device, non_blocking=True) - feature = model(image)["embedding"].cpu().numpy() - features.resize(features.shape[0] + feature.shape[0], axis=0) - features[-feature.shape[0]:] = feature - indices.resize(indices.shape[0] + idx.shape[0], axis=0) - indices[-idx.shape[0]:] = idx.cpu().numpy() - # cleanup - del image, feature + h2d_start = time.perf_counter() + image = image.to(device, non_blocking=not run_on_cpu) + stats["h2d_s"] += time.perf_counter() - h2d_start + + forward_start = time.perf_counter() + feature_cpu = model(image)["embedding"].cpu() + stats["forward_s"] += time.perf_counter() - forward_start + + write_start = time.perf_counter() + feature_np = feature_cpu.numpy() + if features is None: + feature_dim = feature_np.shape[1:] + dtype = feature_np.dtype + features = f.create_dataset( + "features", + shape=(0, *feature_dim), + maxshape=(None, *feature_dim), + dtype=dtype, + chunks=(batch_size, *feature_dim), + ) + if collect_indices: + indices = f.create_dataset( + "indices", + shape=(0,), + maxshape=(None,), + dtype="int64", + chunks=(batch_size,), + ) + + features.resize(features.shape[0] + feature_np.shape[0], axis=0) + features[-feature_np.shape[0] :] = feature_np + + if collect_indices: + idx_np = idx.cpu().numpy() if hasattr(idx, "cpu") else np.asarray(idx) + indices.resize(indices.shape[0] + idx_np.shape[0], axis=0) + indices[-idx_np.shape[0] :] = idx_np + + stats["write_s"] += time.perf_counter() - write_start + stats["samples"] += int(feature_np.shape[0]) + stats["batches"] += 1 + progress.update(1) + + del image, feature_cpu, feature_np + + progress.close() + + if stats["batches"] == 0: + raise RuntimeError("No batches yielded by DataLoader.") + + stats["elapsed_s"] = time.perf_counter() - start - # cleanup if not run_on_cpu: torch.cuda.empty_cache() gc.collect() + return stats + + +def load_features_from_h5(feature_path): + with h5py.File(feature_path, "r") as f: + features = torch.from_numpy(f["features"][:]) + return features + def load_sort_and_deduplicate_features(tmp_dir, name, expected_len=None): features_list, indices_list = [], [] @@ -133,7 +379,7 @@ def load_sort_and_deduplicate_features(tmp_dir, name, expected_len=None): order = torch.argsort(indices) indices = indices[order] features = features[order] - # deduplicate + keep = torch.ones_like(indices, dtype=torch.bool) keep[1:] = indices[1:] != indices[:-1] indices_unique = indices[keep] @@ -160,8 +406,344 @@ def cleanup_tmp_features(tmp_dir: Path, name: str): os.remove(fp) +def run_embed_v1( + *, + slides, + process_df, + process_list, + model, + cfg, + coordinates_dir, + sampling_params, + transforms, + runtime, + autocast_context, + features_dir, + tmp_dir, + run_on_cpu, + unit, +): + feature_extraction_updates = {} + + for slide in tqdm.tqdm( + slides, + desc="Inference", + unit="slide", + total=len(slides), + leave=True, + disable=not distributed.is_main_process(), + position=1, + ): + wsi_fp = Path(slide["wsi_path"]) + mask_fp = Path(slide["mask_path"]) if slide["mask_path"] is not None else None + name = slide["name"] + + feature_path = get_feature_path(features_dir, name, cfg) + tmp_feature_path = tmp_dir / f"{name}-rank_{distributed.get_global_rank()}.h5" + + status_info = {"status": "success"} + local_failed = False + + try: + dataset = create_dataset( + wsi_path=wsi_fp, + mask_path=mask_fp, + coordinates_dir=coordinates_dir, + target_spacing=cfg.tiling.params.spacing, + tolerance=cfg.tiling.params.tolerance, + backend=cfg.tiling.backend, + segment_params=cfg.tiling.seg_params, + sampling_params=sampling_params, + filter_params=cfg.tiling.filter_params, + transforms=transforms, + restrict_to_tissue=cfg.model.restrict_to_tissue, + ) + if distributed.is_enabled_and_multiple_gpus(): + sampler = torch.utils.data.DistributedSampler( + dataset, + shuffle=False, + drop_last=False, + ) + else: + sampler = None + + dataloader = create_dataloader(dataset, cfg, runtime, sampler=sampler) + perf_stats = run_inference_to_h5( + dataloader=dataloader, + model=model, + device=model.device, + autocast_context=autocast_context, + unit=unit, + batch_size=cfg.model.batch_size, + feature_path=tmp_feature_path, + collect_indices=True, + run_on_cpu=run_on_cpu, + show_progress=distributed.is_main_process(), + ) + if runtime["log_perf_embedding"] and distributed.is_main_process(): + log_perf_summary(name, perf_stats, unit) + + except Exception as e: + local_failed = True + status_info = { + "status": "failed", + "error": str(e), + "traceback": str(traceback.format_exc()), + } + + any_rank_failed = local_failed + if not run_on_cpu: + torch.distributed.barrier() + failure_flag = torch.tensor( + 1 if local_failed else 0, device=model.device, dtype=torch.int32 + ) + torch.distributed.all_reduce(failure_flag, op=torch.distributed.ReduceOp.MAX) + any_rank_failed = bool(failure_flag.item()) + + if any_rank_failed: + if distributed.is_main_process(): + cleanup_tmp_features(tmp_dir, name) + if status_info["status"] != "failed": + status_info = { + "status": "failed", + "error": "Feature extraction failed on at least one distributed rank.", + "traceback": "", + } + elif distributed.is_main_process(): + try: + wsi_feature = load_sort_and_deduplicate_features( + tmp_dir, name, expected_len=len(dataset) + ) + torch.save(wsi_feature, feature_path) + except Exception as e: + any_rank_failed = True + cleanup_tmp_features(tmp_dir, name) + status_info = { + "status": "failed", + "error": str(e), + "traceback": str(traceback.format_exc()), + } + finally: + if "wsi_feature" in locals(): + del wsi_feature + if not run_on_cpu: + torch.cuda.empty_cache() + gc.collect() + + if not run_on_cpu: + failure_flag = torch.tensor( + 1 if (distributed.is_main_process() and any_rank_failed) else 0, + device=model.device, + dtype=torch.int32, + ) + torch.distributed.broadcast(failure_flag, src=0) + torch.distributed.barrier() + any_rank_failed = bool(failure_flag.item()) + + if distributed.is_main_process(): + if any_rank_failed and status_info["status"] != "failed": + status_info = { + "status": "failed", + "error": "Feature extraction failed on at least one distributed rank.", + "traceback": "", + } + feature_extraction_updates[str(wsi_fp)] = status_info + + process_df.loc[ + process_df["wsi_path"] == str(wsi_fp), "feature_status" + ] = status_info["status"] + if "error" in status_info: + process_df.loc[ + process_df["wsi_path"] == str(wsi_fp), "error" + ] = status_info["error"] + process_df.loc[ + process_df["wsi_path"] == str(wsi_fp), "traceback" + ] = status_info["traceback"] + process_df.to_csv(process_list, index=False) + + if distributed.is_enabled_and_multiple_gpus(): + torch.distributed.barrier() + + +def run_embed_v2( + *, + slides, + process_df, + process_list, + model, + cfg, + coordinates_dir, + sampling_params, + transforms, + runtime, + autocast_context, + features_dir, + tmp_dir, + run_on_cpu, + unit, +): + world_size = distributed.get_global_size() + sharding_mode = decide_sharding_mode(cfg, pending_count=len(slides), world_size=world_size) + + if sharding_mode == "tile": + if distributed.is_main_process(): + print( + "Embedding v2 requested but switching to tile-level sharding " + f"(pending_slides={len(slides)} < world_size={world_size})." + ) + return run_embed_v1( + slides=slides, + process_df=process_df, + process_list=process_list, + model=model, + cfg=cfg, + coordinates_dir=coordinates_dir, + sampling_params=sampling_params, + transforms=transforms, + runtime=runtime, + autocast_context=autocast_context, + features_dir=features_dir, + tmp_dir=tmp_dir, + run_on_cpu=run_on_cpu, + unit=unit, + ) + + if distributed.is_main_process(): + slides_to_assign = slides + else: + slides_to_assign = None + + if distributed.is_enabled(): + payload = [slides_to_assign] + torch.distributed.broadcast_object_list(payload, src=0) + slides_to_assign = payload[0] + + assignments = assign_slides_lpt(slides_to_assign, world_size=world_size) + rank = distributed.get_global_rank() + local_slides = assignments[rank] + + if distributed.is_main_process(): + print( + f"Embedding v2 slide-sharding enabled. " + f"Rank 0 assigned {len(local_slides)} / {len(slides_to_assign)} slides." + ) + + local_updates = {} + + for slide in tqdm.tqdm( + local_slides, + desc=f"Inference (rank {rank})", + unit="slide", + total=len(local_slides), + leave=True, + disable=not distributed.is_main_process(), + position=1, + ): + wsi_fp = Path(slide["wsi_path"]) + mask_fp = Path(slide["mask_path"]) if slide["mask_path"] is not None else None + name = slide["name"] + + feature_path = get_feature_path(features_dir, name, cfg) + tmp_feature_path = tmp_dir / f"{name}-rank_{rank}-v2.h5" + + status_info = {"status": "success"} + try: + dataset = create_dataset( + wsi_path=wsi_fp, + mask_path=mask_fp, + coordinates_dir=coordinates_dir, + target_spacing=cfg.tiling.params.spacing, + tolerance=cfg.tiling.params.tolerance, + backend=cfg.tiling.backend, + segment_params=cfg.tiling.seg_params, + sampling_params=sampling_params, + filter_params=cfg.tiling.filter_params, + transforms=transforms, + restrict_to_tissue=cfg.model.restrict_to_tissue, + ) + dataloader = create_dataloader(dataset, cfg, runtime, sampler=None) + perf_stats = run_inference_to_h5( + dataloader=dataloader, + model=model, + device=model.device, + autocast_context=autocast_context, + unit=unit, + batch_size=cfg.model.batch_size, + feature_path=tmp_feature_path, + collect_indices=False, + run_on_cpu=run_on_cpu, + show_progress=distributed.is_main_process(), + ) + + wsi_feature = load_features_from_h5(tmp_feature_path) + torch.save(wsi_feature, feature_path) + os.remove(tmp_feature_path) + del wsi_feature + + if runtime["log_perf_embedding"]: + print(f"[rank {rank}]", end=" ") + log_perf_summary(name, perf_stats, unit) + + if not run_on_cpu: + torch.cuda.empty_cache() + gc.collect() + + except Exception as e: + status_info = { + "status": "failed", + "error": str(e), + "traceback": str(traceback.format_exc()), + } + if tmp_feature_path.exists(): + os.remove(tmp_feature_path) + + local_updates[str(wsi_fp)] = status_info + + if distributed.is_enabled(): + gathered_updates = [None for _ in range(world_size)] + torch.distributed.all_gather_object(gathered_updates, local_updates) + else: + gathered_updates = [local_updates] + + if distributed.is_main_process(): + merged_updates = {} + for update in gathered_updates: + merged_updates.update(update) + + for wsi_path, status_info in merged_updates.items(): + process_df.loc[process_df["wsi_path"] == str(wsi_path), "feature_status"] = status_info[ + "status" + ] + if "error" in status_info: + process_df.loc[process_df["wsi_path"] == str(wsi_path), "error"] = status_info[ + "error" + ] + process_df.loc[ + process_df["wsi_path"] == str(wsi_path), "traceback" + ] = status_info["traceback"] + + process_df.to_csv(process_list, index=False) + + if distributed.is_enabled_and_multiple_gpus(): + torch.distributed.barrier() + + +def print_feature_summary(process_df, tiled_df, unit): + slides_with_tiles = len(tiled_df) + total_slides = len(process_df) + failed_feature_extraction = process_df[process_df["feature_status"] == "failed"] + print("=+=" * 10) + print(f"Total number of slides with {unit}s: {slides_with_tiles}/{total_slides}") + print( + f"Failed {unit}-level feature extraction: {len(failed_feature_extraction)}/{slides_with_tiles}" + ) + print( + f"Completed {unit}-level feature extraction: {slides_with_tiles - len(failed_feature_extraction)}/{slides_with_tiles}" + ) + print("=+=" * 10) + + def main(args): - # setup configuration run_on_cpu = args.run_on_cpu cfg = get_cfg_from_file(args.config_file) output_dir = resolve_output_dir(cfg.output_dir, args.output_dir) @@ -174,28 +756,40 @@ def main(args): coordinates_dir = Path(cfg.tiling.read_coordinates_from) else: coordinates_dir = Path(cfg.output_dir, "coordinates") + fix_random_seeds(cfg.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False unit = "tile" if cfg.model.level != "region" else "region" - - num_workers = min(mp.cpu_count(), cfg.speed.num_workers_embedding) - if "SLURM_JOB_CPUS_PER_NODE" in os.environ: - num_workers = min(num_workers, int(os.environ["SLURM_JOB_CPUS_PER_NODE"])) + runtime = resolve_loader_settings(cfg, run_on_cpu=run_on_cpu) + if runtime["embedding_pipeline"] not in {"v1", "v2"}: + raise ValueError( + f"Unknown embedding pipeline: {runtime['embedding_pipeline']}. " + "Expected one of: v1, v2." + ) process_list = Path(cfg.output_dir, "process_list.csv") - assert ( - process_list.is_file() - ), "Process list CSV not found. Ensure tiling has been run." + assert process_list.is_file(), "Process list CSV not found. Ensure tiling has been run." process_df = pd.read_csv(process_list) - cols = ["wsi_name", "wsi_path", "tiling_status", "error", "traceback"] + if "feature_status" not in process_df.columns: process_df["feature_status"] = ["tbp"] * len(process_df) if "mask_path" not in process_df.columns: process_df["mask_path"] = [None] * len(process_df) - cols = ["wsi_name", "wsi_path", "mask_path", "tiling_status", "feature_status", "error", "traceback"] + + cols = [ + "wsi_name", + "wsi_path", + "mask_path", + "tiling_status", + "feature_status", + "error", + "traceback", + ] process_df = process_df[cols] + process_df["error"] = process_df["error"].astype("object") + process_df["traceback"] = process_df["traceback"].astype("object") skip_feature_extraction = process_df["feature_status"].str.contains("success").all() @@ -206,230 +800,94 @@ def main(args): print("=+=" * 10) if distributed.is_enabled(): torch.distributed.destroy_process_group() + return + model = ModelFactory(cfg.model).get_model() + if distributed.is_main_process(): + print(f"Starting {unit}-level feature extraction...") + if not run_on_cpu: + torch.distributed.barrier() + + pixel_mapping = {k: v for e in cfg.tiling.sampling_params.pixel_mapping for k, v in e.items()} + tissue_percentage = {k: v for e in cfg.tiling.sampling_params.tissue_percentage for k, v in e.items()} + if "tissue" not in tissue_percentage: + tissue_percentage["tissue"] = cfg.tiling.params.min_tissue_percentage + if cfg.tiling.sampling_params.color_mapping is not None: + color_mapping = {k: v for e in cfg.tiling.sampling_params.color_mapping for k, v in e.items()} else: - model = ModelFactory(cfg.model).get_model() - if distributed.is_main_process(): - print(f"Starting {unit}-level feature extraction...") - if not run_on_cpu: - torch.distributed.barrier() - - pixel_mapping = {k: v for e in cfg.tiling.sampling_params.pixel_mapping for k, v in e.items()} - tissue_percentage = {k: v for e in cfg.tiling.sampling_params.tissue_percentage for k, v in e.items()} - if "tissue" not in tissue_percentage: - tissue_percentage["tissue"] = cfg.tiling.params.min_tissue_percentage - if cfg.tiling.sampling_params.color_mapping is not None: - color_mapping = {k: v for e in cfg.tiling.sampling_params.color_mapping for k, v in e.items()} - else: - color_mapping = None - - sampling_params = SamplingParameters( - pixel_mapping=pixel_mapping, - color_mapping=color_mapping, - tissue_percentage=tissue_percentage, - ) + color_mapping = None - # select slides that were successfully tiled but not yet processed for feature extraction - tiled_df = process_df[process_df.tiling_status == "success"] - mask = tiled_df["feature_status"] != "success" - process_stack = tiled_df[mask] - total = len(process_stack) + sampling_params = SamplingParameters( + pixel_mapping=pixel_mapping, + color_mapping=color_mapping, + tissue_percentage=tissue_percentage, + ) - wsi_paths_to_process = [Path(x) for x in process_stack.wsi_path.values.tolist()] - mask_paths_to_process = [Path(x) if x is not None and not pd.isna(x) else None for x in process_stack.mask_path.values.tolist()] - combined_paths = zip(wsi_paths_to_process, mask_paths_to_process) + slides, tiled_df = collect_pending_slides(process_df, coordinates_dir) - features_dir = Path(cfg.output_dir, "features") - if distributed.is_main_process(): - features_dir.mkdir(exist_ok=True, parents=True) + features_dir = Path(cfg.output_dir, "features") + features_dir.mkdir(exist_ok=True, parents=True) - tmp_dir = Path("/tmp") - if distributed.is_main_process(): - tmp_dir.mkdir(exist_ok=True, parents=True) + tmp_dir = Path("/tmp") + tmp_dir.mkdir(exist_ok=True, parents=True) - autocast_context = ( - torch.autocast(device_type="cuda", dtype=torch.float16) - if (cfg.speed.fp16 and not run_on_cpu) - else nullcontext() - ) - feature_extraction_updates = {} + autocast_context = ( + torch.autocast(device_type="cuda", dtype=torch.float16) + if (cfg.speed.fp16 and not run_on_cpu) + else nullcontext() + ) - transforms = create_transforms(cfg, model) + transforms = create_transforms(cfg, model) + if distributed.is_main_process(): print(f"transforms: {transforms}") + print( + "loader settings: " + f"pipeline={runtime['embedding_pipeline']}, " + f"sharding={runtime['rank_sharding_mode']}, workers={runtime['num_workers']}, " + f"prefetch={runtime['prefetch_factor']}, persistent_workers={runtime['persistent_workers']}, " + f"pin_memory={runtime['pin_memory']}" + ) - for wsi_fp, mask_fp in tqdm.tqdm( - combined_paths, - desc="Inference", - unit="slide", - total=total, - leave=True, - disable=not distributed.is_main_process(), - position=1, - ): - name = wsi_fp.stem.replace(" ", "_") - feature_path = features_dir / f"{name}.pt" - if cfg.model.save_tile_embeddings: - feature_path = features_dir / f"{name}-tiles.pt" - tmp_feature_path = tmp_dir / f"{name}-rank_{distributed.get_global_rank()}.h5" - - status_info = {"status": "success"} - local_failed = False - try: - dataset = create_dataset( - wsi_path=wsi_fp, - mask_path=mask_fp, - coordinates_dir=coordinates_dir, - target_spacing=cfg.tiling.params.spacing, - tolerance=cfg.tiling.params.tolerance, - backend=cfg.tiling.backend, - segment_params=cfg.tiling.seg_params, - sampling_params=sampling_params, - filter_params=cfg.tiling.filter_params, - transforms=transforms, - restrict_to_tissue=cfg.model.restrict_to_tissue, - ) - if distributed.is_enabled_and_multiple_gpus(): - sampler = torch.utils.data.DistributedSampler( - dataset, - shuffle=False, - drop_last=False, - ) - else: - sampler = None - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=cfg.model.batch_size, - sampler=sampler, - num_workers=num_workers, - pin_memory=True, - ) - - # get feature dimension and dtype using a dry run - with torch.inference_mode(), autocast_context: - sample_batch = next(iter(dataloader)) - sample_image = sample_batch[1].to(model.device) - sample_feature = model(sample_image)["embedding"].cpu().numpy() - feature_dim = sample_feature.shape[1:] - dtype = sample_feature.dtype - - run_inference( - dataloader, - model, - model.device, - autocast_context, - unit, - cfg.model.batch_size, - tmp_feature_path, - feature_dim, - dtype, - run_on_cpu, - ) - - except Exception as e: - local_failed = True - status_info = { - "status": "failed", - "error": str(e), - "traceback": str(traceback.format_exc()), - } - - any_rank_failed = local_failed - if not run_on_cpu: - # Ensure every rank reaches sync points, even when one rank failed. - torch.distributed.barrier() - failure_flag = torch.tensor( - 1 if local_failed else 0, device=model.device, dtype=torch.int32 - ) - torch.distributed.all_reduce( - failure_flag, op=torch.distributed.ReduceOp.MAX - ) - any_rank_failed = bool(failure_flag.item()) - - if any_rank_failed: - if distributed.is_main_process(): - cleanup_tmp_features(tmp_dir, name) - if status_info["status"] != "failed": - status_info = { - "status": "failed", - "error": "Feature extraction failed on at least one distributed rank.", - "traceback": "", - } - elif distributed.is_main_process(): - try: - wsi_feature = load_sort_and_deduplicate_features( - tmp_dir, name, expected_len=len(dataset) - ) - torch.save(wsi_feature, feature_path) - except Exception as e: - any_rank_failed = True - cleanup_tmp_features(tmp_dir, name) - status_info = { - "status": "failed", - "error": str(e), - "traceback": str(traceback.format_exc()), - } - finally: - if "wsi_feature" in locals(): - del wsi_feature - if not run_on_cpu: - torch.cuda.empty_cache() - gc.collect() - - if not run_on_cpu: - # Propagate post-processing failures from rank 0 to all ranks. - failure_flag = torch.tensor( - 1 if (distributed.is_main_process() and any_rank_failed) else 0, - device=model.device, - dtype=torch.int32, - ) - torch.distributed.broadcast(failure_flag, src=0) - torch.distributed.barrier() - any_rank_failed = bool(failure_flag.item()) - - if distributed.is_main_process(): - if any_rank_failed and status_info["status"] != "failed": - status_info = { - "status": "failed", - "error": "Feature extraction failed on at least one distributed rank.", - "traceback": "", - } - feature_extraction_updates[str(wsi_fp)] = status_info - - # update process_df - if distributed.is_main_process(): - status_info = feature_extraction_updates[str(wsi_fp)] - process_df.loc[ - process_df["wsi_path"] == str(wsi_fp), "feature_status" - ] = status_info["status"] - if "error" in status_info: - process_df.loc[ - process_df["wsi_path"] == str(wsi_fp), "error" - ] = status_info["error"] - process_df.loc[ - process_df["wsi_path"] == str(wsi_fp), "traceback" - ] = status_info["traceback"] - process_df.to_csv(process_list, index=False) - - if distributed.is_enabled_and_multiple_gpus(): - torch.distributed.barrier() + if runtime["embedding_pipeline"] == "v2": + run_embed_v2( + slides=slides, + process_df=process_df, + process_list=process_list, + model=model, + cfg=cfg, + coordinates_dir=coordinates_dir, + sampling_params=sampling_params, + transforms=transforms, + runtime=runtime, + autocast_context=autocast_context, + features_dir=features_dir, + tmp_dir=tmp_dir, + run_on_cpu=run_on_cpu, + unit=unit, + ) + else: + run_embed_v1( + slides=slides, + process_df=process_df, + process_list=process_list, + model=model, + cfg=cfg, + coordinates_dir=coordinates_dir, + sampling_params=sampling_params, + transforms=transforms, + runtime=runtime, + autocast_context=autocast_context, + features_dir=features_dir, + tmp_dir=tmp_dir, + run_on_cpu=run_on_cpu, + unit=unit, + ) - if distributed.is_main_process(): - # summary logging - slides_with_tiles = len(tiled_df) - total_slides = len(process_df) - failed_feature_extraction = process_df[ - process_df["feature_status"] == "failed" - ] - print("=+=" * 10) - print(f"Total number of slides with {unit}s: {slides_with_tiles}/{total_slides}") - print(f"Failed {unit}-level feature extraction: {len(failed_feature_extraction)}/{slides_with_tiles}") - print( - f"Completed {unit}-level feature extraction: {slides_with_tiles - len(failed_feature_extraction)}/{slides_with_tiles}" - ) - print("=+=" * 10) + if distributed.is_main_process(): + print_feature_summary(process_df, tiled_df, unit) - if distributed.is_enabled(): - torch.distributed.destroy_process_group() + if distributed.is_enabled(): + torch.distributed.destroy_process_group() if __name__ == "__main__": diff --git a/tests/fixtures/test_dataset_worker_cache.py b/tests/fixtures/test_dataset_worker_cache.py new file mode 100644 index 0000000..655fb28 --- /dev/null +++ b/tests/fixtures/test_dataset_worker_cache.py @@ -0,0 +1,76 @@ +import ast +import unittest +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +DATASET_FILE = ROOT / "slide2vec/data/dataset.py" + + +class DatasetWorkerCacheTests(unittest.TestCase): + def setUp(self): + self.src = DATASET_FILE.read_text(encoding="utf-8") + self.tree = ast.parse(self.src) + self.tile_dataset = next( + node + for node in self.tree.body + if isinstance(node, ast.ClassDef) and node.name == "TileDataset" + ) + + def _get_method(self, name: str): + return next( + node + for node in self.tile_dataset.body + if isinstance(node, ast.FunctionDef) and node.name == name + ) + + def test_private_worker_cache_members_initialized(self): + init_fn = self._get_method("__init__") + attrs = set() + for node in ast.walk(init_fn): + if ( + isinstance(node, ast.Assign) + and len(node.targets) == 1 + and isinstance(node.targets[0], ast.Attribute) + and isinstance(node.targets[0].value, ast.Name) + and node.targets[0].value.id == "self" + and isinstance(node.value, ast.Constant) + and node.value.value is None + ): + attrs.add(node.targets[0].attr) + self.assertIn("_wsi", attrs) + self.assertIn("_worker_id", attrs) + + def test_worker_wsi_helper_exists_and_uses_worker_info(self): + helper_fn = self._get_method("_get_worker_wsi") + helper_src = ast.get_source_segment(self.src, helper_fn) + self.assertIn("torch.utils.data.get_worker_info()", helper_src) + self.assertIn("wsd.WholeSlideImage(self.path, backend=self.backend)", helper_src) + + def test_getitem_uses_helper_and_not_direct_constructor(self): + getitem_fn = self._get_method("__getitem__") + helper_calls = 0 + direct_ctor_calls = 0 + for node in ast.walk(getitem_fn): + if isinstance(node, ast.Call): + if ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "self" + and node.func.attr == "_get_worker_wsi" + ): + helper_calls += 1 + if ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "wsd" + and node.func.attr == "WholeSlideImage" + ): + direct_ctor_calls += 1 + + self.assertGreater(helper_calls, 0) + self.assertEqual(direct_ctor_calls, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/fixtures/test_embed_pipeline_mode.py b/tests/fixtures/test_embed_pipeline_mode.py new file mode 100644 index 0000000..66e6079 --- /dev/null +++ b/tests/fixtures/test_embed_pipeline_mode.py @@ -0,0 +1,67 @@ +import ast +import types +import unittest +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +EMBED_FILE = ROOT / "slide2vec/embed.py" + + +def load_functions(*fn_names): + src = EMBED_FILE.read_text(encoding="utf-8") + tree = ast.parse(src) + fn_nodes = { + node.name: node + for node in tree.body + if isinstance(node, ast.FunctionDef) + } + namespace = {} + for name in fn_names: + module = ast.Module(body=[fn_nodes[name]], type_ignores=[]) + code = compile(module, filename=str(EMBED_FILE), mode="exec") + exec(code, namespace) + return [namespace[name] for name in fn_names] + + +class EmbedPipelineModeTests(unittest.TestCase): + def _cfg(self, mode: str): + return types.SimpleNamespace(speed=types.SimpleNamespace(rank_sharding_mode=mode)) + + def test_explicit_tile_mode(self): + get_speed_option, decide_sharding_mode = load_functions( + "get_speed_option", "decide_sharding_mode" + ) + _ = get_speed_option + cfg = self._cfg("tile") + self.assertEqual(decide_sharding_mode(cfg, pending_count=100, world_size=8), "tile") + + def test_explicit_slide_mode(self): + get_speed_option, decide_sharding_mode = load_functions( + "get_speed_option", "decide_sharding_mode" + ) + _ = get_speed_option + cfg = self._cfg("slide") + self.assertEqual(decide_sharding_mode(cfg, pending_count=1, world_size=8), "slide") + + def test_auto_mode_threshold(self): + get_speed_option, decide_sharding_mode = load_functions( + "get_speed_option", "decide_sharding_mode" + ) + _ = get_speed_option + cfg = self._cfg("auto") + self.assertEqual(decide_sharding_mode(cfg, pending_count=8, world_size=8), "slide") + self.assertEqual(decide_sharding_mode(cfg, pending_count=7, world_size=8), "tile") + + def test_invalid_mode_raises(self): + get_speed_option, decide_sharding_mode = load_functions( + "get_speed_option", "decide_sharding_mode" + ) + _ = get_speed_option + cfg = self._cfg("invalid") + with self.assertRaises(ValueError): + decide_sharding_mode(cfg, pending_count=10, world_size=8) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/fixtures/test_rank_sharding_lpt.py b/tests/fixtures/test_rank_sharding_lpt.py new file mode 100644 index 0000000..03eddb3 --- /dev/null +++ b/tests/fixtures/test_rank_sharding_lpt.py @@ -0,0 +1,62 @@ +import ast +import copy +import unittest +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +EMBED_FILE = ROOT / "slide2vec/embed.py" + + +def load_functions(*fn_names): + src = EMBED_FILE.read_text(encoding="utf-8") + tree = ast.parse(src) + fn_nodes = { + node.name: node + for node in tree.body + if isinstance(node, ast.FunctionDef) + } + namespace = {} + for name in fn_names: + module = ast.Module(body=[fn_nodes[name]], type_ignores=[]) + code = compile(module, filename=str(EMBED_FILE), mode="exec") + exec(code, namespace) + return [namespace[name] for name in fn_names] + + +class RankShardingLptTests(unittest.TestCase): + def test_deterministic_assignment(self): + (assign_slides_lpt,) = load_functions("assign_slides_lpt") + slides = [ + {"name": "a", "tile_count": 10}, + {"name": "b", "tile_count": 9}, + {"name": "c", "tile_count": 8}, + {"name": "d", "tile_count": 7}, + ] + result_1 = assign_slides_lpt(copy.deepcopy(slides), world_size=2) + result_2 = assign_slides_lpt(copy.deepcopy(slides), world_size=2) + self.assertEqual(result_1, result_2) + + def test_balance_on_skewed_distribution(self): + (assign_slides_lpt,) = load_functions("assign_slides_lpt") + slides = [ + {"name": "big", "tile_count": 50}, + {"name": "m1", "tile_count": 10}, + {"name": "m2", "tile_count": 10}, + {"name": "m3", "tile_count": 10}, + {"name": "m4", "tile_count": 10}, + ] + assignments = assign_slides_lpt(copy.deepcopy(slides), world_size=2) + + assigned_names = [] + rank_loads = {} + for rank, rank_slides in assignments.items(): + assigned_names.extend(slide["name"] for slide in rank_slides) + rank_loads[rank] = sum(slide["tile_count"] for slide in rank_slides) + + self.assertCountEqual(assigned_names, [slide["name"] for slide in slides]) + self.assertLessEqual(abs(rank_loads[0] - rank_loads[1]), 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_regression_bugfixes.py b/tests/test_regression_bugfixes.py index ac6dfc1..83d2792 100644 --- a/tests/test_regression_bugfixes.py +++ b/tests/test_regression_bugfixes.py @@ -109,6 +109,39 @@ def test_region_model_factory_uses_tile_encoder_assignments(self): f"Region-level branch for {model_name} should assign to tile_encoder", ) +<<<<<<< HEAD:test/test_regression_bugfixes.py + def test_embed_reads_new_loader_config_keys(self): + src = read_source("slide2vec/embed.py") + expected_keys = [ + "embedding_pipeline", + "rank_sharding_mode", + "storage_mode", + "num_workers_embedding", + "prefetch_factor_embedding", + "persistent_workers_embedding", + "pin_memory_embedding", + "loader_batch_timeout_sec", + "log_perf_embedding", + ] + for key in expected_keys: + self.assertIn( + f"\"{key}\"", + src, + f"embed.py should reference speed.{key}", + ) + + def test_embed_cpu_workers_guard_exists(self): + src = read_source("slide2vec/embed.py") + self.assertIn( + "if run_on_cpu:", + src, + "embed.py should have a dedicated CPU loader branch", + ) + self.assertIn( + "workers_per_rank = 0", + src, + "embed.py should force single-process loading for CPU runs", +======= def test_tile_model_factory_has_pathojepa_branch(self): src = read_source("slide2vec/models/models.py") pattern = r'elif options\.name == "pathojepa":\n\s+model = PathoJEPA\(' @@ -125,6 +158,7 @@ def test_region_feature_extractor_uses_options_patch_size(self): src, pattern, "RegionFeatureExtractor should use options.patch_size to define region unrolling tile size", +>>>>>>> 29cb71d0d5f96f382d266c36fcb8f74e6b619e80:tests/test_regression_bugfixes.py )