Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Changelog

## [Unreleased]
## [0.5.2]

### Fix
- Fix critical bug in masking roi image handling causing incorrect results when image and mask have different pixel sizes.
- Fix bug in loading masking roi images when paths other than default are used.

## [0.5.1]

### Fix
- Fix bug causing incorrect channel metadata when creating an image.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ dependencies = [
"numpy",
"filelock",
"zarr>3",
"scipy",
"fsspec",
"anndata",
"pydantic",
"pandas>=1.2.0,<3.0.0",
Expand Down
2 changes: 1 addition & 1 deletion src/ngio/images/_ome_zarr_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def get_masked_image(
masking_label, masking_table = self._find_matching_masking_label(
masking_label_name=masking_label_name,
masking_table_name=masking_table_name,
pixel_size=pixel_size,
pixel_size=image.pixel_size,
)
return MaskedImage(
group_handler=image._group_handler,
Expand Down
23 changes: 4 additions & 19 deletions src/ngio/io_pipes/_io_pipes_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def _numpy_label_to_bool_mask(
data_shape: tuple[int, ...],
label_axes: tuple[str, ...],
data_axes: tuple[str, ...],
allow_rescaling: bool = True,
) -> np.ndarray:
"""Convert label data to a boolean mask."""
if label is not None:
Expand All @@ -47,7 +46,6 @@ def _numpy_label_to_bool_mask(
reference_shape=data_shape,
array_axes=label_axes,
reference_axes=data_axes,
allow_rescaling=allow_rescaling,
)
return bool_mask

Expand Down Expand Up @@ -84,15 +82,14 @@ def _setup_numpy_getters(

if allow_rescaling:
_zoom_transform = BaseZoomTransform(
input_dimensions=dimensions,
target_dimensions=label_dimensions,
input_dimensions=label_dimensions,
target_dimensions=dimensions,
order="nearest",
)
if label_transforms is None or len(label_transforms) == 0:
label_transforms = [_zoom_transform]
else:
label_transforms = [_zoom_transform, *label_transforms]

label_slicing_dict = roi_to_slicing_dict(
roi=roi,
pixel_size=label_dimensions.pixel_size,
Expand Down Expand Up @@ -147,7 +144,6 @@ def __init__(

self._label_id = roi.label
self._fill_value = fill_value
self._allow_rescaling = allow_rescaling
super().__init__(
zarr_array=zarr_array,
slicing_ops=self._data_getter.slicing_ops,
Expand All @@ -164,14 +160,12 @@ def get(self) -> np.ndarray:
"""Get the masked data as a numpy array."""
data = self._data_getter()
label_data = self._label_data_getter()

bool_mask = _numpy_label_to_bool_mask(
label_data=label_data,
label=self.label_id,
data_shape=data.shape,
label_axes=self._label_data_getter.axes_ops.output_axes,
data_axes=self._data_getter.axes_ops.output_axes,
allow_rescaling=self._allow_rescaling,
)
if bool_mask.shape != data.shape:
bool_mask = np.broadcast_to(bool_mask, data.shape)
Expand Down Expand Up @@ -214,7 +208,6 @@ def __init__(
self._data_getter = _data_getter
self._label_data_getter = _label_data_getter
self._label_id = roi.label
self._allow_rescaling = allow_rescaling

self._data_setter = NumpySetter(
zarr_array=zarr_array,
Expand Down Expand Up @@ -246,7 +239,6 @@ def set(self, patch: np.ndarray) -> None:
data_shape=data.shape,
label_axes=self._label_data_getter.axes_ops.output_axes,
data_axes=self._data_getter.axes_ops.output_axes,
allow_rescaling=self._allow_rescaling,
)
if bool_mask.shape != data.shape:
bool_mask = np.broadcast_to(bool_mask, data.shape)
Expand All @@ -267,7 +259,6 @@ def _dask_label_to_bool_mask(
data_shape: tuple[int, ...],
label_axes: tuple[str, ...],
data_axes: tuple[str, ...],
allow_rescaling: bool = True,
) -> DaskArray:
"""Convert label data to a boolean mask."""
if label is not None:
Expand All @@ -280,7 +271,6 @@ def _dask_label_to_bool_mask(
reference_shape=data_shape,
array_axes=label_axes,
reference_axes=data_axes,
allow_rescaling=allow_rescaling,
)
return bool_mask

Expand Down Expand Up @@ -317,15 +307,14 @@ def _setup_dask_getters(

if allow_rescaling:
_zoom_transform = BaseZoomTransform(
input_dimensions=dimensions,
target_dimensions=label_dimensions,
input_dimensions=label_dimensions,
target_dimensions=dimensions,
order="nearest",
)
if label_transforms is None or len(label_transforms) == 0:
label_transforms = [_zoom_transform]
else:
label_transforms = [_zoom_transform, *label_transforms]

label_slicing_dict = roi_to_slicing_dict(
roi=roi,
pixel_size=label_dimensions.pixel_size,
Expand Down Expand Up @@ -379,7 +368,6 @@ def __init__(
self._label_data_getter = _label_data_getter
self._label_id = roi.label
self._fill_value = fill_value
self._allow_rescaling = allow_rescaling
super().__init__(
zarr_array=zarr_array,
slicing_ops=self._data_getter.slicing_ops,
Expand All @@ -402,7 +390,6 @@ def get(self) -> DaskArray:
data_shape=data_shape,
label_axes=self._label_data_getter.axes_ops.output_axes,
data_axes=self._data_getter.axes_ops.output_axes,
allow_rescaling=self._allow_rescaling,
)
if bool_mask.shape != data.shape:
bool_mask = da.broadcast_to(bool_mask, data.shape)
Expand Down Expand Up @@ -446,7 +433,6 @@ def __init__(
self._label_data_getter = _label_data_getter

self._label_id = roi.label
self._allow_rescaling = allow_rescaling

self._data_setter = DaskSetter(
zarr_array=zarr_array,
Expand Down Expand Up @@ -480,7 +466,6 @@ def set(self, patch: DaskArray) -> None:
data_shape=data_shape,
label_axes=self._label_data_getter.axes_ops.output_axes,
data_axes=self._data_getter.axes_ops.output_axes,
allow_rescaling=self._allow_rescaling,
)
if bool_mask.shape != data.shape:
bool_mask = da.broadcast_to(bool_mask, data.shape)
Expand Down
78 changes: 8 additions & 70 deletions src/ngio/io_pipes/_match_shape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum

import dask.array as da
import numpy as np
Expand All @@ -10,11 +10,10 @@
logger = logging.getLogger(f"ngio:{__name__}")


class Action(str, Enum):
class Action(StrEnum):
NONE = "none"
PAD = "pad"
TRIM = "trim"
RESCALING = "rescaling"


def _compute_pad_widths(
Expand Down Expand Up @@ -103,51 +102,6 @@ def _dask_trim(
return array[tuple(slices)]


def _compute_rescaling_shape(
array_shape: tuple[int, ...],
actions: list[Action],
target_shape: tuple[int, ...],
) -> tuple[int, ...]:
rescaling_shape = []
factor = []
for act, s, ts in zip(actions, array_shape, target_shape, strict=True):
if act == Action.RESCALING:
rescaling_shape.append(ts)
factor.append(ts / s)
else:
rescaling_shape.append(s)
factor.append(1.0)

logger.warning(
f"Images have a different shape ({array_shape} vs {target_shape}). "
f"Resolving by scaling with factors {factor}."
)
return tuple(rescaling_shape)


def _numpy_rescaling(
array: np.ndarray, actions: list[Action], target_shape: tuple[int, ...]
) -> np.ndarray:
if all(act != Action.RESCALING for act in actions):
return array
from ngio.common._zoom import numpy_zoom

rescaling_shape = _compute_rescaling_shape(array.shape, actions, target_shape)
return numpy_zoom(source_array=array, target_shape=rescaling_shape, order="nearest")


def _dask_rescaling(
array: da.Array, actions: list[Action], target_shape: tuple[int, ...]
) -> da.Array:
if all(act != Action.RESCALING for act in actions):
return array
from ngio.common._zoom import dask_zoom

shape = tuple(int(s) for s in array.shape)
rescaling_shape = _compute_rescaling_shape(shape, actions, target_shape)
return dask_zoom(source_array=array, target_shape=rescaling_shape, order="nearest")


def _check_axes(array_shape, reference_shape, array_axes, reference_axes):
if len(array_shape) != len(array_axes):
raise NgioValueError(
Expand Down Expand Up @@ -183,7 +137,6 @@ def _compute_reshape_and_actions(
array_axes: list[str],
reference_axes: list[str],
tolerance: int = 1,
allow_rescaling: bool = True,
) -> tuple[tuple[int, ...], list[Action]]:
# Reshape array to match reference shape
# And determine actions to be taken
Expand All @@ -206,24 +159,20 @@ def _compute_reshape_and_actions(
elif s2 < ref_shape:
if (ref_shape - s2) <= tolerance:
actions.append(Action.PAD)
elif allow_rescaling:
actions.append(Action.RESCALING)
else:
errors.append(
f"Cannot pad axis={ref_ax}:{s2}->{ref_shape} "
f"Cannot pad axis={ref_ax} from {s2} to {ref_shape} "
"because shape difference is outside tolerance "
f"{tolerance}."
f"of {tolerance} pixels."
)
elif s2 > ref_shape:
if (s2 - ref_shape) <= tolerance:
actions.append(Action.TRIM)
elif allow_rescaling:
actions.append(Action.RESCALING)
else:
errors.append(
f"Cannot trim axis={ref_ax}:{s2}->{ref_shape} "
f"Cannot trim axis={ref_ax} from {s2} to {ref_shape} "
"because shape difference is outside tolerance "
f"{tolerance}."
f"of {tolerance} pixels."
)
else:
raise RuntimeError("Unreachable code reached.")
Expand All @@ -233,8 +182,9 @@ def _compute_reshape_and_actions(
"Cannot match shapes if the order is different."
)
if errors:
error_msg = "\n - ".join(errors)
raise NgioValueError(
"Array shape cannot be matched to reference shape:\n\n".join(errors)
f"Array shape cannot be matched to reference shape:\n - {error_msg}"
)
return tuple(reshape_tuple), actions

Expand All @@ -247,7 +197,6 @@ def numpy_match_shape(
tolerance: int = 1,
pad_mode: str = "constant",
pad_values: int | float = 0,
allow_rescaling: bool = True,
):
"""Match the shape of a numpy array to a reference shape.

Expand All @@ -268,9 +217,6 @@ def numpy_match_shape(
pad_mode (str): The mode to use for padding. See numpy.pad for options.
pad_values (int | float): The constant value to use for padding if
pad_mode is 'constant'.
allow_rescaling (bool): If True, when the array differs more than the
tolerance, it will be rescaled to the reference shape. If False,
an error will be raised.
"""
_check_axes(
array_shape=array.shape,
Expand All @@ -291,10 +237,8 @@ def numpy_match_shape(
array_axes=array_axes,
reference_axes=reference_axes,
tolerance=tolerance,
allow_rescaling=allow_rescaling,
)
array = array.reshape(reshape_tuple)
array = _numpy_rescaling(array=array, actions=actions, target_shape=reference_shape)
array = _numpy_pad(
array=array,
actions=actions,
Expand All @@ -314,7 +258,6 @@ def dask_match_shape(
tolerance: int = 1,
pad_mode: str = "constant",
pad_values: int | float = 0,
allow_rescaling: bool = True,
) -> da.Array:
"""Match the shape of a dask array to a reference shape.

Expand All @@ -335,9 +278,6 @@ def dask_match_shape(
pad_mode (str): The mode to use for padding. See numpy.pad for options.
pad_values (int | float): The constant value to use for padding if
pad_mode is 'constant'.
allow_rescaling (bool): If True, when the array differs more than the
tolerance, it will be rescalingd to the reference shape. If False,
an error will be raised.
"""
array_shape = tuple(int(s) for s in array.shape)
_check_axes(
Expand All @@ -358,10 +298,8 @@ def dask_match_shape(
array_axes=array_axes,
reference_axes=reference_axes,
tolerance=tolerance,
allow_rescaling=allow_rescaling,
)
array = da.reshape(array, reshape_tuple)
array = _dask_rescaling(array=array, actions=actions, target_shape=reference_shape)
array = _dask_pad(
array=array,
actions=actions,
Expand Down
4 changes: 2 additions & 2 deletions src/ngio/io_pipes/_zoom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self._order: InterpolationOrder = order

def _normalize_shape(
self, slice_: slice | int | tuple, scale: float, max_dim: int
self, slice_: slice | int | list[int], scale: float, max_dim: int
) -> int:
if isinstance(slice_, slice):
_start = slice_.start or 0
Expand All @@ -43,7 +43,7 @@ def _normalize_shape(

elif isinstance(slice_, int):
target_shape = 1
elif isinstance(slice_, tuple):
elif isinstance(slice_, list):
target_shape = len(slice_) * scale
else:
raise ValueError(f"Unsupported slice type: {type(slice_)}")
Expand Down
Loading