Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ If `visualize` is set to `true`, a `visualization/` folder is created containing
- **`mask/`**: visualizations of the provided tissue (or annotation) mask
- **`tiling/`** (for `tiling.py`) or **`sampling/`** (for `sampling.py`): visualizations of the extracted or sampled tiles overlaid on the slide. For `sampling.py`, this includes subfolders for each category defined in the sampling parameters (e.g., tumor, stroma, etc.)

For sampling visualizations, overlays are drawn only for annotations that have a non-null color in `sampling_params.color_mapping`. Annotations with null color are left untouched (raw slide pixels, no darkening overlay).

These visualizations are useful for double-checking that the tiling or sampling process ran as expected.

### Process summary
Expand Down
2 changes: 1 addition & 1 deletion hs2p/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tiling:
drop_holes: false # whether or not to drop tiles whose center pixel falls withing an identified holes
use_padding: true # whether to pad the border of the slide
seg_params:
downsample: 64 # find the closest downsample in the slide for tissue segmentation
downsample: 16 # find the closest downsample in the slide for tissue segmentation
sthresh: 8 # segmentation threshold (positive integer, using a higher threshold leads to less foreground and more background detection) (not used when use_otsu=True)
sthresh_up: 255 # upper threshold value for scaling the binary mask
mthresh: 7 # median filter size (positive, odd integer)
Expand Down
44 changes: 42 additions & 2 deletions hs2p/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
from hs2p.wsi import extract_coordinates, filter_coordinates, sample_coordinates, save_coordinates, visualize_coordinates, overlay_mask_on_slide, SamplingParameters


def _validate_visualization_color_mapping(
*,
pixel_mapping: dict[str, int],
color_mapping: dict[str, list[int] | None],
):
missing_annotations = sorted(set(pixel_mapping.keys()) - set(color_mapping.keys()))
if missing_annotations:
raise ValueError(
"color_mapping is missing annotation keys required by pixel_mapping: "
+ ", ".join(missing_annotations)
)

for annotation, color in color_mapping.items():
if color is None:
continue
if not isinstance(color, (list, tuple)) or len(color) != 3:
raise ValueError(
f"color_mapping['{annotation}'] must be None or a length-3 RGB list/tuple"
)
if any((not isinstance(c, (int, np.integer)) or c < 0 or c > 255) for c in color):
raise ValueError(
f"color_mapping['{annotation}'] must contain integers in [0, 255]"
)


