From ed17d7de48ef43886b572d7c7484eee0d804eda6 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Thu, 5 Feb 2026 17:33:07 +0100 Subject: [PATCH 1/5] fix bug in loading masking image --- src/ngio/images/_ome_zarr_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ngio/images/_ome_zarr_container.py b/src/ngio/images/_ome_zarr_container.py index 57d31d37..fa812538 100644 --- a/src/ngio/images/_ome_zarr_container.py +++ b/src/ngio/images/_ome_zarr_container.py @@ -565,7 +565,7 @@ def get_masked_image( masking_label, masking_table = self._find_matching_masking_label( masking_label_name=masking_label_name, masking_table_name=masking_table_name, - pixel_size=pixel_size, + pixel_size=image.pixel_size, ) return MaskedImage( group_handler=image._group_handler, From a3e51f7631b9635dc161778e56c09903467e1c0e Mon Sep 17 00:00:00 2001 From: lorenzo Date: Fri, 6 Feb 2026 10:02:41 +0100 Subject: [PATCH 2/5] fix bug in loading masking roi --- pyproject.toml | 2 ++ src/ngio/io_pipes/_io_pipes_masked.py | 9 ++++----- src/ngio/io_pipes/_match_shape.py | 4 ++-- src/ngio/io_pipes/_zoom_transform.py | 4 ++-- tests/unit/images/test_masked_images.py | 20 +++++++++++++++----- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 819e5b7c..36bacde4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,8 @@ dependencies = [ "numpy", "filelock", "zarr>3", + "scipy", + "fsspec", "anndata", "pydantic", "pandas>=1.2.0,<3.0.0", diff --git a/src/ngio/io_pipes/_io_pipes_masked.py b/src/ngio/io_pipes/_io_pipes_masked.py index 388baa53..c2523db5 100644 --- a/src/ngio/io_pipes/_io_pipes_masked.py +++ b/src/ngio/io_pipes/_io_pipes_masked.py @@ -84,15 +84,14 @@ def _setup_numpy_getters( if allow_rescaling: _zoom_transform = BaseZoomTransform( - input_dimensions=dimensions, - target_dimensions=label_dimensions, + input_dimensions=label_dimensions, + target_dimensions=dimensions, order="nearest", ) if label_transforms is None or len(label_transforms) == 0: label_transforms = [_zoom_transform] else: label_transforms = [_zoom_transform, *label_transforms] - label_slicing_dict = roi_to_slicing_dict( roi=roi, pixel_size=label_dimensions.pixel_size, @@ -317,8 +316,8 @@ def _setup_dask_getters( if allow_rescaling: _zoom_transform = BaseZoomTransform( - input_dimensions=dimensions, - target_dimensions=label_dimensions, + input_dimensions=label_dimensions, + target_dimensions=dimensions, order="nearest", ) if label_transforms is None or len(label_transforms) == 0: diff --git a/src/ngio/io_pipes/_match_shape.py b/src/ngio/io_pipes/_match_shape.py index a010a9c3..972d64d1 100644 --- a/src/ngio/io_pipes/_match_shape.py +++ b/src/ngio/io_pipes/_match_shape.py @@ -1,6 +1,6 @@ import logging from collections.abc import Sequence -from enum import Enum +from enum import StrEnum import dask.array as da import numpy as np @@ -10,7 +10,7 @@ logger = logging.getLogger(f"ngio:{__name__}") -class Action(str, Enum): +class Action(StrEnum): NONE = "none" PAD = "pad" TRIM = "trim" diff --git a/src/ngio/io_pipes/_zoom_transform.py b/src/ngio/io_pipes/_zoom_transform.py index b33d4de4..b309d388 100644 --- a/src/ngio/io_pipes/_zoom_transform.py +++ b/src/ngio/io_pipes/_zoom_transform.py @@ -28,7 +28,7 @@ def __init__( self._order: InterpolationOrder = order def _normalize_shape( - self, slice_: slice | int | tuple, scale: float, max_dim: int + self, slice_: slice | int | list[int], scale: float, max_dim: int ) -> int: if isinstance(slice_, slice): _start = slice_.start or 0 @@ -43,7 +43,7 @@ def _normalize_shape( elif isinstance(slice_, int): target_shape = 1 - elif isinstance(slice_, tuple): + elif isinstance(slice_, list): target_shape = len(slice_) * scale else: raise ValueError(f"Unsupported slice type: {type(slice_)}") diff --git a/tests/unit/images/test_masked_images.py b/tests/unit/images/test_masked_images.py index 48bbffdc..6201d71c 100644 --- a/tests/unit/images/test_masked_images.py +++ b/tests/unit/images/test_masked_images.py @@ -17,9 +17,11 @@ def _draw_random_labels(shape: tuple[int, ...], num_regions: int): for i, (y, x) in enumerate(seeds_list, start=1): markers[y, x] = i - image = ndimage.distance_transform_edt(markers == 0).astype("uint32") - labels = watershed(image, markers).astype("uint32") - return image, labels + dt_image = ndimage.distance_transform_edt(markers == 0) + assert isinstance(dt_image, np.ndarray) + dt_image = dt_image.astype("float32") + labels = watershed(dt_image, markers).astype("uint32") + return dt_image, labels @pytest.mark.parametrize( @@ -99,7 +101,10 @@ def test_masking( _roi_mask = masked_image.get_roi_masked(label=1, mode=array_mode) # Check that the mask is binary after masking - np.testing.assert_allclose(np.unique(_roi_mask), [0, 1]) + if isinstance(_roi_mask, np.ndarray): + np.testing.assert_allclose(np.unique(_roi_mask), [0, 1]) + else: + np.testing.assert_allclose(np.unique(_roi_mask.compute()), [0, 1]) # Just test the API masked_image.set_roi(label=1, patch=np.zeros_like(_roi_array), zoom_factor=1.123) @@ -113,12 +118,17 @@ def test_masking( for label_id in labels_stats.keys(): label_mask = masked_new_label.get_roi(label_id, mode=array_mode) + if not isinstance(label_mask, np.ndarray): + label_mask = label_mask.compute() label_mask = np.full(label_mask.shape, label_id, dtype=label_mask.dtype) # Set the label only inside the mask masked_new_label.set_roi_masked(label_id, label_mask) # rerun the stats on the new masked label - unique_labels, counts = np.unique(masked_new_label.get_array(), return_counts=True) + masked_array = masked_new_label.get_array() + if not isinstance(masked_array, np.ndarray): + masked_array = masked_array.compute() + unique_labels, counts = np.unique(masked_array, return_counts=True) labels_stats_masked = dict(zip(unique_labels, counts, strict=True)) assert labels_stats == labels_stats_masked From 3b6df443a4bb0e8a58fdb0542f00bdf08d122f8a Mon Sep 17 00:00:00 2001 From: lorenzo Date: Fri, 6 Feb 2026 10:50:26 +0100 Subject: [PATCH 3/5] fix major bug in masking loading when label was of different dimension --- src/ngio/io_pipes/_io_pipes_masked.py | 14 ----- src/ngio/io_pipes/_match_shape.py | 74 +++------------------------ 2 files changed, 6 insertions(+), 82 deletions(-) diff --git a/src/ngio/io_pipes/_io_pipes_masked.py b/src/ngio/io_pipes/_io_pipes_masked.py index c2523db5..3d4be723 100644 --- a/src/ngio/io_pipes/_io_pipes_masked.py +++ b/src/ngio/io_pipes/_io_pipes_masked.py @@ -34,7 +34,6 @@ def _numpy_label_to_bool_mask( data_shape: tuple[int, ...], label_axes: tuple[str, ...], data_axes: tuple[str, ...], - allow_rescaling: bool = True, ) -> np.ndarray: """Convert label data to a boolean mask.""" if label is not None: @@ -47,7 +46,6 @@ def _numpy_label_to_bool_mask( reference_shape=data_shape, array_axes=label_axes, reference_axes=data_axes, - allow_rescaling=allow_rescaling, ) return bool_mask @@ -146,7 +144,6 @@ def __init__( self._label_id = roi.label self._fill_value = fill_value - self._allow_rescaling = allow_rescaling super().__init__( zarr_array=zarr_array, slicing_ops=self._data_getter.slicing_ops, @@ -163,14 +160,12 @@ def get(self) -> np.ndarray: """Get the masked data as a numpy array.""" data = self._data_getter() label_data = self._label_data_getter() - bool_mask = _numpy_label_to_bool_mask( label_data=label_data, label=self.label_id, data_shape=data.shape, label_axes=self._label_data_getter.axes_ops.output_axes, data_axes=self._data_getter.axes_ops.output_axes, - allow_rescaling=self._allow_rescaling, ) if bool_mask.shape != data.shape: bool_mask = np.broadcast_to(bool_mask, data.shape) @@ -213,7 +208,6 @@ def __init__( self._data_getter = _data_getter self._label_data_getter = _label_data_getter self._label_id = roi.label - self._allow_rescaling = allow_rescaling self._data_setter = NumpySetter( zarr_array=zarr_array, @@ -245,7 +239,6 @@ def set(self, patch: np.ndarray) -> None: data_shape=data.shape, label_axes=self._label_data_getter.axes_ops.output_axes, data_axes=self._data_getter.axes_ops.output_axes, - allow_rescaling=self._allow_rescaling, ) if bool_mask.shape != data.shape: bool_mask = np.broadcast_to(bool_mask, data.shape) @@ -266,7 +259,6 @@ def _dask_label_to_bool_mask( data_shape: tuple[int, ...], label_axes: tuple[str, ...], data_axes: tuple[str, ...], - allow_rescaling: bool = True, ) -> DaskArray: """Convert label data to a boolean mask.""" if label is not None: @@ -279,7 +271,6 @@ def _dask_label_to_bool_mask( reference_shape=data_shape, array_axes=label_axes, reference_axes=data_axes, - allow_rescaling=allow_rescaling, ) return bool_mask @@ -324,7 +315,6 @@ def _setup_dask_getters( label_transforms = [_zoom_transform] else: label_transforms = [_zoom_transform, *label_transforms] - label_slicing_dict = roi_to_slicing_dict( roi=roi, pixel_size=label_dimensions.pixel_size, @@ -378,7 +368,6 @@ def __init__( self._label_data_getter = _label_data_getter self._label_id = roi.label self._fill_value = fill_value - self._allow_rescaling = allow_rescaling super().__init__( zarr_array=zarr_array, slicing_ops=self._data_getter.slicing_ops, @@ -401,7 +390,6 @@ def get(self) -> DaskArray: data_shape=data_shape, label_axes=self._label_data_getter.axes_ops.output_axes, data_axes=self._data_getter.axes_ops.output_axes, - allow_rescaling=self._allow_rescaling, ) if bool_mask.shape != data.shape: bool_mask = da.broadcast_to(bool_mask, data.shape) @@ -445,7 +433,6 @@ def __init__( self._label_data_getter = _label_data_getter self._label_id = roi.label - self._allow_rescaling = allow_rescaling self._data_setter = DaskSetter( zarr_array=zarr_array, @@ -479,7 +466,6 @@ def set(self, patch: DaskArray) -> None: data_shape=data_shape, label_axes=self._label_data_getter.axes_ops.output_axes, data_axes=self._data_getter.axes_ops.output_axes, - allow_rescaling=self._allow_rescaling, ) if bool_mask.shape != data.shape: bool_mask = da.broadcast_to(bool_mask, data.shape) diff --git a/src/ngio/io_pipes/_match_shape.py b/src/ngio/io_pipes/_match_shape.py index 972d64d1..2a7c100b 100644 --- a/src/ngio/io_pipes/_match_shape.py +++ b/src/ngio/io_pipes/_match_shape.py @@ -14,7 +14,6 @@ class Action(StrEnum): NONE = "none" PAD = "pad" TRIM = "trim" - RESCALING = "rescaling" def _compute_pad_widths( @@ -103,51 +102,6 @@ def _dask_trim( return array[tuple(slices)] -def _compute_rescaling_shape( - array_shape: tuple[int, ...], - actions: list[Action], - target_shape: tuple[int, ...], -) -> tuple[int, ...]: - rescaling_shape = [] - factor = [] - for act, s, ts in zip(actions, array_shape, target_shape, strict=True): - if act == Action.RESCALING: - rescaling_shape.append(ts) - factor.append(ts / s) - else: - rescaling_shape.append(s) - factor.append(1.0) - - logger.warning( - f"Images have a different shape ({array_shape} vs {target_shape}). " - f"Resolving by scaling with factors {factor}." - ) - return tuple(rescaling_shape) - - -def _numpy_rescaling( - array: np.ndarray, actions: list[Action], target_shape: tuple[int, ...] -) -> np.ndarray: - if all(act != Action.RESCALING for act in actions): - return array - from ngio.common._zoom import numpy_zoom - - rescaling_shape = _compute_rescaling_shape(array.shape, actions, target_shape) - return numpy_zoom(source_array=array, target_shape=rescaling_shape, order="nearest") - - -def _dask_rescaling( - array: da.Array, actions: list[Action], target_shape: tuple[int, ...] -) -> da.Array: - if all(act != Action.RESCALING for act in actions): - return array - from ngio.common._zoom import dask_zoom - - shape = tuple(int(s) for s in array.shape) - rescaling_shape = _compute_rescaling_shape(shape, actions, target_shape) - return dask_zoom(source_array=array, target_shape=rescaling_shape, order="nearest") - - def _check_axes(array_shape, reference_shape, array_axes, reference_axes): if len(array_shape) != len(array_axes): raise NgioValueError( @@ -183,7 +137,6 @@ def _compute_reshape_and_actions( array_axes: list[str], reference_axes: list[str], tolerance: int = 1, - allow_rescaling: bool = True, ) -> tuple[tuple[int, ...], list[Action]]: # Reshape array to match reference shape # And determine actions to be taken @@ -206,24 +159,20 @@ def _compute_reshape_and_actions( elif s2 < ref_shape: if (ref_shape - s2) <= tolerance: actions.append(Action.PAD) - elif allow_rescaling: - actions.append(Action.RESCALING) else: errors.append( - f"Cannot pad axis={ref_ax}:{s2}->{ref_shape} " + f"Cannot pad axis={ref_ax} from {s2} to {ref_shape} " "because shape difference is outside tolerance " - f"{tolerance}." + f"of {tolerance} pixels." ) elif s2 > ref_shape: if (s2 - ref_shape) <= tolerance: actions.append(Action.TRIM) - elif allow_rescaling: - actions.append(Action.RESCALING) else: errors.append( - f"Cannot trim axis={ref_ax}:{s2}->{ref_shape} " + f"Cannot trim axis={ref_ax} from {s2} to {ref_shape} " "because shape difference is outside tolerance " - f"{tolerance}." + f"of {tolerance} pixels." ) else: raise RuntimeError("Unreachable code reached.") @@ -233,8 +182,9 @@ def _compute_reshape_and_actions( "Cannot match shapes if the order is different." ) if errors: + error_msg = "\n - ".join(errors) raise NgioValueError( - "Array shape cannot be matched to reference shape:\n\n".join(errors) + f"Array shape cannot be matched to reference shape:\n - {error_msg}" ) return tuple(reshape_tuple), actions @@ -247,7 +197,6 @@ def numpy_match_shape( tolerance: int = 1, pad_mode: str = "constant", pad_values: int | float = 0, - allow_rescaling: bool = True, ): """Match the shape of a numpy array to a reference shape. @@ -268,9 +217,6 @@ def numpy_match_shape( pad_mode (str): The mode to use for padding. See numpy.pad for options. pad_values (int | float): The constant value to use for padding if pad_mode is 'constant'. - allow_rescaling (bool): If True, when the array differs more than the - tolerance, it will be rescaled to the reference shape. If False, - an error will be raised. """ _check_axes( array_shape=array.shape, @@ -291,10 +237,8 @@ def numpy_match_shape( array_axes=array_axes, reference_axes=reference_axes, tolerance=tolerance, - allow_rescaling=allow_rescaling, ) array = array.reshape(reshape_tuple) - array = _numpy_rescaling(array=array, actions=actions, target_shape=reference_shape) array = _numpy_pad( array=array, actions=actions, @@ -314,7 +258,6 @@ def dask_match_shape( tolerance: int = 1, pad_mode: str = "constant", pad_values: int | float = 0, - allow_rescaling: bool = True, ) -> da.Array: """Match the shape of a dask array to a reference shape. @@ -335,9 +278,6 @@ def dask_match_shape( pad_mode (str): The mode to use for padding. See numpy.pad for options. pad_values (int | float): The constant value to use for padding if pad_mode is 'constant'. - allow_rescaling (bool): If True, when the array differs more than the - tolerance, it will be rescalingd to the reference shape. If False, - an error will be raised. """ array_shape = tuple(int(s) for s in array.shape) _check_axes( @@ -358,10 +298,8 @@ def dask_match_shape( array_axes=array_axes, reference_axes=reference_axes, tolerance=tolerance, - allow_rescaling=allow_rescaling, ) array = da.reshape(array, reshape_tuple) - array = _dask_rescaling(array=array, actions=actions, target_shape=reference_shape) array = _dask_pad( array=array, actions=actions, From 7b2381469282928f95dcd2bca16a186c9401071f Mon Sep 17 00:00:00 2001 From: lorenzo Date: Fri, 6 Feb 2026 11:21:47 +0100 Subject: [PATCH 4/5] expand testing of mask loading --- tests/unit/images/test_masked_images.py | 105 +++++++++++++++++++++++- 1 file changed, 103 insertions(+), 2 deletions(-) diff --git a/tests/unit/images/test_masked_images.py b/tests/unit/images/test_masked_images.py index 6201d71c..0dba75c6 100644 --- a/tests/unit/images/test_masked_images.py +++ b/tests/unit/images/test_masked_images.py @@ -6,8 +6,14 @@ from scipy import ndimage from skimage.segmentation import watershed -from ngio import create_ome_zarr_from_array, open_ome_zarr_container +from ngio import ( + create_empty_ome_zarr, + create_ome_zarr_from_array, + open_ome_zarr_container, +) +from ngio.images._masked_image import MaskedImage from ngio.transforms import ZoomTransform +from ngio.utils._errors import NgioValueError def _draw_random_labels(shape: tuple[int, ...], num_regions: int): @@ -137,7 +143,102 @@ def test_masking( masked_new_label.set_roi(label_id, x, zoom_factor=1.1) -@pytest.mark.filterwarnings("ignore::anndata._warnings.ImplicitModificationWarning") +@pytest.mark.parametrize( + "image_scale, label_scale", + [("0", "1"), ("1", "0"), ("0", "4"), ("4", "0")], +) +def test_masking_at_different_res(image_scale: str, label_scale: str): + # Test on a real example + ome_zarr = create_empty_ome_zarr( + store={}, + shape=(2, 3, 200, 200), + pixelsize=1.0, + ) + masked_label = ome_zarr.derive_label("mask") + masked_label.set_array(np.ones(shape=masked_label.shape, dtype=masked_label.dtype)) + masked_label.consolidate() + + # Image at higher res and label at lower res, with zooming + image = ome_zarr.get_image(path=image_scale) + masking_label = ome_zarr.get_label(name="mask", path=label_scale) + masking_table = masked_label.build_masking_roi_table() + masked_image = MaskedImage( + group_handler=image._group_handler, + path=image_scale, + meta_handler=image.meta_handler, + label=masking_label, + masking_roi_table=masking_table, + ) + roi = masked_image.get_roi_masked_as_numpy(label=1) + assert roi.shape == image.shape + roi = masked_image.get_roi_masked_as_dask(label=1) + assert roi.shape == image.shape + + with pytest.raises(NgioValueError): + masked_image.get_roi_masked_as_numpy(label=1, allow_rescaling=False) + + with pytest.raises(NgioValueError): + masked_image.get_roi_masked_as_dask(label=1, allow_rescaling=False) + + +def test_masking_oneoff_handling(): + # Test on a real example + ome_zarr = create_empty_ome_zarr( + store={}, + shape=(2, 3, 200, 200), + pixelsize=1.0, + ) + + masked_label2 = ome_zarr.derive_label("mask2", shape=(1, 3, 200, 200)) + masked_label2.set_array( + np.ones(shape=masked_label2.shape, dtype=masked_label2.dtype) + ) + masked_label2.consolidate() + + masked_label = ome_zarr.derive_label("mask", shape=(1, 3, 199, 201)) + masked_label.set_array(np.ones(shape=masked_label.shape, dtype=masked_label.dtype)) + masked_label.consolidate() + + # Image at higher res and label at lower res, with zooming + image = ome_zarr.get_image(path="0") + masking_label = ome_zarr.get_label(name="mask", path="0") + # Use the second label to allow testing the one-off handling in the zoom transform + masking_table = masked_label2.build_masking_roi_table() + masked_image = MaskedImage( + group_handler=image._group_handler, + path="0", + meta_handler=image.meta_handler, + label=masking_label, + masking_roi_table=masking_table, + ) + roi = masked_image.get_roi_masked_as_numpy(label=1, allow_rescaling=False) + assert roi.shape == (2, 3, 200, 200) + roi = masked_image.get_roi_masked_as_dask(label=1, allow_rescaling=False) + assert roi.shape == (2, 3, 200, 200) + + # Fail if more than 1 pixel difference and allow_rescaling is False + masked_label = ome_zarr.derive_label("mask", shape=(1, 3, 198, 202), overwrite=True) + masked_label.set_array(np.ones(shape=masked_label.shape, dtype=masked_label.dtype)) + masked_label.consolidate() + + # Image at higher res and label at lower res, with zooming + image = ome_zarr.get_image(path="0") + masking_label = ome_zarr.get_label(name="mask", path="0") + # Use the second label to allow testing the one-off handling in the zoom transform + masking_table = masked_label2.build_masking_roi_table() + masked_image = MaskedImage( + group_handler=image._group_handler, + path="0", + meta_handler=image.meta_handler, + label=masking_label, + masking_roi_table=masking_table, + ) + with pytest.raises(NgioValueError): + roi = masked_image.get_roi_masked_as_numpy(label=1, allow_rescaling=False) + with pytest.raises(NgioValueError): + roi = masked_image.get_roi_masked_as_dask(label=1, allow_rescaling=False) + + @pytest.mark.parametrize( ("label", "c", "zoom_factor", "expected_shape"), [ From 8cbd99bf14d86bafb62ab7a2e995fc5dff9b164f Mon Sep 17 00:00:00 2001 From: lorenzo Date: Fri, 6 Feb 2026 11:29:53 +0100 Subject: [PATCH 5/5] update change log --- CHANGELOG.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b96af26..116a9e18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,12 @@ # Changelog -## [Unreleased] +## [0.5.2] + +### Fix +- Fix critical bug in masking roi image handling causing incorrect results when image and mask have different pixel sizes. +- Fix bug in loading masking roi images when paths other than default are used. + +## [0.5.1] ### Fix - Fix bug causing incorrect channel metadata when creating an image.