From 62422030494ad226c771e085f0d0d2ff96b4a1a9 Mon Sep 17 00:00:00 2001 From: clemsgrs Date: Thu, 26 Feb 2026 01:43:16 +0000 Subject: [PATCH 1/2] Improve embedding data loading robustness and dataset structure --- slide2vec/configs/default_model.yaml | 8 +- slide2vec/data/__init__.py | 3 +- slide2vec/data/augmentations.py | 3 + slide2vec/data/dataset.py | 206 ++++++++++++- slide2vec/data/tile_catalog.py | 137 +++++++++ slide2vec/embed.py | 415 ++++++++++++++++++++++----- slide2vec/utils/parquet.py | 7 + tests/test_tile_catalog.py | 38 +++ 8 files changed, 735 insertions(+), 82 deletions(-) create mode 100644 slide2vec/data/tile_catalog.py create mode 100644 slide2vec/utils/parquet.py create mode 100644 tests/test_tile_catalog.py diff --git a/slide2vec/configs/default_model.yaml b/slide2vec/configs/default_model.yaml index 70068ae..d31ad05 100644 --- a/slide2vec/configs/default_model.yaml +++ b/slide2vec/configs/default_model.yaml @@ -23,7 +23,13 @@ model: speed: fp16: false # use mixed precision during model inference - num_workers_embedding: 8 # number of workers for data loading when embedding slides + num_workers_embedding: 16 # number of workers for data loading when embedding slides + persistent_workers_embedding: true # keep DataLoader workers alive across iterator re-creation for a slide + prefetch_factor_embedding: 4 # number of batches prefetched per worker + use_parquet: true # if false, use legacy TileDataset and read per-slide .npy coordinates directly + max_open_slides_per_worker: 16 # worker-local LRU cache size for open WSI readers + deterministic_inference: false # if true, force deterministic cuDNN behavior + cudnn_benchmark: true # enable cuDNN autotuner for faster fixed-shape inference wandb: enable: false diff --git a/slide2vec/data/__init__.py b/slide2vec/data/__init__.py index 98809c3..857d37f 100644 --- a/slide2vec/data/__init__.py +++ b/slide2vec/data/__init__.py @@ -1,2 +1,3 @@ -from .dataset import TileDataset +from .dataset import TileDataset, TileCatalogDataset +from .tile_catalog import ensure_tile_catalogs from .augmentations import RegionUnfolding diff --git a/slide2vec/data/augmentations.py b/slide2vec/data/augmentations.py index e18c9db..20c9e6e 100644 --- a/slide2vec/data/augmentations.py +++ b/slide2vec/data/augmentations.py @@ -55,3 +55,6 @@ def __call__(self, x): x, "c p1 p2 w h -> (p1 p2) c w h" ) # [num_tilees, 3, tile_size, tile_size] return x + + def __repr__(self): + return f"{self.__class__.__name__}({self.tile_size})" diff --git a/slide2vec/data/dataset.py b/slide2vec/data/dataset.py index e19d0b4..d740f35 100644 --- a/slide2vec/data/dataset.py +++ b/slide2vec/data/dataset.py @@ -1,8 +1,10 @@ +import os import cv2 import torch import numpy as np import wholeslidedata as wsd +from collections import OrderedDict from transformers.image_processing_utils import BaseImageProcessor from PIL import Image from pathlib import Path @@ -10,9 +12,13 @@ from slide2vec.hs2p.hs2p.wsi import WholeSlideImage, SegmentationParameters, SamplingParameters, FilterParameters from slide2vec.hs2p.hs2p.wsi.utils import HasEnoughTissue +from slide2vec.utils.parquet import require_pyarrow class TileDataset(torch.utils.data.Dataset): + # Worker-local cache (process scoped because each worker is a separate process). + _WSI_CACHE_BY_PID: dict[int, OrderedDict[tuple[str, str], wsd.WholeSlideImage]] = {} + def __init__( self, wsi_path: Path, @@ -26,12 +32,14 @@ def __init__( filter_params: FilterParameters | None = None, transforms: BaseImageProcessor | Callable | None = None, restrict_to_tissue: bool = False, + max_open_slides_per_worker: int = 16, ): self.path = wsi_path self.mask_path = mask_path self.target_spacing = target_spacing self.backend = backend self.name = wsi_path.stem.replace(" ", "_") + self.max_open_slides_per_worker = max(1, int(max_open_slides_per_worker)) self.load_coordinates(coordinates_dir) self.transforms = transforms self.restrict_to_tissue = restrict_to_tissue @@ -57,6 +65,32 @@ def __init__( self.seg_spacing = _wsi.get_level_spacing(_wsi.seg_level) self.spacing_at_level_0 = _wsi.get_level_spacing(0) + @classmethod + def _get_worker_cache( + cls, + ) -> OrderedDict[tuple[str, str], wsd.WholeSlideImage]: + pid = os.getpid() + if pid not in cls._WSI_CACHE_BY_PID: + cls._WSI_CACHE_BY_PID[pid] = OrderedDict() + return cls._WSI_CACHE_BY_PID[pid] + + def _get_wsi(self) -> wsd.WholeSlideImage: + key = (str(self.path), str(self.backend)) + cache = self._get_worker_cache() + cached = cache.pop(key, None) + if cached is not None: + cache[key] = cached + return cached + + reader = wsd.WholeSlideImage(self.path, backend=self.backend) + cache[key] = reader + while len(cache) > self.max_open_slides_per_worker: + _, evicted = cache.popitem(last=False) + close_fn = getattr(evicted, "close", None) + if callable(close_fn): + close_fn() + return reader + def load_coordinates(self, coordinates_dir): coordinates = np.load(Path(coordinates_dir, f"{self.name}.npy"), allow_pickle=True) self.x = coordinates["x"] @@ -73,7 +107,7 @@ def load_coordinates(self, coordinates_dir): def scale_coordinates(self): # coordinates are defined w.r.t. level 0 # i need to scale them to target_spacing - wsi = wsd.WholeSlideImage(self.path, backend=self.backend) + wsi = self._get_wsi() min_spacing = wsi.spacings[0] scale = min_spacing / self.target_spacing # create a [N, 2] array with x and y coordinates @@ -84,9 +118,7 @@ def __len__(self): return len(self.x) def __getitem__(self, idx): - wsi = wsd.WholeSlideImage( - self.path, backend=self.backend - ) # cannot be defined in __init__ because of multiprocessing + wsi = self._get_wsi() tile_level = self.tile_level[idx] tile_spacing = wsi.spacings[tile_level] tile_arr = wsi.get_patch( @@ -125,3 +157,169 @@ def __getitem__(self, idx): else: # general callable such as torchvision transforms tile = self.transforms(tile) return idx, tile + + +class TileCatalogDataset(torch.utils.data.Dataset): + # Worker-local cache (process scoped because each worker is a separate process). + _WSI_CACHE_BY_PID: dict[int, OrderedDict[tuple[str, str], wsd.WholeSlideImage]] = {} + + def __init__( + self, + *, + catalog_path: Path, + wsi_path: Path, + mask_path: Path | None, + target_spacing: float, + tolerance: float, + backend: str, + segment_params: SegmentationParameters | None = None, + sampling_params: SamplingParameters | None = None, + filter_params: FilterParameters | None = None, + transforms: BaseImageProcessor | Callable | None = None, + restrict_to_tissue: bool = False, + max_open_slides_per_worker: int = 16, + ): + self.catalog_path = Path(catalog_path) + self.path = wsi_path + self.mask_path = mask_path + self.target_spacing = target_spacing + self.backend = backend + self.name = wsi_path.stem.replace(" ", "_") + self.transforms = transforms + self.restrict_to_tissue = restrict_to_tissue + self.max_open_slides_per_worker = max(1, int(max_open_slides_per_worker)) + self._load_catalog() + + if restrict_to_tissue: + _wsi = WholeSlideImage( + path=self.path, + mask_path=self.mask_path, + backend=self.backend, + segment=self.mask_path is None, + segment_params=segment_params, + sampling_params=sampling_params, + ) + contours, holes = _wsi.detect_contours( + target_spacing=target_spacing, + tolerance=tolerance, + filter_params=filter_params, + ) + scale = _wsi.level_downsamples[_wsi.seg_level] + self.contours = _wsi.scaleContourDim( + contours, (1.0 / scale[0], 1.0 / scale[1]) + ) + self.holes = _wsi.scaleHolesDim(holes, (1.0 / scale[0], 1.0 / scale[1])) + self.tissue_mask = _wsi.annotation_mask["tissue"] + self.seg_spacing = _wsi.get_level_spacing(_wsi.seg_level) + self.spacing_at_level_0 = _wsi.get_level_spacing(0) + + def _load_catalog(self): + _, pq, _ = require_pyarrow() + table = pq.read_table( + str(self.catalog_path), + columns=[ + "coord_index", + "x", + "y", + "contour_index", + "target_tile_size", + "tile_level", + "resize_factor", + "tile_size_resized", + "tile_size_lv0", + ], + memory_map=True, + ) + columns = table.to_pydict() + self.coord_index = np.asarray(columns["coord_index"], dtype=np.int64) + self.x = np.asarray(columns["x"], dtype=np.int64) + self.y = np.asarray(columns["y"], dtype=np.int64) + self.contour_index = np.asarray(columns["contour_index"], dtype=np.int64) + self.target_tile_size = np.asarray(columns["target_tile_size"], dtype=np.int64) + self.tile_level = np.asarray(columns["tile_level"], dtype=np.int64) + self.resize_factor = np.asarray(columns["resize_factor"], dtype=np.float64) + self.tile_size_resized = np.asarray(columns["tile_size_resized"], dtype=np.int64) + self.tile_size_lv0 = np.asarray(columns["tile_size_lv0"], dtype=np.int64) + + expected = np.arange(len(self.coord_index), dtype=np.int64) + if not np.array_equal(self.coord_index, expected): + raise ValueError( + f"Catalog coord_index must be contiguous 0..N-1 for {self.catalog_path}" + ) + + @classmethod + def _get_worker_cache( + cls, + ) -> OrderedDict[tuple[str, str], wsd.WholeSlideImage]: + pid = os.getpid() + if pid not in cls._WSI_CACHE_BY_PID: + cls._WSI_CACHE_BY_PID[pid] = OrderedDict() + return cls._WSI_CACHE_BY_PID[pid] + + def _get_wsi(self) -> wsd.WholeSlideImage: + key = (str(self.path), str(self.backend)) + cache = self._get_worker_cache() + cached = cache.pop(key, None) + if cached is not None: + cache[key] = cached + return cached + + reader = wsd.WholeSlideImage(self.path, backend=self.backend) + cache[key] = reader + while len(cache) > self.max_open_slides_per_worker: + _, evicted = cache.popitem(last=False) + close_fn = getattr(evicted, "close", None) + if callable(close_fn): + close_fn() + return reader + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + row_idx = int(idx) + wsi = self._get_wsi() + tile_level = int(self.tile_level[row_idx]) + tile_spacing = wsi.spacings[tile_level] + tile_arr = wsi.get_patch( + int(self.x[row_idx]), + int(self.y[row_idx]), + int(self.tile_size_resized[row_idx]), + int(self.tile_size_resized[row_idx]), + spacing=tile_spacing, + center=False, + ) + if self.restrict_to_tissue: + contour_idx = int(self.contour_index[row_idx]) + contour = self.contours[contour_idx] + holes = self.holes[contour_idx] + tissue_checker = HasEnoughTissue( + contour=contour, + contour_holes=holes, + tissue_mask=self.tissue_mask, + tile_size=int(self.target_tile_size[row_idx]), + tile_spacing=tile_spacing, + resize_factor=float(self.resize_factor[row_idx]), + seg_spacing=self.seg_spacing, + spacing_at_level_0=self.spacing_at_level_0, + ) + tissue_mask = tissue_checker.get_tile_mask( + int(self.x[row_idx]), int(self.y[row_idx]) + ) + if tissue_mask.shape[:2] != tile_arr.shape[:2]: + raise ValueError("Mask and tile shapes do not match") + tile_arr = cv2.bitwise_and(tile_arr, tile_arr, mask=tissue_mask) + + tile = Image.fromarray(tile_arr).convert("RGB") + target_size = int(self.target_tile_size[row_idx]) + resized_size = int(self.tile_size_resized[row_idx]) + if target_size != resized_size: + tile = tile.resize((target_size, target_size)) + if self.transforms: + if isinstance(self.transforms, BaseImageProcessor): + tile = self.transforms(tile, return_tensors="pt")[ + "pixel_values" + ].squeeze(0) + else: + tile = self.transforms(tile) + return int(self.coord_index[row_idx]), tile diff --git a/slide2vec/data/tile_catalog.py b/slide2vec/data/tile_catalog.py new file mode 100644 index 0000000..1023a2f --- /dev/null +++ b/slide2vec/data/tile_catalog.py @@ -0,0 +1,137 @@ +import json +from pathlib import Path + +import numpy as np + +from slide2vec.utils.parquet import require_pyarrow + + +CATALOG_COLUMNS = ( + "slide_id", + "wsi_path", + "mask_path", + "coord_index", + "x", + "y", + "contour_index", + "target_tile_size", + "tile_level", + "resize_factor", + "tile_size_resized", + "tile_size_lv0", +) + + +def _safe_name(wsi_path: Path) -> str: + return wsi_path.stem.replace(" ", "_") + + +def get_slide_catalog_path(catalog_dir: Path, wsi_path: Path) -> Path: + return Path(catalog_dir, f"{_safe_name(wsi_path)}.parquet") + + +def build_tile_catalog_for_slide( + *, + coordinates_path: Path, + wsi_path: Path, + mask_path: Path | None, + catalog_path: Path, +) -> Path: + pa, pq, _ = require_pyarrow() + coords = np.load(coordinates_path, allow_pickle=False) + n_tile = int(len(coords)) + + coord_index = np.arange(n_tile, dtype=np.int64) + x = np.asarray(coords["x"], dtype=np.int64) + y = np.asarray(coords["y"], dtype=np.int64) + contour_index = np.asarray(coords["contour_index"], dtype=np.int64) + target_tile_size = np.asarray(coords["target_tile_size"], dtype=np.int64) + tile_level = np.asarray(coords["tile_level"], dtype=np.int64) + resize_factor = np.asarray(coords["resize_factor"], dtype=np.float64) + tile_size_resized = np.asarray(coords["tile_size_resized"], dtype=np.int64) + tile_size_lv0 = np.asarray(coords["tile_size_lv0"], dtype=np.int64) + + slide_id = _safe_name(wsi_path) + table = pa.table( + { + "slide_id": np.full(n_tile, slide_id, dtype=object), + "wsi_path": np.full(n_tile, str(wsi_path), dtype=object), + "mask_path": np.full( + n_tile, + str(mask_path) if mask_path is not None else None, + dtype=object, + ), + "coord_index": coord_index, + "x": x, + "y": y, + "contour_index": contour_index, + "target_tile_size": target_tile_size, + "tile_level": tile_level, + "resize_factor": resize_factor, + "tile_size_resized": tile_size_resized, + "tile_size_lv0": tile_size_lv0, + } + ) + catalog_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(table, str(catalog_path), compression="zstd") + return catalog_path + + +def _should_rebuild_catalog(catalog_path: Path, coordinates_path: Path) -> bool: + if not catalog_path.exists(): + return True + return catalog_path.stat().st_mtime < coordinates_path.stat().st_mtime + + +def ensure_tile_catalogs( + *, + slide_mask_pairs: list[tuple[Path, Path | None]], + coordinates_dir: Path, + catalog_dir: Path, + force_rebuild: bool = False, +) -> dict[str, Path]: + slide_to_catalog: dict[str, Path] = {} + manifest_rows: list[dict[str, str | int]] = [] + for wsi_path, mask_path in slide_mask_pairs: + name = _safe_name(wsi_path) + coordinates_path = Path(coordinates_dir, f"{name}.npy") + if not coordinates_path.exists(): + raise FileNotFoundError(f"Missing coordinates file: {coordinates_path}") + catalog_path = get_slide_catalog_path(catalog_dir, wsi_path) + if force_rebuild or _should_rebuild_catalog(catalog_path, coordinates_path): + build_tile_catalog_for_slide( + coordinates_path=coordinates_path, + wsi_path=wsi_path, + mask_path=mask_path, + catalog_path=catalog_path, + ) + + coords = np.load(coordinates_path, allow_pickle=False) + slide_to_catalog[str(wsi_path)] = catalog_path + manifest_rows.append( + { + "slide_id": name, + "wsi_path": str(wsi_path), + "mask_path": str(mask_path) if mask_path is not None else None, + "coordinates_path": str(coordinates_path), + "catalog_path": str(catalog_path), + "tiles": int(len(coords)), + } + ) + + manifest = { + "schema_version": 1, + "columns": list(CATALOG_COLUMNS), + "slides": manifest_rows, + } + manifest_path = Path(catalog_dir, "manifest.json") + manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8") + return slide_to_catalog + + +__all__ = [ + "CATALOG_COLUMNS", + "get_slide_catalog_path", + "build_tile_catalog_for_slide", + "ensure_tile_catalogs", +] diff --git a/slide2vec/embed.py b/slide2vec/embed.py index 89b3acf..f99a611 100644 --- a/slide2vec/embed.py +++ b/slide2vec/embed.py @@ -1,13 +1,16 @@ import gc import os -import h5py import tqdm import torch import argparse import traceback import torchvision import pandas as pd +import numpy as np import multiprocessing as mp +import inspect +import time +import stat from pathlib import Path from contextlib import nullcontext @@ -17,7 +20,12 @@ from slide2vec.utils import fix_random_seeds from slide2vec.utils.config import get_cfg_from_file, setup_distributed from slide2vec.models import ModelFactory -from slide2vec.data import TileDataset, RegionUnfolding +from slide2vec.data import ( + TileDataset, + TileCatalogDataset, + RegionUnfolding, + ensure_tile_catalogs, +) from slide2vec.hs2p.hs2p.wsi import SamplingParameters torchvision.disable_beta_transforms_warning() @@ -65,6 +73,7 @@ def create_dataset( wsi_path, mask_path, coordinates_dir, + catalog_path, target_spacing, tolerance, backend, @@ -73,7 +82,29 @@ def create_dataset( filter_params, transforms, restrict_to_tissue: bool, + use_parquet: bool, + max_open_slides_per_worker: int, ): + if use_parquet: + if catalog_path is None: + raise ValueError("catalog_path must be provided when speed.use_parquet=true") + return TileCatalogDataset( + catalog_path=catalog_path, + wsi_path=wsi_path, + mask_path=mask_path, + target_spacing=target_spacing, + tolerance=tolerance, + backend=backend, + segment_params=segment_params, + sampling_params=sampling_params, + filter_params=filter_params, + transforms=transforms, + restrict_to_tissue=restrict_to_tissue, + max_open_slides_per_worker=max_open_slides_per_worker, + ) + + if coordinates_dir is None: + raise ValueError("coordinates_dir must be provided when speed.use_parquet=false") return TileDataset( wsi_path=wsi_path, mask_path=mask_path, @@ -86,33 +117,89 @@ def create_dataset( filter_params=filter_params, transforms=transforms, restrict_to_tissue=restrict_to_tissue, + max_open_slides_per_worker=max_open_slides_per_worker, ) -def run_inference(dataloader, model, device, autocast_context, unit, batch_size, feature_path, feature_dim, dtype, run_on_cpu: False): +def run_inference( + dataloader, + model, + device, + autocast_context, + unit, + batch_size, + shard_prefix, + expected_num_samples, + run_on_cpu: bool, +): device_name = f"GPU {distributed.get_global_rank()}" if not run_on_cpu else "CPU" - with h5py.File(feature_path, "w") as f: - features = f.create_dataset("features", shape=(0, *feature_dim), maxshape=(None, *feature_dim), dtype=dtype, chunks=(batch_size, *feature_dim)) - indices = f.create_dataset("indices", shape=(0,), maxshape=(None,), dtype='int64', chunks=(batch_size,)) - with torch.inference_mode(), autocast_context: - for batch in tqdm.tqdm( - dataloader, - desc=f"Inference on {device_name}", - unit=unit, - unit_scale=batch_size, - leave=False, - position=2 + distributed.get_global_rank(), - ): - idx, image = batch - image = image.to(device, non_blocking=True) - feature = model(image)["embedding"].cpu().numpy() - features.resize(features.shape[0] + feature.shape[0], axis=0) - features[-feature.shape[0]:] = feature - indices.resize(indices.shape[0] + idx.shape[0], axis=0) - indices[-idx.shape[0]:] = idx.cpu().numpy() - - # cleanup - del image, feature + features_path = Path(f"{shard_prefix}.features.npy") + indices_path = Path(f"{shard_prefix}.indices.npy") + for fp in (features_path, indices_path): + if fp.exists(): + os.remove(fp) + write_offset = 0 + features_mm = None + indices_mm = None + with torch.inference_mode(), autocast_context: + for batch in tqdm.tqdm( + dataloader, + desc=f"Inference on {device_name}", + unit=unit, + unit_scale=batch_size, + leave=False, + position=2 + distributed.get_global_rank(), + ): + idx, image = batch + image = image.to(device, non_blocking=True) + feature = model(image)["embedding"].cpu().numpy() + idx_np = idx.cpu().numpy() + batch_len = int(feature.shape[0]) + + if features_mm is None: + features_mm = np.lib.format.open_memmap( + str(features_path), + mode="w+", + dtype=feature.dtype, + shape=(int(expected_num_samples), *feature.shape[1:]), + ) + indices_mm = np.lib.format.open_memmap( + str(indices_path), + mode="w+", + dtype=np.int64, + shape=(int(expected_num_samples),), + ) + + end = write_offset + batch_len + if end > expected_num_samples: + raise RuntimeError( + f"Received {end} samples but expected {expected_num_samples} for {features_path}" + ) + features_mm[write_offset:end] = feature + indices_mm[write_offset:end] = idx_np + write_offset = end + + # cleanup + del image, feature, idx, idx_np, batch + + if features_mm is None: + raise RuntimeError(f"No batches were produced for {features_path}") + if write_offset != expected_num_samples: + raise RuntimeError( + f"Wrote {write_offset} samples but expected {expected_num_samples} for {features_path}" + ) + features_mm.flush() + indices_mm.flush() + del features_mm, indices_mm + for fp in (features_path, indices_path): + try: + os.chmod( + fp, + stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH, + ) + except OSError: + # Best effort only; some filesystems ignore chmod/ACL updates. + pass # cleanup if not run_on_cpu: @@ -120,28 +207,132 @@ def run_inference(dataloader, model, device, autocast_context, unit, batch_size, gc.collect() -def load_sort_and_deduplicate_features(tmp_dir, name, expected_len=None): - features_list, indices_list = [], [] +def _open_npy_for_read_with_retry( + fp: Path, + max_attempts: int = 60, + initial_delay_s: float = 0.1, +): + delay = max(0.01, float(initial_delay_s)) + last_exc = None + for attempt in range(max_attempts): + try: + return np.load(str(fp), mmap_mode="r", allow_pickle=False) + except (BlockingIOError, FileNotFoundError, PermissionError, OSError) as exc: + last_exc = exc + errno_val = getattr(exc, "errno", None) + msg = str(exc).lower() + if isinstance(exc, PermissionError): + try: + os.chmod(fp, 0o666) + except OSError: + pass + retryable = ( + isinstance(exc, (BlockingIOError, FileNotFoundError, PermissionError)) + or errno_val in (11, 13) + or "unable to lock file" in msg + or "resource temporarily unavailable" in msg + or "permission denied" in msg + ) + if (not retryable) or (attempt >= max_attempts - 1): + break + time.sleep(delay) + delay = min(delay * 1.5, 1.0) + + if isinstance(last_exc, FileNotFoundError): + raise FileNotFoundError( + f"Missing shard file after retries: {fp}. " + "If running multi-node, ensure shard temp dir is shared across ranks." + ) from last_exc + if last_exc is not None: + raise RuntimeError(f"Unable to open shard file for reading: {fp}") from last_exc + raise RuntimeError(f"Unable to open shard file for reading: {fp}") + + +def _tmp_shard_paths(tmp_dir: Path, name: str, rank: int) -> tuple[Path, Path]: + prefix = Path(tmp_dir, f"{name}-rank_{rank}") + return Path(f"{prefix}.features.npy"), Path(f"{prefix}.indices.npy") + + +def load_features_with_indexed_fill(tmp_dir, name, expected_len: int): + if expected_len < 1: + raise ValueError(f"expected_len must be >= 1, got {expected_len}") + + merged_features = None + seen = torch.zeros(expected_len, dtype=torch.bool) + for rank in range(distributed.get_global_size()): - fp = tmp_dir / f"{name}-rank_{rank}.h5" - with h5py.File(fp, "r") as f: - features_list.append(torch.from_numpy(f["features"][:])) - indices_list.append(torch.from_numpy(f["indices"][:])) - os.remove(fp) - features = torch.cat(features_list, dim=0) - indices = torch.cat(indices_list, dim=0) - order = torch.argsort(indices) - indices = indices[order] - features = features[order] - # deduplicate - keep = torch.ones_like(indices, dtype=torch.bool) - keep[1:] = indices[1:] != indices[:-1] - indices_unique = indices[keep] - features_unique = features[keep] - if expected_len is not None: - assert len(indices_unique) == expected_len, f"Got {len(indices_unique)} items, expected {expected_len}" - assert torch.unique(indices_unique).numel() == len(indices_unique), "Indices are not unique after sorting" - return features_unique + feat_path, idx_path = _tmp_shard_paths(tmp_dir, name, rank) + feat_ds = _open_npy_for_read_with_retry(feat_path) + idx_ds = _open_npy_for_read_with_retry(idx_path) + if feat_ds.shape[0] != idx_ds.shape[0]: + raise RuntimeError( + f"Mismatched features/indices rows for rank {rank}: " + f"{feat_ds.shape[0]} vs {idx_ds.shape[0]}" + ) + + if merged_features is None: + probe = np.empty((), dtype=feat_ds.dtype) + merged_features = torch.empty( + (expected_len, *feat_ds.shape[1:]), + dtype=torch.from_numpy(probe).dtype, + ) + elif tuple(feat_ds.shape[1:]) != tuple(merged_features.shape[1:]): + raise RuntimeError( + f"Inconsistent feature shape for rank {rank}: got {feat_ds.shape[1:]}, " + f"expected {tuple(merged_features.shape[1:])}" + ) + + chunk_rows = 8192 + total_rows = int(idx_ds.shape[0]) + for start in range(0, total_rows, chunk_rows): + end = min(start + chunk_rows, total_rows) + idx_np = np.asarray(idx_ds[start:end], dtype=np.int64) + feat_np = np.asarray(feat_ds[start:end]) + + if idx_np.shape[0] != feat_np.shape[0]: + raise RuntimeError( + f"Mismatched chunk rows for rank {rank}: " + f"{idx_np.shape[0]} vs {feat_np.shape[0]}" + ) + if idx_np.size == 0: + continue + if np.any(idx_np < 0) or np.any(idx_np >= expected_len): + bad = idx_np[(idx_np < 0) | (idx_np >= expected_len)][0] + raise RuntimeError( + f"Out-of-range tile index {int(bad)} for expected_len={expected_len}" + ) + + _, first_pos = np.unique(idx_np, return_index=True) + first_mask_np = np.zeros(idx_np.shape[0], dtype=bool) + first_mask_np[first_pos] = True + + idx_first = torch.from_numpy(idx_np[first_mask_np]) + feat_first = torch.from_numpy(feat_np[first_mask_np]) + + unseen_mask = ~seen[idx_first] + if unseen_mask.any(): + idx_write = idx_first[unseen_mask] + merged_features[idx_write] = feat_first[unseen_mask] + seen[idx_write] = True + + del feat_ds, idx_ds + if feat_path.exists(): + os.remove(feat_path) + if idx_path.exists(): + os.remove(idx_path) + + if merged_features is None: + raise RuntimeError(f"No shard data found for {name}") + + missing = torch.nonzero(~seen, as_tuple=False) + if missing.numel() > 0: + missing_count = int(missing.numel()) + first_missing = int(missing[0].item()) + raise RuntimeError( + f"Missing {missing_count} tile embeddings after merge for {name}. " + f"First missing index: {first_missing}" + ) + return merged_features def resolve_output_dir(config_output_dir: str, cli_output_dir: str | None) -> Path: @@ -155,9 +346,27 @@ def resolve_output_dir(config_output_dir: str, cli_output_dir: str | None) -> Pa def cleanup_tmp_features(tmp_dir: Path, name: str): for rank in range(distributed.get_global_size()): - fp = tmp_dir / f"{name}-rank_{rank}.h5" - if fp.exists(): - os.remove(fp) + feat_path, idx_path = _tmp_shard_paths(tmp_dir, name, rank) + if feat_path.exists(): + os.remove(feat_path) + if idx_path.exists(): + os.remove(idx_path) + + +def cleanup_tmp_feature_dir(tmp_dir: Path): + if not tmp_dir.exists(): + return + for pattern in ("*.features.npy", "*.indices.npy"): + for fp in tmp_dir.glob(pattern): + try: + fp.unlink() + except OSError: + pass + try: + tmp_dir.rmdir() + except OSError: + # Keep directory if it's not empty or cannot be removed on this filesystem. + pass def main(args): @@ -175,14 +384,25 @@ def main(args): else: coordinates_dir = Path(cfg.output_dir, "coordinates") fix_random_seeds(cfg.seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + deterministic_inference = bool(cfg.speed.get("deterministic_inference", False)) + cudnn_benchmark = bool(cfg.speed.get("cudnn_benchmark", not deterministic_inference)) + torch.backends.cudnn.deterministic = deterministic_inference + torch.backends.cudnn.benchmark = cudnn_benchmark and not deterministic_inference unit = "tile" if cfg.model.level != "region" else "region" num_workers = min(mp.cpu_count(), cfg.speed.num_workers_embedding) if "SLURM_JOB_CPUS_PER_NODE" in os.environ: num_workers = min(num_workers, int(os.environ["SLURM_JOB_CPUS_PER_NODE"])) + persistent_workers = bool(cfg.speed.get("persistent_workers_embedding", True)) + prefetch_factor = max(1, int(cfg.speed.get("prefetch_factor_embedding", 4))) + use_parquet = bool(cfg.speed.get("use_parquet", True)) + dataloader_supports_in_order = ( + "in_order" in inspect.signature(torch.utils.data.DataLoader.__init__).parameters + ) + max_open_slides_per_worker = max( + 1, int(cfg.speed.get("max_open_slides_per_worker", 16)) + ) process_list = Path(cfg.output_dir, "process_list.csv") assert ( @@ -196,6 +416,8 @@ def main(args): process_df["mask_path"] = [None] * len(process_df) cols = ["wsi_name", "wsi_path", "mask_path", "tiling_status", "feature_status", "error", "traceback"] process_df = process_df[cols] + process_df["error"] = process_df["error"].astype("object") + process_df["traceback"] = process_df["traceback"].astype("object") skip_feature_extraction = process_df["feature_status"].str.contains("success").all() @@ -237,15 +459,37 @@ def main(args): wsi_paths_to_process = [Path(x) for x in process_stack.wsi_path.values.tolist()] mask_paths_to_process = [Path(x) if x is not None and not pd.isna(x) else None for x in process_stack.mask_path.values.tolist()] - combined_paths = zip(wsi_paths_to_process, mask_paths_to_process) + slide_mask_pairs = list(zip(wsi_paths_to_process, mask_paths_to_process)) features_dir = Path(cfg.output_dir, "features") if distributed.is_main_process(): features_dir.mkdir(exist_ok=True, parents=True) - tmp_dir = Path("/tmp") - if distributed.is_main_process(): - tmp_dir.mkdir(exist_ok=True, parents=True) + if use_parquet: + catalog_dir = Path(cfg.output_dir, "tile_catalog") + if distributed.is_main_process(): + ensure_tile_catalogs( + slide_mask_pairs=slide_mask_pairs, + coordinates_dir=coordinates_dir, + catalog_dir=catalog_dir, + ) + if distributed.is_enabled_and_multiple_gpus(): + torch.distributed.barrier() + slide_to_catalog = { + str(wsi_fp): Path(catalog_dir, f"{wsi_fp.stem.replace(' ', '_')}.parquet") + for wsi_fp, _ in slide_mask_pairs + } + else: + slide_to_catalog = {} + + tmp_dir = Path(cfg.output_dir, "tmp_feature_shards") + tmp_dir.mkdir(exist_ok=True, parents=True) + try: + os.chmod(tmp_dir, 0o777) + except OSError: + pass + if distributed.is_enabled_and_multiple_gpus(): + torch.distributed.barrier() autocast_context = ( torch.autocast(device_type="cuda", dtype=torch.float16) @@ -258,7 +502,7 @@ def main(args): print(f"transforms: {transforms}") for wsi_fp, mask_fp in tqdm.tqdm( - combined_paths, + slide_mask_pairs, desc="Inference", unit="slide", total=total, @@ -270,15 +514,17 @@ def main(args): feature_path = features_dir / f"{name}.pt" if cfg.model.save_tile_embeddings: feature_path = features_dir / f"{name}-tiles.pt" - tmp_feature_path = tmp_dir / f"{name}-rank_{distributed.get_global_rank()}.h5" + tmp_feature_prefix = tmp_dir / f"{name}-rank_{distributed.get_global_rank()}" status_info = {"status": "success"} local_failed = False try: + catalog_path = slide_to_catalog[str(wsi_fp)] if use_parquet else None dataset = create_dataset( wsi_path=wsi_fp, mask_path=mask_fp, coordinates_dir=coordinates_dir, + catalog_path=catalog_path, target_spacing=cfg.tiling.params.spacing, tolerance=cfg.tiling.params.tolerance, backend=cfg.tiling.backend, @@ -287,7 +533,18 @@ def main(args): filter_params=cfg.tiling.filter_params, transforms=transforms, restrict_to_tissue=cfg.model.restrict_to_tissue, + use_parquet=use_parquet, + max_open_slides_per_worker=max_open_slides_per_worker, ) + if len(dataset) == 0: + source_desc = ( + f"catalog {catalog_path}" + if use_parquet + else f"coordinates file {Path(coordinates_dir, f'{name}.npy')}" + ) + raise ValueError( + f"No tiles found for slide {wsi_fp} ({source_desc})" + ) if distributed.is_enabled_and_multiple_gpus(): sampler = torch.utils.data.DistributedSampler( dataset, @@ -296,21 +553,20 @@ def main(args): ) else: sampler = None - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=cfg.model.batch_size, - sampler=sampler, - num_workers=num_workers, - pin_memory=True, - ) - - # get feature dimension and dtype using a dry run - with torch.inference_mode(), autocast_context: - sample_batch = next(iter(dataloader)) - sample_image = sample_batch[1].to(model.device) - sample_feature = model(sample_image)["embedding"].cpu().numpy() - feature_dim = sample_feature.shape[1:] - dtype = sample_feature.dtype + loader_kwargs = { + "dataset": dataset, + "batch_size": cfg.model.batch_size, + "sampler": sampler, + "num_workers": num_workers, + "pin_memory": not run_on_cpu, + } + if num_workers > 0: + loader_kwargs["persistent_workers"] = persistent_workers + loader_kwargs["prefetch_factor"] = prefetch_factor + if dataloader_supports_in_order: + loader_kwargs["in_order"] = False + dataloader = torch.utils.data.DataLoader(**loader_kwargs) + expected_num_samples = len(sampler) if sampler is not None else len(dataset) run_inference( dataloader, @@ -319,9 +575,8 @@ def main(args): autocast_context, unit, cfg.model.batch_size, - tmp_feature_path, - feature_dim, - dtype, + tmp_feature_prefix, + expected_num_samples, run_on_cpu, ) @@ -356,7 +611,7 @@ def main(args): } elif distributed.is_main_process(): try: - wsi_feature = load_sort_and_deduplicate_features( + wsi_feature = load_features_with_indexed_fill( tmp_dir, name, expected_len=len(dataset) ) torch.save(wsi_feature, feature_path) @@ -408,6 +663,13 @@ def main(args): process_df.loc[ process_df["wsi_path"] == str(wsi_fp), "traceback" ] = status_info["traceback"] + else: + process_df.loc[ + process_df["wsi_path"] == str(wsi_fp), "error" + ] = None + process_df.loc[ + process_df["wsi_path"] == str(wsi_fp), "traceback" + ] = None process_df.to_csv(process_list, index=False) if distributed.is_enabled_and_multiple_gpus(): @@ -427,6 +689,7 @@ def main(args): f"Completed {unit}-level feature extraction: {slides_with_tiles - len(failed_feature_extraction)}/{slides_with_tiles}" ) print("=+=" * 10) + cleanup_tmp_feature_dir(tmp_dir) if distributed.is_enabled(): torch.distributed.destroy_process_group() diff --git a/slide2vec/utils/parquet.py b/slide2vec/utils/parquet.py new file mode 100644 index 0000000..9b46d8a --- /dev/null +++ b/slide2vec/utils/parquet.py @@ -0,0 +1,7 @@ +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.parquet as pq + + +def require_pyarrow(): + return pa, pq, ds diff --git a/tests/test_tile_catalog.py b/tests/test_tile_catalog.py new file mode 100644 index 0000000..9635468 --- /dev/null +++ b/tests/test_tile_catalog.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import numpy as np + +from slide2vec.data.tile_catalog import ensure_tile_catalogs +from slide2vec.utils.parquet import require_pyarrow + + +def test_tile_catalog_preserves_npy_row_order(tmp_path): + repo_root = Path(__file__).resolve().parents[1] + input_wsi = (repo_root / "tests" / "fixtures" / "input" / "test-wsi.tif").resolve() + gt_coords = np.load( + repo_root / "tests" / "fixtures" / "gt" / "test-wsi.npy", + allow_pickle=False, + ) + + coordinates_dir = tmp_path / "coordinates" + coordinates_dir.mkdir(parents=True, exist_ok=True) + np.save(coordinates_dir / "test-wsi.npy", gt_coords) + + catalog_dir = tmp_path / "tile_catalog" + mapping = ensure_tile_catalogs( + slide_mask_pairs=[(input_wsi, None)], + coordinates_dir=coordinates_dir, + catalog_dir=catalog_dir, + ) + + _, pq, _ = require_pyarrow() + table = pq.read_table(str(mapping[str(input_wsi)])) + columns = table.to_pydict() + + expected_idx = np.arange(len(gt_coords), dtype=np.int64) + np.testing.assert_array_equal(np.asarray(columns["coord_index"], dtype=np.int64), expected_idx) + np.testing.assert_array_equal(np.asarray(columns["x"], dtype=np.int64), gt_coords["x"].astype(np.int64)) + np.testing.assert_array_equal(np.asarray(columns["y"], dtype=np.int64), gt_coords["y"].astype(np.int64)) + + manifest_path = catalog_dir / "manifest.json" + assert manifest_path.exists() From 4b863f0e0d510a55c7d646ea6a5c91815ae3e87b Mon Sep 17 00:00:00 2001 From: clemsgrs Date: Thu, 26 Feb 2026 01:44:14 +0000 Subject: [PATCH 2/2] add pyarrow to requirements --- requirements.in | 1 + requirements.txt | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.in b/requirements.in index 64dc594..6bb5667 100644 --- a/requirements.in +++ b/requirements.in @@ -6,6 +6,7 @@ pandas pillow tqdm wandb +pyarrow torch>=2.3,<2.8 torchvision>=0.18.0 opencv-python diff --git a/requirements.txt b/requirements.txt index 68459d5..ad3f4e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,11 @@ wandb numpy==1.26.1 pandas pillow +pyarrow einops tqdm omegaconf wholeslidedata huggingface_hub torch==2.1.0 -torchvision==0.16.0 \ No newline at end of file +torchvision==0.16.0