def get_args_parser(add_help: bool = True):
parser = argparse.ArgumentParser("hs2p", add_help=add_help)
parser.add_argument(
Expand All @@ -20,6 +45,9 @@ def get_args_parser(add_help: bool = True):
parser.add_argument(
"--skip-datetime", action="store_true", help="skip run id datetime prefix"
)
parser.add_argument(
"--skip-logging", action="store_true", help="skip logging configuration"
)
parser.add_argument(
"--output-dir",
type=str,
Expand Down Expand Up @@ -73,6 +101,10 @@ def process_slide(
}
else:
color_mapping = sampling_params.color_mapping
_validate_visualization_color_mapping(
pixel_mapping=sampling_params.pixel_mapping,
color_mapping=color_mapping,
)
p = [0] * 3 * len(color_mapping)
for k, v in sampling_params.pixel_mapping.items():
if color_mapping[k] is not None:
Expand Down Expand Up @@ -139,6 +171,8 @@ def process_slide(
mask_path=mask_path,
annotation=annotation,
palette=preview_palette,
pixel_mapping=sampling_params.pixel_mapping,
color_mapping=color_mapping,
)
else:
for annotation in sampling_params.pixel_mapping.keys():
Expand Down Expand Up @@ -188,6 +222,8 @@ def process_slide(
mask_path=mask_path,
annotation=annotation,
palette=preview_palette,
pixel_mapping=sampling_params.pixel_mapping,
color_mapping=color_mapping,
)
if cfg.visualize and mask_visualize_dir is not None:
mask_visu_path = Path(mask_visualize_dir, f"{wsi_name}.png")
Expand All @@ -211,7 +247,7 @@ def process_slide(


def main(args):

cfg = setup(args)
output_dir = Path(cfg.output_dir)

Expand Down Expand Up @@ -243,13 +279,15 @@ def main(args):

pixel_mapping = {k: v for e in cfg.tiling.sampling_params.pixel_mapping for k, v in e.items()}
tissue_percentage = {k: v for e in cfg.tiling.sampling_params.tissue_percentage for k, v in e.items()}
tissue_key_present = True
if "tissue" not in tissue_percentage:
tissue_key_present = False
tissue_percentage["tissue"] = cfg.tiling.params.min_tissue_percentage
if cfg.tiling.sampling_params.color_mapping is not None:
color_mapping = {k: v for e in cfg.tiling.sampling_params.color_mapping for k, v in e.items()}
else:
color_mapping = None

sampling_params = SamplingParameters(
pixel_mapping=pixel_mapping,
color_mapping=color_mapping,
Expand Down Expand Up @@ -333,6 +371,8 @@ def main(args):
for annotation, pct in tissue_percentage.items():
if pct is None:
continue
if not tissue_key_present and annotation == "tissue":
continue
slides_with_tiles = [
str(p)
for p in wsi_paths
Expand Down
71 changes: 50 additions & 21 deletions hs2p/wsi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,18 @@ def overlay_mask_on_tile(
tile: Image.Image,
mask: Image.Image,
palette: dict[str, int],
pixel_mapping: dict[str, int],
color_mapping: dict[str, list[int] | None],
alpha=0.5,
):

# create alpha mask
mask_arr = np.array(mask)
alpha_int = int(round(255 * alpha))
alpha_content = np.less_equal(mask_arr, 0).astype("uint8") * alpha_int + (
255 - alpha_int
alpha_content = _build_overlay_alpha(
mask_arr=mask_arr,
alpha=alpha,
pixel_mapping=pixel_mapping,
color_mapping=color_mapping,
)
alpha_content = Image.fromarray(alpha_content)

mask.putpalette(data=palette.tolist())
mask_rgb = mask.convert(mode="RGB")
Expand All @@ -324,13 +326,33 @@ def overlay_mask_on_tile(
return overlayed_image


def _build_overlay_alpha(
*,
mask_arr: np.ndarray,
alpha: float,
pixel_mapping: dict[str, int],
color_mapping: dict[str, list[int] | None],
) -> Image.Image:
alpha_int = int(round(255 * alpha))
active_labels = set()
for annotation, label_value in pixel_mapping.items():
if color_mapping.get(annotation) is not None:
active_labels.add(label_value)

overlay_mask = np.isin(mask_arr, list(active_labels)).astype("uint8")
alpha_content = np.less(overlay_mask, 1).astype("uint8") * alpha_int + (
255 - alpha_int
)
return Image.fromarray(alpha_content)


def overlay_mask_on_slide(
wsi_path: Path,
annotation_mask_path: Path,
downsample: int,
palette: dict[str, int],
pixel_mapping: dict[str, int],
color_mapping: dict[str, list[int]] | None = None,
color_mapping: dict[str, list[int] | None],
alpha: float = 0.5,
):
"""
Expand Down Expand Up @@ -373,21 +395,12 @@ def overlay_mask_on_slide(
)
mask = Image.fromarray(mask_arr)

# create alpha mask
alpha_int = int(round(255 * alpha))
if color_mapping is not None:
alpha_content = np.zeros_like(mask_arr)
for k, v in pixel_mapping.items():
if color_mapping[k] is not None:
alpha_content += mask_arr == v
alpha_content = np.less(alpha_content, 1).astype("uint8") * alpha_int + (
255 - alpha_int
)
else:
alpha_content = np.less_equal(mask_arr, 0).astype("uint8") * alpha_int + (
255 - alpha_int
)
alpha_content = Image.fromarray(alpha_content)
alpha_content = _build_overlay_alpha(
mask_arr=mask_arr,
alpha=alpha,
pixel_mapping=pixel_mapping,
color_mapping=color_mapping,
)

mask.putpalette(data=palette.tolist())
mask_rgb = mask.convert(mode="RGB")
Expand Down Expand Up @@ -417,6 +430,8 @@ def draw_grid_from_coordinates(
indices: list[int] | None = None,
mask = None,
palette: dict[str, int] | None = None,
pixel_mapping: dict[str, int] | None = None,
color_mapping: dict[str, list[int] | None] | None = None,
):
downsamples = wsi.level_downsamples[vis_level]
if indices is None:
Expand Down Expand Up @@ -478,6 +493,14 @@ def draw_grid_from_coordinates(
valid_tile = Image.fromarray(valid_tile).convert("RGB")

if mask is not None:
if (
palette is None
or pixel_mapping is None
or color_mapping is None
):
raise ValueError(
"palette, pixel_mapping, and color_mapping are required when mask overlay is enabled"
)
# need to scale (x, y) defined w.r.t. slide level 0
# to mask level 0
downsample = wsi.spacings[0] / mask.spacings[0]
Expand Down Expand Up @@ -509,6 +532,8 @@ def draw_grid_from_coordinates(
valid_tile,
masked_tile,
palette,
pixel_mapping,
color_mapping,
)

# paste the valid part into the white tile
Expand Down Expand Up @@ -561,6 +586,8 @@ def visualize_coordinates(
mask_path: Path | None = None,
annotation: str | None = None,
palette: dict[str, int] | None = None,
pixel_mapping: dict[str, int] | None = None,
color_mapping: dict[str, list[int] | None] | None = None,
):
wsi = WholeSlideImage(wsi_path, backend=backend)
vis_level = wsi.get_best_level_for_downsample_custom(downsample)
Expand Down Expand Up @@ -601,6 +628,8 @@ def visualize_coordinates(
thickness=grid_thickness,
mask=mask,
palette=palette,
pixel_mapping=pixel_mapping,
color_mapping=color_mapping,
)
wsi_name = wsi_path.stem.replace(" ", "_")
if annotation is not None:
Expand Down
5 changes: 4 additions & 1 deletion hs2p/wsi/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,10 +1017,13 @@ def process_contour(
[x_coords.flatten(), y_coords.flatten()]
).transpose()

# filter coordinates based on tissue coverage
# filter coordinates based on tissue coverage (reads tissue mask for active contour only)
keep_flags, tissue_pcts = tissue_checker.check_coordinates(coord_candidates)

# further filter coordinates based on black/white tile filtering
# (reads RGB values from the wsi at the tile level)
# (note that this step is after the tissue mask filtering, so it only applies to tiles that have enough tissue coverage)
# (speed could improved by working at a lower resolution)
keep_flags = self.filter_black_and_white_tiles(
keep_flags,
coord_candidates,
Expand Down
91 changes: 91 additions & 0 deletions tests/test_overlay_semantics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from pathlib import Path

import numpy as np
import pytest
from PIL import Image


cv2 = pytest.importorskip("cv2")
wsi_mod = pytest.importorskip("hs2p.wsi")


def _build_palette(mapping: dict[int, tuple[int, int, int]]) -> np.ndarray:
palette = np.zeros(shape=768, dtype=int)
for label, color in mapping.items():
palette[label * 3 : label * 3 + 3] = np.array(color, dtype=int)
return palette


def test_overlay_mask_on_tile_only_colored_labels_are_blended():
tile_arr = np.full((2, 2, 3), 120, dtype=np.uint8)
tile = Image.fromarray(tile_arr)
mask_arr = np.array([[0, 3], [4, 3]], dtype=np.uint8)
mask = Image.fromarray(mask_arr)

pixel_mapping = {"background": 0, "gleason3": 3, "gleason4": 4}
color_mapping = {
"background": None,
"gleason3": [255, 0, 0],
"gleason4": None,
}
palette = _build_palette({3: (255, 0, 0)})

overlay = wsi_mod.overlay_mask_on_tile(
tile=tile,
mask=mask,
palette=palette,
pixel_mapping=pixel_mapping,
color_mapping=color_mapping,
alpha=0.5,
)
overlay_arr = np.array(overlay)

assert np.array_equal(overlay_arr[0, 0], tile_arr[0, 0]) # background untouched
assert np.array_equal(overlay_arr[1, 0], tile_arr[1, 0]) # uncolored label untouched
assert not np.array_equal(overlay_arr[0, 1], tile_arr[0, 1]) # colored label blended


def test_overlay_mask_on_slide_matches_tile_semantics(monkeypatch):
slide_arr = np.full((2, 2, 3), 120, dtype=np.uint8)
mask_labels = np.array([[0, 3], [4, 3]], dtype=np.uint8)
mask_arr = np.stack([mask_labels, mask_labels, mask_labels], axis=-1)

class FakeWSI:
def __init__(self, path, backend="asap"):
self.path = Path(path)
self.spacings = [0.5]
self.level_dimensions = [(2, 2)]
self.level_downsamples = [(1.0, 1.0)]

def get_best_level_for_downsample_custom(self, downsample):
return 0

def get_slide(self, spacing):
if "mask" in self.path.name:
return mask_arr
return slide_arr

monkeypatch.setattr(wsi_mod, "WholeSlideImage", FakeWSI)

pixel_mapping = {"background": 0, "gleason3": 3, "gleason4": 4}
color_mapping = {
"background": None,
"gleason3": [255, 0, 0],
"gleason4": None,
}
palette = _build_palette({3: (255, 0, 0)})

overlay = wsi_mod.overlay_mask_on_slide(
wsi_path=Path("fake-wsi.tif"),
annotation_mask_path=Path("fake-mask.tif"),
downsample=1,
palette=palette,
pixel_mapping=pixel_mapping,
color_mapping=color_mapping,
alpha=0.5,
)
overlay_arr = np.array(overlay.convert("RGB"))

assert np.array_equal(overlay_arr[0, 0], slide_arr[0, 0]) # background untouched
assert np.array_equal(overlay_arr[1, 0], slide_arr[1, 0]) # uncolored label untouched
assert not np.array_equal(overlay_arr[0, 1], slide_arr[0, 1]) # colored label blended
Loading