Skip to content
93 changes: 93 additions & 0 deletions examples/dilate_segmentation_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Example script to dilate segmentation masks in image files."""

import sys
from argparse import ArgumentParser
from glob import glob
from itertools import chain
from pathlib import Path

import numpy as np

from PartSegCore.image_operations import dilate, to_binary_image
from PartSegCore.mask.io_functions import LoadROIImage, MaskProjectTuple, SaveROI, SaveROIOptions
from PartSegCore.roi_info import ROIInfo
from PartSegCore.segmentation.watershed import calculate_distances_array, get_neigh
from PartSegCore_compiled_backend.sprawl_utils.find_split import euclidean_sprawl


def convert_mask(file_path: Path, radius: float, suffix: str, only_selected: bool):
if radius <= 0:
raise ValueError("Radius must be positive")
print(f"Converting {file_path} to {suffix} with radius {radius}")

project = LoadROIImage.load([str(file_path)])

roi_ = project.roi_info.roi.squeeze()
selected_components = project.selected_components
if only_selected and selected_components is not None:
mask = np.isin(roi_, selected_components)
roi_ = roi_ * mask

unique_values = np.unique(roi_)
mapping = np.zeros(np.max(unique_values) + 1, dtype=roi_.dtype)
for new_val, old_val in enumerate(unique_values):
mapping[old_val] = new_val
roi_ = mapping[roi_]

selected_components = list(range(1, len(unique_values)))

bin_roi = to_binary_image(roi_)
sprawl_area = dilate(bin_roi, [radius, radius], True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify 2D-only dilation is intentional.

The dilation uses [radius, radius] which appears to dilate only in X and Y dimensions, not Z. For 3D volumetric images, this might not produce the expected result. If Z-axis dilation is needed, consider using radius directly or [radius, radius, radius].

Run the following script to check if this project typically uses 3D images:


🏁 Script executed:

#!/bin/bash
# Check for 3D image handling in the codebase and verify dilation usage patterns
rg -nP -C3 'dilate\(' --type=py | head -50

Length of output: 3798


🏁 Script executed:

#!/bin/bash
# Inspect dilate implementation
sed -n '90,150p' package/PartSegCore/image_operations.py

# Inspect example usage and components_num context
sed -n '1,200p' examples/dilate_segmentation_mask.py

Length of output: 4620


Ensure full 3D dilation and correct component count

  • Replace 2D‐only dilation with full 3D, e.g.:
  • sprawl_area = dilate(bin_roi, [radius, radius], True)
  • sprawl_area = dilate(bin_roi, [radius, radius, radius], layer=False)
  • Compute component count from the filtered ROI, not the original, e.g.:
  • components_num = np.max(project.roi_info.roi)
  • components_num = np.max(roi_)
    
    
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sprawl_area = dilate(bin_roi, [radius, radius], True)
# … earlier in the function …
sprawl_area = dilate(bin_roi, [radius, radius, radius], layer=False)
components_num = np.max(roi_)
# … later in the function …

components_num = np.max(roi_)
neigh, dist = calculate_distances_array(project.image.spacing, get_neigh(True))
roi = project.image.fit_array_to_image(
euclidean_sprawl(
sprawl_area,
roi_,
components_num,
neigh,
dist,
)
)
new_file_path = file_path.with_name(file_path.stem + suffix + file_path.suffix)
print("Saving to ", new_file_path)
SaveROI.save(
str(new_file_path),
MaskProjectTuple(
file_path=str(new_file_path),
image=project.image,
roi_info=ROIInfo(roi),
spacing=project.spacing,
frame_thickness=project.frame_thickness,
selected_components=selected_components,
),
SaveROIOptions(
relative_path=True,
mask_data=True,
frame_thickness=project.frame_thickness,
spacing=project.spacing,
),
)


def main():
parser = ArgumentParser()
parser.add_argument("project_files", nargs="+", type=str)
parser.add_argument("--dilate", type=int, default=1)
parser.add_argument("--suffix", type=str, default="_dilated")
parser.add_argument("--only-selected", action="store_true")

args = parser.parse_args()

files = list(chain.from_iterable(glob(x) for x in args.project_files))
if not files:
print("No files found")
return -1

for file_path in files:
convert_mask(Path(file_path).absolute(), args.dilate, args.suffix, args.only_selected)
return 0


if __name__ == "__main__":
sys.exit(main())
32 changes: 32 additions & 0 deletions examples/extract_components_from_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from argparse import ArgumentParser
from glob import glob
from itertools import chain
from pathlib import Path

from PartSegCore.mask.io_functions import LoadROIImage, SaveComponents, SaveComponentsOptions


def cut_components(project_file: Path):
project = LoadROIImage.load([str(project_file)])
SaveComponents.save(
str(project_file.parent / (project_file.stem + "_components")),
project,
SaveComponentsOptions(
frame=0,
mask_data=True,
),
)


def main():
parser = ArgumentParser()
parser.add_argument("project_files", nargs="+", type=str)
args = parser.parse_args()
files = list(chain.from_iterable(glob(f) for f in args.project_files))
for file_path in files:
print(f"Processing {file_path}")
cut_components(Path(file_path))


if __name__ == "__main__":
main()
49 changes: 49 additions & 0 deletions examples/max_projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
For collections of tiff files save a max projection of each file.
"""

from argparse import ArgumentParser
from glob import glob
from itertools import chain
from pathlib import Path

from PartSegImage import Image, ImageWriter, TiffImageReader


def max_projection(file_path: Path, suffix: str = "_max", with_mask: bool = False):
if with_mask:
mask_path = str(file_path.parent / (file_path.stem + "_mask" + file_path.suffix))
else:
mask_path = None
image = TiffImageReader.read_image(str(file_path), mask_path)
if "Z" not in image.axis_order:
raise ValueError(f"Image {file_path} does not have Z axis")
max_proj = image.get_data().max(axis=image.axis_order.index("Z"))
if with_mask:
mask_projection = image.mask.max(axis=image.array_axis_order.index("Z"))
else:
mask_projection = None
image2 = Image(
max_proj, spacing=image.spacing[1:], axes_order=image.axis_order.replace("Z", ""), mask=mask_projection
)
ImageWriter.save(image2, str(file_path.with_name(file_path.stem + suffix + file_path.suffix)))
if with_mask:
ImageWriter.save_mask(image2, str(file_path.with_name(file_path.stem + suffix + "_mask" + file_path.suffix)))


def main():
parser = ArgumentParser()
parser.add_argument("image_files", nargs="+", type=str)
parser.add_argument("--suffix", type=str, default="_max")
parser.add_argument("--with-mask", action="store_true")
args = parser.parse_args()
files = list(chain.from_iterable(glob(f) for f in args.image_files))
for file_path in files:
if args.with_mask and Path(file_path).stem.endswith("_mask"):
continue
print(f"Processing {file_path}")
max_projection(Path(file_path), args.suffix, args.with_mask)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion package/PartSegImage/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def read_image(
read image file with optional mask file

:param image_path: path or opened file contains image
:param mask_path:
:param mask_path: path or opened file contains mask
:param callback_function: function for provide information about progress in reading file (for progressbar)
:param default_spacing: used if file do not contains information about spacing
(or metadata format is not supported)
Expand Down
Loading