From a8710fe9aec9270ebfb01a32c03d31a3bd24de75 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Tue, 12 Nov 2024 12:03:35 +0100 Subject: [PATCH 01/33] add long training option --- BANIS.py | 7 +++++-- README.md | 8 +++++++- aff_train.sh | 32 ++++++++++++++++++++++++++++++-- slurm_long_job.py | 22 ++++++++++++++++++++++ 4 files changed, 64 insertions(+), 5 deletions(-) create mode 100644 slurm_long_job.py diff --git a/BANIS.py b/BANIS.py index 2813310..551c8b7 100644 --- a/BANIS.py +++ b/BANIS.py @@ -201,7 +201,7 @@ def main(): f"lr{args.learning_rate}_wd{args.weight_decay}_sch{args.scheduler}_syn_{args.synthetic}" f"_drop{args.drop_slice_prob}_shift{args.shift_slice_prob}_int{args.intensity_aug}_noise{args.noise_scale}" f"_affine{args.affine}_ns{args.n_steps}_ss{args.small_size}" - ) + ) if not args.exp_name else args.exp_name save_dir = os.path.join(args.save_path, exp_name) os.makedirs(save_dir, exist_ok=True) @@ -243,7 +243,8 @@ def main(): model=model, train_dataloaders=DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, drop_last=True), - val_dataloaders=DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers) + val_dataloaders=DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers), + ckpt_path="last" if args.resume_from_checkpoint else None ) print("Training complete") @@ -297,6 +298,8 @@ def parse_args(): parser.add_argument("--log_every_n_steps", type=int, default=100, help="Log every n steps.") parser.add_argument("--val_check_interval", type=int, default=5000, help="Validation check interval.") parser.add_argument("--small_size", type=int, default=128, help="Size of the patches.") + parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="Resume training from the last checkpoint.") + parser.add_argument("--exp_name", type=str, default="", help="Experiment name (if empty, will be filled automatically).") return parser.parse_args() diff --git a/README.md b/README.md index ca6ac7e..e598122 100644 --- a/README.md +++ b/README.md @@ -30,12 +30,18 @@ python BANIS.py --seed 0 --batch_size 8 --n_steps 50000 --data_setting base --ba ``` Results are logged to TensorBoard. For GPUs with less than 48 GB memory, reduce `batch_size` (and adjust `n_steps` / `learning_rate`). For BANIS-L(arge) add `--model_id L --kernel_size 5`. Additional options are in `parse_args` of `BANIS.py`. -To run multiple jobs on Slurm, adjust `config.yaml` and `start_run.sh`, then: +To run multiple jobs on Slurm, adjust `config.yaml` and `aff_train.sh`, then: ```bash python slurm_job_scheduler.py ``` +To run training that restarts from the last checkpoint once the Slurm limit is reached, adjust `aff_train.sh`, then: + +```bash +python slurm_long_job.py --save_path /local/logging/dir/ --exp_name experiment_name [--other_arguments] +``` + ## Evaluation To evaluate a predicted segmentation (`.zarr` or `.npy`): diff --git a/aff_train.sh b/aff_train.sh index 8edc9da..65d1b78 100644 --- a/aff_train.sh +++ b/aff_train.sh @@ -6,11 +6,39 @@ #SBATCH --time=7-00 #SBATCH --cpus-per-task=32 #SBATCH --mem=500000 +#SBATCH --signal=B:USR1@300 +#SBATCH --open-mode=append mamba activate nisb # Set TMPDIR for torch compile to avoid race conditions when runs are started in parallel -tmp_dir="/tmp/banis/${SLURM_JOB_ID}/" +tmp_dir="/tmp/banis/${SLURM_JOBID}/" mkdir -p $tmp_dir export TMPDIR=$tmp_dir -srun python3 BANIS.py "$@" \ No newline at end of file + +resubmit_job() { + echo "Job is being resubmitted..." + sbatch --dependency=afterany:${SLURM_JOBID} \ + --export=ALL,RESUME=TRUE,LONG_JOB=TRUE,SAVE_DIR=${SAVE_DIR},LONG_JOB_ARGS="${LONG_JOB_ARGS}" \ + --output=${SAVE_DIR}/slurm-log.txt \ + "$0" "${@}" + exit 0 +} +trap 'resubmit_job' USR1 + +if ! [ -n "$LONG_JOB" ]; then + echo "Starting a normal job" + srun python3 -u BANIS.py "${@}" + exit 0 +fi + +if [ -n "$RESUME" ]; then + echo "Resuming from the last checkpoint" + echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" + srun python3 -u BANIS.py --resume_from_checkpoint True ${LONG_JOB_ARGS} & +else + echo "Starting long training from scratch." + export LONG_JOB_ARGS="${@}" + srun python3 -u BANIS.py "${@}" & +fi +wait diff --git a/slurm_long_job.py b/slurm_long_job.py new file mode 100644 index 0000000..3bf626b --- /dev/null +++ b/slurm_long_job.py @@ -0,0 +1,22 @@ +import argparse +import os + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Submit a job with custom save_dir and pass other arguments.") + parser.add_argument("--save_path", type=str, required=True, help="Path to save the model and logs") + parser.add_argument("--exp_name", type=str, required=True, help="Experiment name") + args, unknown_args = parser.parse_known_args() + + save_path = args.save_path + exp_name = args.exp_name + try: + save_dir = os.path.join(save_path, exp_name) + os.makedirs(f"{save_path}/{exp_name}", exist_ok=False) + except FileExistsError as error: + print(f"Error: Experiment already exists: {save_path}/{exp_name}") + exit(1) + + command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --save_path {save_path} --exp_name {exp_name}" + + # Execute the command + os.system(command) From ecc27ed46c05a47a4d9fb84cbca473b444bee10e Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Mon, 9 Dec 2024 11:07:14 +0100 Subject: [PATCH 02/33] validate in another process in parallel, fix all bugs --- BANIS.py | 55 ++++++++++++++++++-------------- aff_train.sh | 19 ++++++----- environment.yaml | 1 + slurm_long_job.py | 2 ++ validation_watcher.py | 74 +++++++++++++++++++++++++++++++++++++++++++ validation_watcher.sh | 44 +++++++++++++++++++++++++ 6 files changed, 164 insertions(+), 31 deletions(-) create mode 100644 validation_watcher.py create mode 100644 validation_watcher.sh diff --git a/BANIS.py b/BANIS.py index 551c8b7..586c3d7 100644 --- a/BANIS.py +++ b/BANIS.py @@ -28,9 +28,9 @@ class BANIS(LightningModule): PyTorch Lightning module for BANIS: Baseline for Affinity-based Neuron Instance Segmentation """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, **kwargs: Any): super().__init__() - self.save_hyperparameters(*args, **kwargs) + self.save_hyperparameters() print(f"hparams: \n{self.hparams}") self.model = create_mednext_v1( @@ -102,7 +102,8 @@ def _add_image(self, tag: str, img: torch.Tensor) -> None: global_step=self.global_step) def on_validation_epoch_end(self): - self.full_cube_inference("val") + pass + #self.full_cube_inference("val") def on_train_end(self): assert self.best_nerl_so_far["val"] > 0, "No best NERL found in validation" @@ -115,7 +116,7 @@ def on_train_end(self): self.full_cube_inference("train") @torch.no_grad() - def full_cube_inference(self, mode: str): + def full_cube_inference(self, mode: str, global_step=None): """Perform full cube inference. Expensive! Args: @@ -137,9 +138,9 @@ def full_cube_inference(self, mode: str): aff_pred = zarr.array(aff_pred, dtype=np.float16, store=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", chunks=(3, 512, 512, 512), overwrite=True) - self._evaluate_thresholds(aff_pred, os.path.join(seed_path, "skeleton.pkl"), mode) + self._evaluate_thresholds(aff_pred, os.path.join(seed_path, "skeleton.pkl"), mode, global_step) - def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str): + def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str, global_step=None): best_voi = best_voi_no_merge = 1e100 best_nerl = best_nerl_no_merge = -1 best_nerl_metrics = None @@ -156,7 +157,7 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str): for k, v in metrics.items(): if isinstance(v, (int, float)): - self.safe_add_scalar(f"{mode}_{k}_thr_{thr}", v) + self.safe_add_scalar(f"{mode}_{k}_thr_{thr}", v, global_step) if metrics["n_non0_mergers"] == 0: best_nerl_no_merge = max(best_nerl_no_merge, metrics["nerl"]) @@ -172,18 +173,18 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str): np.save(f"{self.hparams.save_dir}/pred_seg_best_nerl_{mode}.npy", pred_seg) best_voi = min(best_voi, metrics["voi_sum"]) - self.safe_add_scalar(f"{mode}_best_nerl", best_nerl) - self.safe_add_scalar(f"{mode}_best_voi", best_voi) - self.safe_add_scalar(f"{mode}_best_nerl_no_merge", best_nerl_no_merge) - self.safe_add_scalar(f"{mode}_best_voi_no_merge", best_voi_no_merge) + self.safe_add_scalar(f"{mode}_best_nerl", best_nerl, global_step) + self.safe_add_scalar(f"{mode}_best_voi", best_voi, global_step) + self.safe_add_scalar(f"{mode}_best_nerl_no_merge", best_nerl_no_merge, global_step) + self.safe_add_scalar(f"{mode}_best_voi_no_merge", best_voi_no_merge, global_step) for k, v in best_nerl_metrics.items(): if isinstance(v, (int, float)): self.safe_add_scalar(f"{mode}_best_nerl_{k}", v) - def safe_add_scalar(self, name: str, value: float) -> None: + def safe_add_scalar(self, name: str, value: float, global_step=None) -> None: try: # s.t. full_cube_inference can be called outside of .fit() without error - self.logger.experiment.add_scalar(name, value, self.global_step) + self.logger.experiment.add_scalar(name, value, self.global_step if global_step is None else global_step) except Exception as e: print(f"Error logging {name}: {e}") @@ -206,18 +207,25 @@ def main(): save_dir = os.path.join(args.save_path, exp_name) os.makedirs(save_dir, exist_ok=True) print(f"save dir: {save_dir}") - tb_logger = TensorBoardLogger(save_dir=save_dir) + tb_logger = TensorBoardLogger( + save_dir=args.save_path, + name=exp_name, + version="default", + ) tb_logger.experiment.add_text("save dir", save_dir) + model_checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + save_last=True, + mode="min", + save_top_k=100, + verbose=True, + save_on_train_epoch_end=False # automatically runs at the end of the validation + ) trainer = pl.Trainer( callbacks=[ DeviceStatsMonitor(), - ModelCheckpoint( - monitor="val_loss", - save_last=True, - mode="min", - save_top_k=100, - ), + model_checkpoint_callback ], logger=tb_logger, max_steps=args.n_steps, @@ -232,19 +240,20 @@ def main(): check_val_every_n_epoch=None, num_sanity_val_steps=args.n_debug_steps, ) + print(f"Checkpoints will be saved in: {trainer.default_root_dir}/checkpoints") train_data, val_data, n_channels = load_data(args) args.save_dir = save_dir args.num_input_channels = n_channels - model = BANIS(args) + model = BANIS(**vars(args)) trainer.fit( model=model, train_dataloaders=DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, drop_last=True), val_dataloaders=DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers), - ckpt_path="last" if args.resume_from_checkpoint else None + ckpt_path="last" if args.resume_from_last_checkpoint else None ) print("Training complete") @@ -298,7 +307,7 @@ def parse_args(): parser.add_argument("--log_every_n_steps", type=int, default=100, help="Log every n steps.") parser.add_argument("--val_check_interval", type=int, default=5000, help="Validation check interval.") parser.add_argument("--small_size", type=int, default=128, help="Size of the patches.") - parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="Resume training from the last checkpoint.") + parser.add_argument("--resume_from_last_checkpoint", action=argparse.BooleanOptionalAction, default=False, help="Resume training from the last checkpoint.") parser.add_argument("--exp_name", type=str, default="", help="Experiment name (if empty, will be filled automatically).") return parser.parse_args() diff --git a/aff_train.sh b/aff_train.sh index 65d1b78..9e31cd3 100644 --- a/aff_train.sh +++ b/aff_train.sh @@ -18,27 +18,30 @@ export TMPDIR=$tmp_dir resubmit_job() { echo "Job is being resubmitted..." - sbatch --dependency=afterany:${SLURM_JOBID} \ - --export=ALL,RESUME=TRUE,LONG_JOB=TRUE,SAVE_DIR=${SAVE_DIR},LONG_JOB_ARGS="${LONG_JOB_ARGS}" \ - --output=${SAVE_DIR}/slurm-log.txt \ - "$0" "${@}" - exit 0 + EXP_NAME=$(echo "${LONG_JOB_ARGS}" | grep -oP '(?<=--exp_name )\S+') + sbatch --dependency=afterany:${SLURM_JOBID} \ + --export=ALL,RESUME=TRUE,LONG_JOB=TRUE,SAVE_DIR=${SAVE_DIR},LONG_JOB_ARGS="${LONG_JOB_ARGS}" \ + --output=${SAVE_DIR}/slurm-log.txt \ + --job-name ${EXP_NAME} \ + "$0" "${@}" + exit 0 } trap 'resubmit_job' USR1 if ! [ -n "$LONG_JOB" ]; then echo "Starting a normal job" - srun python3 -u BANIS.py "${@}" + srun mamba run -n nisb --no-capture-output python3 -u BANIS.py "${@}" exit 0 fi if [ -n "$RESUME" ]; then echo "Resuming from the last checkpoint" echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" - srun python3 -u BANIS.py --resume_from_checkpoint True ${LONG_JOB_ARGS} & + srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u BANIS.py --resume_from_last_checkpoint ${LONG_JOB_ARGS} & else echo "Starting long training from scratch." export LONG_JOB_ARGS="${@}" - srun python3 -u BANIS.py "${@}" & + echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" + srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u BANIS.py ${LONG_JOB_ARGS} & fi wait diff --git a/environment.yaml b/environment.yaml index c1289de..520faa1 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,6 +1,7 @@ name: nisb channels: - conda-forge + - nodefaults dependencies: - _libgcc_mutex=0.1 - _openmp_mutex=4.5 diff --git a/slurm_long_job.py b/slurm_long_job.py index 3bf626b..1b1e245 100644 --- a/slurm_long_job.py +++ b/slurm_long_job.py @@ -17,6 +17,8 @@ exit(1) command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --save_path {save_path} --exp_name {exp_name}" + command_validation = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE --job-name {exp_name}_val --output {save_dir}/slurm-validation-log.txt validation_watcher.sh {' '.join(unknown_args)} --save_path {save_path} --exp_name {exp_name}" # Execute the command os.system(command) + os.system(command_validation) diff --git a/validation_watcher.py b/validation_watcher.py new file mode 100644 index 0000000..c26018e --- /dev/null +++ b/validation_watcher.py @@ -0,0 +1,74 @@ +import time +import os +from pathlib import Path +import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning import seed_everything +import argparse +import torch +from datetime import datetime + +from BANIS import BANIS + + +def main(): + args = parse_args() + seed_everything(args.seed, workers=True) + + ckpt_path = os.path.join(args.save_path, args.exp_name, "default", "checkpoints", "last.ckpt") + #ckpt_path = os.path.join(args.save_path, args.exp_name, "lightning_logs", "version_0", "checkpoints", "last.ckpt") + print(f"ckpt path: {ckpt_path}") + last_mod_time = 0 if args.start_from_earlier_ckpt else time.time() + + tb_logger = TensorBoardLogger( + save_dir=args.save_path, + name=args.exp_name, + version="default", + ) + + while True: + if os.path.exists(ckpt_path): + current_mod_time = os.path.getmtime(ckpt_path) + if current_mod_time > last_mod_time: + print(f"New checkpoint detected at {datetime.fromtimestamp(current_mod_time)}") + last_mod_time = current_mod_time + + model = BANIS.load_from_checkpoint(ckpt_path) + model = model.to("cuda") + trainer = pl.Trainer(logger=tb_logger, accelerator="gpu", devices=-1) + model.trainer = trainer # for the logger + + # global step of model loaded from checkpoint is 0 by default, until trainer is started + # see https://github.com/Lightning-AI/pytorch-lightning/issues/12819 + checkpoint = torch.load(ckpt_path) + model.full_cube_inference("val", checkpoint["global_step"]) + + elif time.time() > last_mod_time + 20*60*60: + print("No new checkpoint detected for 24 hours, terminating.") + break + + elif time.time() > last_mod_time + 20 * 60 * 60: + print("No checkpoint detected, terminating.") + break + + time.sleep(60) + + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility.") + parser.add_argument("--base_data_path", type=str, default="/cajal/nvmescratch/projects/NISB/", help="Base path for the dataset.") + parser.add_argument("--data_setting", type=str, default="base", help="Data setting identifier.") + parser.add_argument("--eval_ranges", type=float, nargs="+", default=torch.sigmoid(torch.tensor(list(range(-1, 12))).double() * 0.2).numpy().round(4).tolist(), help="List of evaluation thresholds.") + parser.add_argument("--save_path", type=str, default="/cajal/scratch/projects/misc/riegerfr/aff_nis/", help="Path to save the model and logs.") + parser.add_argument("--small_size", type=int, default=128, help="Size of the patches.") + parser.add_argument("--exp_name", type=str, default="", help="Experiment name (if empty, will be filled automatically).") + parser.add_argument("--start_from_earlier_ckpt", type=argparse.BooleanOptionalAction, default=False, help="Set True if the first checkpoint was made before starting this script.") + + args, _ = parser.parse_known_args() + return args + + +if __name__ == "__main__": + main() diff --git a/validation_watcher.sh b/validation_watcher.sh new file mode 100644 index 0000000..9cf0c1e --- /dev/null +++ b/validation_watcher.sh @@ -0,0 +1,44 @@ +#!/bin/bash -l + +#SBATCH --nodes=1 +#SBATCH --gres=gpu:1 +#SBATCH --ntasks-per-node=1 +#SBATCH --time=0-5 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=500000 +#SBATCH --signal=B:USR1@300 +#SBATCH --open-mode=append + +# Set TMPDIR for torch compile to avoid race conditions when runs are started in parallel +tmp_dir="/tmp/banis/${SLURM_JOBID}/" +mkdir -p $tmp_dir +export TMPDIR=$tmp_dir + +resubmit_job() { + echo "Job is being resubmitted..." + EXP_NAME=$(echo "${LONG_JOB_ARGS}" | grep -oP '(?<=--exp_name )\S+') + sbatch --dependency=afterany:${SLURM_JOBID} \ + --export=ALL,RESUME=TRUE,LONG_JOB=TRUE,SAVE_DIR=${SAVE_DIR},LONG_JOB_ARGS="${LONG_JOB_ARGS}" \ + --output=${SAVE_DIR}/slurm-validation-log.txt \ + --job-name "${EXP_NAME}_val" \ + "$0" "${@}" + exit 0 +} +trap 'resubmit_job' USR1 + +if ! [ -n "$LONG_JOB" ]; then + echo "Validation watcher only needed for a long job." + exit 1 +fi + +if [ -n "$RESUME" ]; then + echo "Resuming validation from the last checkpoint" + echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" + srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${LONG_JOB_ARGS} & +else + echo "Starting validation from scratch." + export LONG_JOB_ARGS="${@}" + echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" + srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${LONG_JOB_ARGS} & +fi +wait From 421abb153345f1e4a28420e6c564b2a0f3cce491 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Mon, 9 Dec 2024 12:07:06 +0100 Subject: [PATCH 03/33] forgotten global_step --- BANIS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/BANIS.py b/BANIS.py index 586c3d7..ef5296a 100644 --- a/BANIS.py +++ b/BANIS.py @@ -180,7 +180,7 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str, for k, v in best_nerl_metrics.items(): if isinstance(v, (int, float)): - self.safe_add_scalar(f"{mode}_best_nerl_{k}", v) + self.safe_add_scalar(f"{mode}_best_nerl_{k}", v, global_step) def safe_add_scalar(self, name: str, value: float, global_step=None) -> None: try: # s.t. full_cube_inference can be called outside of .fit() without error From 10b6eabaffb1db5c467f9c8cd3070c33b09a0bf4 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 11 Dec 2024 13:44:43 +0100 Subject: [PATCH 04/33] fix time for validation --- validation_watcher.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validation_watcher.sh b/validation_watcher.sh index 9cf0c1e..0d1aff0 100644 --- a/validation_watcher.sh +++ b/validation_watcher.sh @@ -3,7 +3,7 @@ #SBATCH --nodes=1 #SBATCH --gres=gpu:1 #SBATCH --ntasks-per-node=1 -#SBATCH --time=0-5 +#SBATCH --time=7-0 #SBATCH --cpus-per-task=32 #SBATCH --mem=500000 #SBATCH --signal=B:USR1@300 From 99f1cd15e05a4bd5e69b131affd89536ad65016f Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 11 Dec 2024 14:02:01 +0100 Subject: [PATCH 05/33] separate validation processes for each checkpoint --- BANIS.py | 9 +++++++-- slurm_long_job.py | 4 +--- validation_watcher.py | 44 +++++++++---------------------------------- validation_watcher.sh | 34 +-------------------------------- 4 files changed, 18 insertions(+), 73 deletions(-) diff --git a/BANIS.py b/BANIS.py index ef5296a..8d0456f 100644 --- a/BANIS.py +++ b/BANIS.py @@ -102,8 +102,12 @@ def _add_image(self, tag: str, img: torch.Tensor) -> None: global_step=self.global_step) def on_validation_epoch_end(self): - pass - #self.full_cube_inference("val") + if self.args.long_training: + args = ' '.join([f"--{key} {value}" for key, value in vars(self.hparams).items()]) + command = f"sbatch --job-name {self.hparams.exp_name}_val --output {self.hparams.save_dir}/slurm-validation-log.txt validation_watcher.sh {args}" + os.system(command) + else: + self.full_cube_inference("val") def on_train_end(self): assert self.best_nerl_so_far["val"] > 0, "No best NERL found in validation" @@ -309,6 +313,7 @@ def parse_args(): parser.add_argument("--small_size", type=int, default=128, help="Size of the patches.") parser.add_argument("--resume_from_last_checkpoint", action=argparse.BooleanOptionalAction, default=False, help="Resume training from the last checkpoint.") parser.add_argument("--exp_name", type=str, default="", help="Experiment name (if empty, will be filled automatically).") + parser.add_argument("--long_training", action=argparse.BooleanOptionalAction, default=False, help="Long training with a separate validation process.") return parser.parse_args() diff --git a/slurm_long_job.py b/slurm_long_job.py index 1b1e245..75eabd2 100644 --- a/slurm_long_job.py +++ b/slurm_long_job.py @@ -16,9 +16,7 @@ print(f"Error: Experiment already exists: {save_path}/{exp_name}") exit(1) - command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --save_path {save_path} --exp_name {exp_name}" - command_validation = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE --job-name {exp_name}_val --output {save_dir}/slurm-validation-log.txt validation_watcher.sh {' '.join(unknown_args)} --save_path {save_path} --exp_name {exp_name}" + command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --long_training --save_path {save_path} --exp_name {exp_name}" # Execute the command os.system(command) - os.system(command_validation) diff --git a/validation_watcher.py b/validation_watcher.py index c26018e..49c220a 100644 --- a/validation_watcher.py +++ b/validation_watcher.py @@ -16,42 +16,17 @@ def main(): seed_everything(args.seed, workers=True) ckpt_path = os.path.join(args.save_path, args.exp_name, "default", "checkpoints", "last.ckpt") - #ckpt_path = os.path.join(args.save_path, args.exp_name, "lightning_logs", "version_0", "checkpoints", "last.ckpt") print(f"ckpt path: {ckpt_path}") - last_mod_time = 0 if args.start_from_earlier_ckpt else time.time() + if os.path.exists(ckpt_path): + model = BANIS.load_from_checkpoint(ckpt_path) + model = model.to("cuda") + trainer = pl.Trainer(logger=tb_logger, accelerator="gpu", devices=-1) + model.trainer = trainer # for the logger - tb_logger = TensorBoardLogger( - save_dir=args.save_path, - name=args.exp_name, - version="default", - ) - - while True: - if os.path.exists(ckpt_path): - current_mod_time = os.path.getmtime(ckpt_path) - if current_mod_time > last_mod_time: - print(f"New checkpoint detected at {datetime.fromtimestamp(current_mod_time)}") - last_mod_time = current_mod_time - - model = BANIS.load_from_checkpoint(ckpt_path) - model = model.to("cuda") - trainer = pl.Trainer(logger=tb_logger, accelerator="gpu", devices=-1) - model.trainer = trainer # for the logger - - # global step of model loaded from checkpoint is 0 by default, until trainer is started - # see https://github.com/Lightning-AI/pytorch-lightning/issues/12819 - checkpoint = torch.load(ckpt_path) - model.full_cube_inference("val", checkpoint["global_step"]) - - elif time.time() > last_mod_time + 20*60*60: - print("No new checkpoint detected for 24 hours, terminating.") - break - - elif time.time() > last_mod_time + 20 * 60 * 60: - print("No checkpoint detected, terminating.") - break - - time.sleep(60) + # global step of model loaded from checkpoint is 0 by default, until trainer is started + # see https://github.com/Lightning-AI/pytorch-lightning/issues/12819 + checkpoint = torch.load(ckpt_path) + model.full_cube_inference("val", checkpoint["global_step"]) @@ -64,7 +39,6 @@ def parse_args(): parser.add_argument("--save_path", type=str, default="/cajal/scratch/projects/misc/riegerfr/aff_nis/", help="Path to save the model and logs.") parser.add_argument("--small_size", type=int, default=128, help="Size of the patches.") parser.add_argument("--exp_name", type=str, default="", help="Experiment name (if empty, will be filled automatically).") - parser.add_argument("--start_from_earlier_ckpt", type=argparse.BooleanOptionalAction, default=False, help="Set True if the first checkpoint was made before starting this script.") args, _ = parser.parse_known_args() return args diff --git a/validation_watcher.sh b/validation_watcher.sh index 0d1aff0..b50ece7 100644 --- a/validation_watcher.sh +++ b/validation_watcher.sh @@ -9,36 +9,4 @@ #SBATCH --signal=B:USR1@300 #SBATCH --open-mode=append -# Set TMPDIR for torch compile to avoid race conditions when runs are started in parallel -tmp_dir="/tmp/banis/${SLURM_JOBID}/" -mkdir -p $tmp_dir -export TMPDIR=$tmp_dir - -resubmit_job() { - echo "Job is being resubmitted..." - EXP_NAME=$(echo "${LONG_JOB_ARGS}" | grep -oP '(?<=--exp_name )\S+') - sbatch --dependency=afterany:${SLURM_JOBID} \ - --export=ALL,RESUME=TRUE,LONG_JOB=TRUE,SAVE_DIR=${SAVE_DIR},LONG_JOB_ARGS="${LONG_JOB_ARGS}" \ - --output=${SAVE_DIR}/slurm-validation-log.txt \ - --job-name "${EXP_NAME}_val" \ - "$0" "${@}" - exit 0 -} -trap 'resubmit_job' USR1 - -if ! [ -n "$LONG_JOB" ]; then - echo "Validation watcher only needed for a long job." - exit 1 -fi - -if [ -n "$RESUME" ]; then - echo "Resuming validation from the last checkpoint" - echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" - srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${LONG_JOB_ARGS} & -else - echo "Starting validation from scratch." - export LONG_JOB_ARGS="${@}" - echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" - srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${LONG_JOB_ARGS} & -fi -wait +srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${@} & \ No newline at end of file From a24c2c1450c2d40ee5abe85581cba94fa965f7b3 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 18 Dec 2024 12:11:53 +0100 Subject: [PATCH 06/33] fix validation run --- BANIS.py | 20 +++++++++++++++++--- validation_watcher.py | 6 ++++++ validation_watcher.sh | 3 ++- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/BANIS.py b/BANIS.py index 8d0456f..d0fc2ad 100644 --- a/BANIS.py +++ b/BANIS.py @@ -102,10 +102,24 @@ def _add_image(self, tag: str, img: torch.Tensor) -> None: global_step=self.global_step) def on_validation_epoch_end(self): - if self.args.long_training: - args = ' '.join([f"--{key} {value}" for key, value in vars(self.hparams).items()]) + if self.hparams.long_training: + def format_value(value): + if isinstance(value, bool): + return str(value).lower() # Convert booleans to lowercase strings (true/false) + elif isinstance(value, list): + return ' '.join(map(str, value)) # Convert list to a space-separated string + elif value is None: + return '' # Skip None values + else: + return str(value) # Convert other types to string + + args_list = [f"--{key} {format_value(value)}" for key, value in self.hparams.items()] + args = ' '.join(args_list) + command = f"sbatch --job-name {self.hparams.exp_name}_val --output {self.hparams.save_dir}/slurm-validation-log.txt validation_watcher.sh {args}" os.system(command) + print(f"running validation: {command}") + else: self.full_cube_inference("val") @@ -130,7 +144,7 @@ def full_cube_inference(self, mode: str, global_step=None): print(f"Full cube inference for {mode}") base_path_mode = os.path.join(self.hparams.base_data_path, self.hparams.data_setting, mode) - seeds_path_mode = sorted([f for f in os.listdir(base_path_mode) if "seed" in f]) + seeds_path_mode = sorted([f for f in os.listdir(base_path_mode) if f.startswith("seed")]) assert len(seeds_path_mode) >= 1, f"No seeds found in {base_path_mode}" seed_path = os.path.join(base_path_mode, seeds_path_mode[0]) diff --git a/validation_watcher.py b/validation_watcher.py index 49c220a..2367843 100644 --- a/validation_watcher.py +++ b/validation_watcher.py @@ -20,6 +20,11 @@ def main(): if os.path.exists(ckpt_path): model = BANIS.load_from_checkpoint(ckpt_path) model = model.to("cuda") + tb_logger = TensorBoardLogger( + save_dir=args.save_path, + name=args.exp_name, + version="default", + ) trainer = pl.Trainer(logger=tb_logger, accelerator="gpu", devices=-1) model.trainer = trainer # for the logger @@ -41,6 +46,7 @@ def parse_args(): parser.add_argument("--exp_name", type=str, default="", help="Experiment name (if empty, will be filled automatically).") args, _ = parser.parse_known_args() + print(f"args: {args}") return args diff --git a/validation_watcher.sh b/validation_watcher.sh index b50ece7..342963f 100644 --- a/validation_watcher.sh +++ b/validation_watcher.sh @@ -9,4 +9,5 @@ #SBATCH --signal=B:USR1@300 #SBATCH --open-mode=append -srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${@} & \ No newline at end of file +sleep 1m # waits for the checkpoint creation +srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${@} From b1d78c0e851dc6248ecf7c584c70df2d986db444 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 18 Dec 2024 15:23:39 +0100 Subject: [PATCH 07/33] problems with child job --- validation_watcher.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validation_watcher.sh b/validation_watcher.sh index 342963f..d93a4d1 100644 --- a/validation_watcher.sh +++ b/validation_watcher.sh @@ -10,4 +10,4 @@ #SBATCH --open-mode=append sleep 1m # waits for the checkpoint creation -srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${@} +mamba run -n nisb --no-capture-output python3 -u validation_watcher.py ${@} From 18ba7255709e27d8af2b63ccb5299a1565c6a785 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 5 Feb 2025 13:39:10 +0100 Subject: [PATCH 08/33] debug OOM, better resubmit --- BANIS.py | 4 +++- aff_train.sh | 1 + slurm_long_job.py | 20 +++++++++++++------- validation_watcher.py | 1 + 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/BANIS.py b/BANIS.py index d0fc2ad..b3effd2 100644 --- a/BANIS.py +++ b/BANIS.py @@ -67,7 +67,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) def training_step(self, data: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: - return self._step(data, "train") + result = self._step(data, "train") + torch.cuda.empty_cache() # sometimes OOM error without this + return result def validation_step(self, data: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: return self._step(data, "val") diff --git a/aff_train.sh b/aff_train.sh index 9e31cd3..6f67aa9 100644 --- a/aff_train.sh +++ b/aff_train.sh @@ -37,6 +37,7 @@ fi if [ -n "$RESUME" ]; then echo "Resuming from the last checkpoint" echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" + echo "ALL ARGS: ${@}" srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u BANIS.py --resume_from_last_checkpoint ${LONG_JOB_ARGS} & else echo "Starting long training from scratch." diff --git a/slurm_long_job.py b/slurm_long_job.py index 75eabd2..6765162 100644 --- a/slurm_long_job.py +++ b/slurm_long_job.py @@ -5,18 +5,24 @@ parser = argparse.ArgumentParser(description="Submit a job with custom save_dir and pass other arguments.") parser.add_argument("--save_path", type=str, required=True, help="Path to save the model and logs") parser.add_argument("--exp_name", type=str, required=True, help="Experiment name") + parser.add_argument("--resubmit", action=argparse.BooleanOptionalAction, default=False, help="Continue already existing training") args, unknown_args = parser.parse_known_args() save_path = args.save_path exp_name = args.exp_name - try: - save_dir = os.path.join(save_path, exp_name) - os.makedirs(f"{save_path}/{exp_name}", exist_ok=False) - except FileExistsError as error: - print(f"Error: Experiment already exists: {save_path}/{exp_name}") - exit(1) + save_dir = os.path.join(save_path, exp_name) + if not args.resubmit: + try: + os.makedirs(f"{save_path}/{exp_name}", exist_ok=False) + except FileExistsError as error: + print(f"Error: Experiment already exists: {save_path}/{exp_name}") + exit(1) - command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --long_training --save_path {save_path} --exp_name {exp_name}" + command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --long_training --save_path {save_path} --exp_name {exp_name}" + + else: + command = f"sbatch --export=ALL,RESUME=TRUE,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,LONG_JOB_ARGS='{' '.join(unknown_args)} --long_training --save_path {save_path} --exp_name {exp_name}' --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --long_training --save_path {save_path} --exp_name {exp_name}" # Execute the command + print(command) os.system(command) diff --git a/validation_watcher.py b/validation_watcher.py index 2367843..669b2c4 100644 --- a/validation_watcher.py +++ b/validation_watcher.py @@ -31,6 +31,7 @@ def main(): # global step of model loaded from checkpoint is 0 by default, until trainer is started # see https://github.com/Lightning-AI/pytorch-lightning/issues/12819 checkpoint = torch.load(ckpt_path) + print(f"global step: {checkpoint['global_step']}") model.full_cube_inference("val", checkpoint["global_step"]) From 7f51dc470a964fa8213ae62488e867806bfe9c45 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Fri, 23 May 2025 17:03:03 +0200 Subject: [PATCH 09/33] sort BANIS arguments --- BANIS.py | 132 ++++++++++++++++++++++++---------------------- inference.py | 4 +- slurm_long_job.py | 4 +- 3 files changed, 75 insertions(+), 65 deletions(-) diff --git a/BANIS.py b/BANIS.py index b3effd2..c793576 100644 --- a/BANIS.py +++ b/BANIS.py @@ -1,4 +1,5 @@ import argparse +import gc import os from collections import defaultdict from datetime import datetime @@ -11,7 +12,7 @@ import zarr from nnunet_mednext import create_mednext_v1 from pytorch_lightning import LightningModule, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint, DeviceStatsMonitor +from pytorch_lightning.callbacks import ModelCheckpoint, DeviceStatsMonitor, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger from torch.nn.functional import binary_cross_entropy_with_logits from torch.optim import AdamW @@ -47,6 +48,14 @@ def __init__(self, **kwargs: Any): self.best_nerl_so_far = defaultdict(float) # for train/val/test self.best_thr_so_far = defaultdict(float) + def on_save_checkpoint(self, checkpoint): + checkpoint["best_thr_so_far"] = self.best_thr_so_far + checkpoint["best_nerl_so_far"] = self.best_nerl_so_far + + def on_load_checkpoint(self, checkpoint): + self.best_thr_so_far = checkpoint.get("best_thr_so_far", defaultdict(float)) + self.best_nerl_so_far = checkpoint.get("best_nerl_so_far", defaultdict(float)) + def on_fit_start(self): self.logger.experiment.add_text("hparams", str(self.hparams)) @@ -67,9 +76,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) def training_step(self, data: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: - result = self._step(data, "train") - torch.cuda.empty_cache() # sometimes OOM error without this - return result + return self._step(data, "train") def validation_step(self, data: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: return self._step(data, "val") @@ -104,23 +111,24 @@ def _add_image(self, tag: str, img: torch.Tensor) -> None: global_step=self.global_step) def on_validation_epoch_end(self): - if self.hparams.long_training: - def format_value(value): - if isinstance(value, bool): - return str(value).lower() # Convert booleans to lowercase strings (true/false) - elif isinstance(value, list): - return ' '.join(map(str, value)) # Convert list to a space-separated string - elif value is None: - return '' # Skip None values - else: - return str(value) # Convert other types to string - - args_list = [f"--{key} {format_value(value)}" for key, value in self.hparams.items()] - args = ' '.join(args_list) - - command = f"sbatch --job-name {self.hparams.exp_name}_val --output {self.hparams.save_dir}/slurm-validation-log.txt validation_watcher.sh {args}" - os.system(command) - print(f"running validation: {command}") + if self.hparams.validate_extern: + if self.trainer.is_global_zero: + def format_value(value): + if isinstance(value, bool): + return str(value).lower() # Convert booleans to lowercase strings (true/false) + elif isinstance(value, list): + return ' '.join(map(str, value)) # Convert list to a space-separated string + elif value is None: + return '' # Skip None values + else: + return str(value) # Convert other types to string + + args_list = [f"--{key} {format_value(value)}" for key, value in self.hparams.items()] + args = ' '.join(args_list) + + command = f"sbatch --job-name {self.hparams.exp_name}_val --output {self.hparams.save_dir}/slurm-validation-log.txt validation_watcher.sh {args}" + os.system(command) + print(f"running validation: {command}") else: self.full_cube_inference("val") @@ -130,8 +138,6 @@ def on_train_end(self): self.eval() print(f"device {next(self.parameters()).device}") self.cuda() - # self.full_cube_inference("val") - assert self.best_nerl_so_far["val"] > 0, "No best NERL found in validation" self.full_cube_inference("test") self.full_cube_inference("train") @@ -167,6 +173,8 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str, thresholds = self.hparams.eval_ranges if mode != "test" else [self.best_thr_so_far["val"]] for thr in tqdm(thresholds): + gc.collect() + torch.cuda.empty_cache() print(f"threshold {thr}") pred_seg = compute_connected_component_segmentation( @@ -189,8 +197,11 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str, if self.best_nerl_so_far[mode] < best_nerl: self.best_nerl_so_far[mode] = best_nerl self.best_thr_so_far[mode] = thr - np.save(f"{self.hparams.save_dir}/pred_aff_best_nerl_{mode}.npy", aff_pred) - np.save(f"{self.hparams.save_dir}/pred_seg_best_nerl_{mode}.npy", pred_seg) + with open(f"{self.hparams.save_dir}/best_thr_{mode}.txt", "w") as f: + f.write(str(self.best_thr_so_far[mode])) + seg_pred = zarr.array(pred_seg, dtype=np.uint32, + store=f"{self.hparams.save_dir}/pred_seg_{mode}.zarr", + chunks=(512, 512, 512), overwrite=True) best_voi = min(best_voi, metrics["voi_sum"]) self.safe_add_scalar(f"{mode}_best_nerl", best_nerl, global_step) @@ -245,7 +256,10 @@ def main(): trainer = pl.Trainer( callbacks=[ DeviceStatsMonitor(), - model_checkpoint_callback + model_checkpoint_callback, + LearningRateMonitor( + logging_interval='step' + ), ], logger=tb_logger, max_steps=args.n_steps, @@ -281,55 +295,49 @@ def main(): def parse_args(): parser = argparse.ArgumentParser() + + # General arguments + parser.add_argument("--exp_name", type=str, default="", help="Experiment name (if empty, will be filled automatically).") + parser.add_argument("--save_path", type=str, default="/cajal/scratch/projects/misc/riegerfr/aff_nis/", help="Path to save the model and logs.") + + # Training arguments parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility.") - parser.add_argument("--long_range", type=int, default=10, help="Long range affinities (voxels).") parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training.") - parser.add_argument("--model_id", type=str, default="S", help="Identifier for the mednext model architecture.") - parser.add_argument("--kernel_size", type=int, default=3, help="Kernel size for the convolutional layers.") + parser.add_argument("--n_steps", type=int, default=20_000, help="Number of training steps.") parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimizer.") parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay for the optimizer.") - parser.add_argument("--drop_slice_prob", type=float, default=0.05, - help="Probability of dropping a slice during augmentation.") - parser.add_argument("--shift_slice_prob", type=float, default=0.05, - help="Probability of shifting a slice during augmentation.") - parser.add_argument("--intensity_aug", action=argparse.BooleanOptionalAction, default=True, - help="Whether to apply intensity augmentation.") - parser.add_argument("--noise_scale", type=float, default=0.5, - help="Scale of the noise to be added during augmentation.") - parser.add_argument("--affine", type=float, default=0.5, help="Affine transformation probability.") - parser.add_argument("--n_steps", type=int, default=20_000, help="Number of training steps.") parser.add_argument("--workers", type=int, default=8, help="Number of workers for data loading.") - parser.add_argument("--base_data_path", type=str, - default="/cajal/nvmescratch/projects/NISB/", - help="Base path for the dataset.") + parser.add_argument("--scheduler", action=argparse.BooleanOptionalAction, default=True, help="Whether to use a learning rate scheduler.") + parser.add_argument("--devices", type=int, default=-1, help="Number GPU devices to use (-1: all).") + parser.add_argument("--n_debug_steps", type=int, default=0, help="Number of debug steps.") + parser.add_argument("--log_every_n_steps", type=int, default=100, help="Log every n steps.") + parser.add_argument("--val_check_interval", type=int, default=5000, help="Validation check interval.") + parser.add_argument("--resume_from_last_checkpoint", action=argparse.BooleanOptionalAction, default=False, help="Resume training from the last checkpoint.") + parser.add_argument("--validate_extern", action=argparse.BooleanOptionalAction, default=False, help="Long training with a separate validation process.") + + # Data arguments + parser.add_argument("--base_data_path", type=str, default="/cajal/nvmescratch/projects/NISB/", help="Base path for the dataset.") parser.add_argument("--data_setting", type=str, default="base", help="Data setting identifier.") - parser.add_argument("--scheduler", action=argparse.BooleanOptionalAction, default=True, - help="Whether to use a learning rate scheduler.") + parser.add_argument("--real_data_path", type=str, default="/cajal/scratch/projects/misc/mdraw/data/funke/zebrafinch/training/", help="Path to the real dataset. See https://colab.research.google.com/github/funkelab/lsd/blob/master/lsd/tutorial/notebooks/lsd_data_download.ipynb ") parser.add_argument("--synthetic", type=float, default=1.0, help="Ratio of synthetic data to real data.") - parser.add_argument("--compile", action=argparse.BooleanOptionalAction, default=True, - help="Whether to compile the model using torch.compile.") - parser.add_argument("--eval_ranges", type=float, nargs="+", - default=torch.sigmoid(torch.tensor(list(range(-1, 12))).double() * 0.2).numpy().round( - 4).tolist(), - help="List of evaluation thresholds.") - parser.add_argument("--save_path", type=str, default="/cajal/scratch/projects/misc/riegerfr/aff_nis/", - help="Path to save the model and logs.") - parser.add_argument("--real_data_path", type=str, - default="/cajal/scratch/projects/misc/mdraw/data/funke/zebrafinch/training/", - help="Path to the real dataset. See https://colab.research.google.com/github/funkelab/lsd/blob/master/lsd/tutorial/notebooks/lsd_data_download.ipynb ") + parser.add_argument("--drop_slice_prob", type=float, default=0.05, help="Probability of dropping a slice during augmentation.") + parser.add_argument("--shift_slice_prob", type=float, default=0.05, help="Probability of shifting a slice during augmentation.") + parser.add_argument("--intensity_aug", action=argparse.BooleanOptionalAction, default=True, help="Whether to apply intensity augmentation.") + parser.add_argument("--noise_scale", type=float, default=0.5, help="Scale of the noise to be added during augmentation.") + parser.add_argument("--affine", type=float, default=0.5, help="Affine transformation probability.") parser.add_argument("--affine_scale", type=float, default=0.2, help="Scale for affine augmentation.") parser.add_argument("--affine_shear", type=float, default=0.5, help="Shear for affine augmentation.") parser.add_argument("--shift_magnitude", type=int, default=10, help="Shift augmentation magnitude (voxels).") parser.add_argument("--mul_int", type=float, default=0.1, help="Multiplicative augmentation intensity.") parser.add_argument("--add_int", type=float, default=0.1, help="Additive augmentation intensity.") - parser.add_argument("--devices", type=int, default=-1, help="Number GPU devices to use (-1: all).") - parser.add_argument("--n_debug_steps", type=int, default=0, help="Number of debug steps.") - parser.add_argument("--log_every_n_steps", type=int, default=100, help="Log every n steps.") - parser.add_argument("--val_check_interval", type=int, default=5000, help="Validation check interval.") + + # Model arguments + parser.add_argument("--long_range", type=int, default=10, help="Long range affinities (voxels).") + parser.add_argument("--model_id", type=str, default="S", help="Identifier for the mednext model architecture.") + parser.add_argument("--kernel_size", type=int, default=3, help="Kernel size for the convolutional layers.") + parser.add_argument("--compile", action=argparse.BooleanOptionalAction, default=True, help="Whether to compile the model using torch.compile.") + parser.add_argument("--eval_ranges", type=float, nargs="+", default=torch.sigmoid(torch.tensor(list(range(-1, 12))).double() * 0.2).numpy().round(4).tolist(), help="List of evaluation thresholds.") parser.add_argument("--small_size", type=int, default=128, help="Size of the patches.") - parser.add_argument("--resume_from_last_checkpoint", action=argparse.BooleanOptionalAction, default=False, help="Resume training from the last checkpoint.") - parser.add_argument("--exp_name", type=str, default="", help="Experiment name (if empty, will be filled automatically).") - parser.add_argument("--long_training", action=argparse.BooleanOptionalAction, default=False, help="Long training with a separate validation process.") return parser.parse_args() diff --git a/inference.py b/inference.py index be2231b..8a939bd 100644 --- a/inference.py +++ b/inference.py @@ -1,7 +1,9 @@ +import gc from typing import Union, List, Tuple import numba import numpy as np +import psutil import torch import torch.utils import zarr @@ -28,7 +30,7 @@ def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray Returns: The segmentation. Shape: (x, y, z). """ - visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=numba.boolean) + visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint8) seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) cur_id = 1 for i in range(visited.shape[0]): diff --git a/slurm_long_job.py b/slurm_long_job.py index 6765162..35dd047 100644 --- a/slurm_long_job.py +++ b/slurm_long_job.py @@ -18,10 +18,10 @@ print(f"Error: Experiment already exists: {save_path}/{exp_name}") exit(1) - command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --long_training --save_path {save_path} --exp_name {exp_name}" + command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --validate_extern --save_path {save_path} --exp_name {exp_name}" else: - command = f"sbatch --export=ALL,RESUME=TRUE,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,LONG_JOB_ARGS='{' '.join(unknown_args)} --long_training --save_path {save_path} --exp_name {exp_name}' --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --long_training --save_path {save_path} --exp_name {exp_name}" + command = f"sbatch --export=ALL,RESUME=TRUE,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,LONG_JOB_ARGS='{' '.join(unknown_args)} --validate_extern --save_path {save_path} --exp_name {exp_name}' --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --validate_extern --save_path {save_path} --exp_name {exp_name}" # Execute the command print(command) From 6bf836d0c5a6b46a826ea227b69ffaf82bf34857 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Fri, 23 May 2025 17:05:37 +0200 Subject: [PATCH 10/33] ignore slurm logs --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 7b6caf3..93d2b8d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +slurm-*.out + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From 7d78676cf874a83e8e14100ea9603d05359aa58f Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Fri, 23 May 2025 18:28:11 +0200 Subject: [PATCH 11/33] one script to rule them all (long training also possible from slurm_job_scheduler) --- README.md | 6 +----- aff_train.sh | 3 +-- slurm_job_scheduler.py | 25 +++++++++++++++++++++---- slurm_long_job.py | 28 ---------------------------- 4 files changed, 23 insertions(+), 39 deletions(-) delete mode 100644 slurm_long_job.py diff --git a/README.md b/README.md index e598122..81d39f0 100644 --- a/README.md +++ b/README.md @@ -36,11 +36,7 @@ To run multiple jobs on Slurm, adjust `config.yaml` and `aff_train.sh`, then: python slurm_job_scheduler.py ``` -To run training that restarts from the last checkpoint once the Slurm limit is reached, adjust `aff_train.sh`, then: - -```bash -python slurm_long_job.py --save_path /local/logging/dir/ --exp_name experiment_name [--other_arguments] -``` +Adding an `auto_resubmit` argument to `config.yaml` allows Slurm to automatically resubmit jobs that reach the Slurm time limit (see `aff_train.sh`). ## Evaluation diff --git a/aff_train.sh b/aff_train.sh index 6f67aa9..6e1999b 100644 --- a/aff_train.sh +++ b/aff_train.sh @@ -21,8 +21,7 @@ resubmit_job() { EXP_NAME=$(echo "${LONG_JOB_ARGS}" | grep -oP '(?<=--exp_name )\S+') sbatch --dependency=afterany:${SLURM_JOBID} \ --export=ALL,RESUME=TRUE,LONG_JOB=TRUE,SAVE_DIR=${SAVE_DIR},LONG_JOB_ARGS="${LONG_JOB_ARGS}" \ - --output=${SAVE_DIR}/slurm-log.txt \ - --job-name ${EXP_NAME} \ + --output=${SAVE_DIR}/slurm.out \ "$0" "${@}" exit 0 } diff --git a/slurm_job_scheduler.py b/slurm_job_scheduler.py index cfd2909..a439de0 100644 --- a/slurm_job_scheduler.py +++ b/slurm_job_scheduler.py @@ -10,23 +10,40 @@ def load_config(filename): return yaml.safe_load(file) -def construct_args(params, combination): +def construct_args(params, combination, variable_keys): args = [] + long = False + save_dir = "" + exp_name_parts = [params["exp_name"][0]] if "exp_name" in params else [] for key, value in zip(params.keys(), combination): - if key in ["scheduler", "intensity_aug"]: + if key in ["scheduler", "intensity_aug", "validate_extern"]: if not value: args.append(f"--no-{key}") + elif key == "auto_resubmit": + long = value else: args.append(f"--{key} {value}") - return " ".join(args) + if key == "save_path": + save_dir = value + if key in variable_keys: + exp_name_parts.append(f"{key}{value}") + exp_name = "-".join(exp_name_parts) or "experiment" + save_dir = os.path.join(save_dir, exp_name) + os.makedirs(save_dir, exist_ok=False) + return long, save_dir, " ".join(args), exp_name if __name__ == "__main__": config = load_config("config.yaml") params = config['params'] + variable_keys = [k for k, v in params.items() if len(v) > 1] for combination in product(*params.values()): - command = f"sbatch aff_train.sh {construct_args(params, combination)}" + long, save_dir, args, exp_name = construct_args(params, combination, variable_keys) + if not long: + command = f"sbatch --output {save_dir}/slurm.out aff_train.sh {args}" + else: + command = f"sbatch --output {save_dir}/slurm.out --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True aff_train.sh {args}" print(f"Executing command: {command}") os.system(command) time.sleep(1) diff --git a/slurm_long_job.py b/slurm_long_job.py deleted file mode 100644 index 35dd047..0000000 --- a/slurm_long_job.py +++ /dev/null @@ -1,28 +0,0 @@ -import argparse -import os - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Submit a job with custom save_dir and pass other arguments.") - parser.add_argument("--save_path", type=str, required=True, help="Path to save the model and logs") - parser.add_argument("--exp_name", type=str, required=True, help="Experiment name") - parser.add_argument("--resubmit", action=argparse.BooleanOptionalAction, default=False, help="Continue already existing training") - args, unknown_args = parser.parse_known_args() - - save_path = args.save_path - exp_name = args.exp_name - save_dir = os.path.join(save_path, exp_name) - if not args.resubmit: - try: - os.makedirs(f"{save_path}/{exp_name}", exist_ok=False) - except FileExistsError as error: - print(f"Error: Experiment already exists: {save_path}/{exp_name}") - exit(1) - - command = f"sbatch --export=ALL,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --validate_extern --save_path {save_path} --exp_name {exp_name}" - - else: - command = f"sbatch --export=ALL,RESUME=TRUE,SAVE_DIR={save_dir},LONG_JOB=TRUE,PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,LONG_JOB_ARGS='{' '.join(unknown_args)} --validate_extern --save_path {save_path} --exp_name {exp_name}' --job-name {exp_name} --output {save_dir}/slurm-log.txt aff_train.sh {' '.join(unknown_args)} --validate_extern --save_path {save_path} --exp_name {exp_name}" - - # Execute the command - print(command) - os.system(command) From 45f2ecf3f74ec74f70b3c0551a40a2cfa8920b21 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Sat, 24 May 2025 09:57:40 +0200 Subject: [PATCH 12/33] fix exp_name --- slurm_job_scheduler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/slurm_job_scheduler.py b/slurm_job_scheduler.py index a439de0..6480fa4 100644 --- a/slurm_job_scheduler.py +++ b/slurm_job_scheduler.py @@ -19,8 +19,12 @@ def construct_args(params, combination, variable_keys): if key in ["scheduler", "intensity_aug", "validate_extern"]: if not value: args.append(f"--no-{key}") + else: + args.append(f"--{key}") elif key == "auto_resubmit": long = value + elif key == "exp_name": + pass else: args.append(f"--{key} {value}") if key == "save_path": @@ -28,8 +32,8 @@ def construct_args(params, combination, variable_keys): if key in variable_keys: exp_name_parts.append(f"{key}{value}") exp_name = "-".join(exp_name_parts) or "experiment" + args.append(f"--exp_name {exp_name}") save_dir = os.path.join(save_dir, exp_name) - os.makedirs(save_dir, exist_ok=False) return long, save_dir, " ".join(args), exp_name @@ -40,6 +44,7 @@ def construct_args(params, combination, variable_keys): for combination in product(*params.values()): long, save_dir, args, exp_name = construct_args(params, combination, variable_keys) + os.makedirs(save_dir, exist_ok=False) if not long: command = f"sbatch --output {save_dir}/slurm.out aff_train.sh {args}" else: From 488ac20793b713a7fe86f1407e6f3ed9d431d175 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Sat, 24 May 2025 10:07:34 +0200 Subject: [PATCH 13/33] monai determinism --- data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/data.py b/data.py index 4af024d..a307921 100644 --- a/data.py +++ b/data.py @@ -9,6 +9,7 @@ import torch.utils import zarr from monai.transforms import RandAffined +from monai.utils import set_determinism from torch.utils.data import Dataset, ConcatDataset from tqdm import tqdm @@ -110,6 +111,7 @@ def __init__( ), divide: Union[int, float] = 1, ): + set_determinism(seed=np.random.randint(0, 2**32)) self.size_divisor = size_divisor self.img = img self.divide = divide @@ -121,7 +123,7 @@ def __init__( self.offset = tuple((img.shape[i] - seg.shape[i]) // 2 for i in range(3)) - print(f"seg shape {seg.shape}, img shape {img.shape}") + # print(f"seg shape {seg.shape}, img shape {img.shape}") if img.shape[:3] != seg.shape: # Shapes don't match, pad seg (and load it into memory) From 083e49e676e96fdc3ed564816a6f34cf9925ae84 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Sat, 24 May 2025 10:15:52 +0200 Subject: [PATCH 14/33] worker init fn --- BANIS.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/BANIS.py b/BANIS.py index c793576..9ee59e6 100644 --- a/BANIS.py +++ b/BANIS.py @@ -220,6 +220,16 @@ def safe_add_scalar(self, name: str, value: float, global_step=None) -> None: print(f"Error logging {name}: {e}") +def worker_init_fn(worker_id): + """ Ensures different seeds for each worker. """ + # torch.initial_seed() is derived from the initial seed state but advanced for each worker + seed = torch.initial_seed() % (2**32) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + print(f"[Worker {worker_id}] Seed: {seed}") + + def main(): args = parse_args() seed_everything(args.seed, workers=True) @@ -285,8 +295,8 @@ def main(): trainer.fit( model=model, train_dataloaders=DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, - drop_last=True), - val_dataloaders=DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers), + drop_last=True, worker_init_fn=worker_init_fn), + val_dataloaders=DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn), ckpt_path="last" if args.resume_from_last_checkpoint else None ) From 6954b515ccaf8f9b4c9d3d19b357a54dd45043c6 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Sat, 24 May 2025 10:27:58 +0200 Subject: [PATCH 15/33] xl long distributed training --- BANIS.py | 3 +++ aff_train.sh | 11 ++++++----- config.yaml | 45 +++++++++++++++++++++++++----------------- slurm_job_scheduler.py | 2 +- 4 files changed, 37 insertions(+), 24 deletions(-) diff --git a/BANIS.py b/BANIS.py index 9ee59e6..ac15dd3 100644 --- a/BANIS.py +++ b/BANIS.py @@ -275,6 +275,8 @@ def main(): max_steps=args.n_steps, accelerator="gpu", devices=args.devices, + strategy=DDPStrategy(find_unused_parameters=False) if args.distributed else "auto", + num_nodes=int(os.environ["SLURM_NNODES"]) if args.distributed else 1, log_every_n_steps=args.log_every_n_steps, limit_val_batches=100, precision="16-mixed", @@ -324,6 +326,7 @@ def parse_args(): parser.add_argument("--val_check_interval", type=int, default=5000, help="Validation check interval.") parser.add_argument("--resume_from_last_checkpoint", action=argparse.BooleanOptionalAction, default=False, help="Resume training from the last checkpoint.") parser.add_argument("--validate_extern", action=argparse.BooleanOptionalAction, default=False, help="Long training with a separate validation process.") + parser.add_argument("--distributed", action=argparse.BooleanOptionalAction, default=False, help="Use distributed training.") # Data arguments parser.add_argument("--base_data_path", type=str, default="/cajal/nvmescratch/projects/NISB/", help="Base path for the dataset.") diff --git a/aff_train.sh b/aff_train.sh index 6e1999b..567b9b7 100644 --- a/aff_train.sh +++ b/aff_train.sh @@ -1,13 +1,14 @@ #!/bin/bash -l -#SBATCH --nodes=1 -#SBATCH --gres=gpu:1 -#SBATCH --ntasks-per-node=1 +#SBATCH --nodes=2 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 #SBATCH --time=7-00 -#SBATCH --cpus-per-task=32 -#SBATCH --mem=500000 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=1000G #SBATCH --signal=B:USR1@300 #SBATCH --open-mode=append +#SBATCH --partition=p.large mamba activate nisb diff --git a/config.yaml b/config.yaml index 346ae76..032e4a3 100644 --- a/config.yaml +++ b/config.yaml @@ -5,20 +5,20 @@ params: - 1e-2 seed: - 0 - - 1 - - 2 - - 3 - - 4 + #- 1 + #- 2 + #- 3 + #- 4 long_range: - 10 batch_size: - - 8 + - 1 scheduler: - true model_id: - - "S" + - "L" kernel_size: - - 3 + - 5 synthetic: - 1.0 drop_slice_prob: @@ -32,22 +32,31 @@ params: affine: - 0.5 n_steps: - - 50000 + - 1_000_000 small_size: - - 128 + - 256 data_setting: - - "base" - - "liconn" - - "multichannel" - - "neg_guidance" - - "no_touch_thick" - - "pos_guidance" - - "slice_perturbed" - - "touching_thin" + #- "base" + #- "liconn" + #- "multichannel" + #- "neg_guidance" + #- "no_touch_thick" + #- "pos_guidance" + #- "slice_perturbed" + #- "touching_thin" - "train_100" base_data_path: - "/cajal/nvmescratch/projects/NISB/" save_path: - - "/cajal/scratch/projects/misc/riegerfr/aff_nis/" + #- "/cajal/scratch/projects/misc/riegerfr/aff_nis/" + - "/cajal/scratch/projects/misc/zuzur/xl_banis" + exp_name: + - "xl_test" real_data_path: #https://colab.research.google.com/github/funkelab/lsd/blob/master/lsd/tutorial/notebooks/lsd_data_download.ipynb - "/cajal/scratch/projects/misc/mdraw/data/funke/zebrafinch/training/" + auto_resubmit: + - True + distributed: + - True + validate_extern: + - True \ No newline at end of file diff --git a/slurm_job_scheduler.py b/slurm_job_scheduler.py index 6480fa4..2ba4d6c 100644 --- a/slurm_job_scheduler.py +++ b/slurm_job_scheduler.py @@ -16,7 +16,7 @@ def construct_args(params, combination, variable_keys): save_dir = "" exp_name_parts = [params["exp_name"][0]] if "exp_name" in params else [] for key, value in zip(params.keys(), combination): - if key in ["scheduler", "intensity_aug", "validate_extern"]: + if key in ["scheduler", "intensity_aug", "validate_extern", "distributed"]: if not value: args.append(f"--no-{key}") else: From edf4f0bd464d74f50d17bb7a98c4ac6fba6adfaa Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Sat, 24 May 2025 11:03:10 +0200 Subject: [PATCH 16/33] fix bugs --- BANIS.py | 3 +++ aff_train.sh | 4 ++-- config.yaml | 2 ++ slurm_job_scheduler.py | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/BANIS.py b/BANIS.py index ac15dd3..69d63cf 100644 --- a/BANIS.py +++ b/BANIS.py @@ -1,9 +1,11 @@ import argparse import gc import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" from collections import defaultdict from datetime import datetime from typing import Any, Dict +import random import numpy as np import pytorch_lightning as pl @@ -14,6 +16,7 @@ from pytorch_lightning import LightningModule, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, DeviceStatsMonitor, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.strategies import DDPStrategy from torch.nn.functional import binary_cross_entropy_with_logits from torch.optim import AdamW from torch.utils.data import DataLoader diff --git a/aff_train.sh b/aff_train.sh index 567b9b7..074bcfb 100644 --- a/aff_train.sh +++ b/aff_train.sh @@ -38,11 +38,11 @@ if [ -n "$RESUME" ]; then echo "Resuming from the last checkpoint" echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" echo "ALL ARGS: ${@}" - srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u BANIS.py --resume_from_last_checkpoint ${LONG_JOB_ARGS} & + srun mamba run -n nisb --no-capture-output python3 -u BANIS.py --resume_from_last_checkpoint ${LONG_JOB_ARGS} & else echo "Starting long training from scratch." export LONG_JOB_ARGS="${@}" echo "LONG_JOB_ARGS: ${LONG_JOB_ARGS}" - srun --gres=gpu:1 mamba run -n nisb --no-capture-output python3 -u BANIS.py ${LONG_JOB_ARGS} & + srun mamba run -n nisb --no-capture-output python3 -u BANIS.py ${LONG_JOB_ARGS} & fi wait diff --git a/config.yaml b/config.yaml index 032e4a3..a0677bc 100644 --- a/config.yaml +++ b/config.yaml @@ -58,5 +58,7 @@ params: - True distributed: - True + compile: + - False validate_extern: - True \ No newline at end of file diff --git a/slurm_job_scheduler.py b/slurm_job_scheduler.py index 2ba4d6c..268d8ec 100644 --- a/slurm_job_scheduler.py +++ b/slurm_job_scheduler.py @@ -16,7 +16,7 @@ def construct_args(params, combination, variable_keys): save_dir = "" exp_name_parts = [params["exp_name"][0]] if "exp_name" in params else [] for key, value in zip(params.keys(), combination): - if key in ["scheduler", "intensity_aug", "validate_extern", "distributed"]: + if key in ["scheduler", "intensity_aug", "validate_extern", "distributed", "compile", "resume_from_last_checkpoint"]: if not value: args.append(f"--no-{key}") else: From c34142fe6de5effb7ff25989525f784f6eae79c8 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Fri, 13 Jun 2025 14:40:55 +0200 Subject: [PATCH 17/33] gradient clippinng, logging activations etc, deactivate augmentations --- BANIS.py | 74 +++++++++++++++++++++++++++++++++++++++++- data.py | 6 ++-- inference.py | 1 - slurm_job_scheduler.py | 2 +- 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/BANIS.py b/BANIS.py index 69d63cf..2b942fe 100644 --- a/BANIS.py +++ b/BANIS.py @@ -13,6 +13,7 @@ import torchvision import zarr from nnunet_mednext import create_mednext_v1 +from nnunet_mednext import MedNeXtBlock, MedNeXtUpBlock, MedNeXtDownBlock from pytorch_lightning import LightningModule, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint, DeviceStatsMonitor, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger @@ -85,6 +86,8 @@ def validation_step(self, data: Dict[str, torch.Tensor], batch_idx: int) -> torc return self._step(data, "val") def _step(self, data: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: + self.log_input(data["img"]) + self.log_weight_stats() pred = self(data["img"]) target = data["aff"].half() loss_mask = data["aff"] >= 0 @@ -222,6 +225,65 @@ def safe_add_scalar(self, name: str, value: float, global_step=None) -> None: except Exception as e: print(f"Error logging {name}: {e}") + def log_input(self, input): + self.log_dict({ + f"input/min": input.min(), + f"input/max": input.max(), + f"input/mean": input.mean(), + f"input/std": input.std(), + }) + + def register_activation_hooks(self): + for name, module in self.named_modules(): + if isinstance(module, (MedNeXtUpBlock, MedNeXtDownBlock)): + def hook_fn(module, input, output, block_name=name): # capture name in default arg + if not self.training: # don't log during validation + return + self.log_dict({ + f"activations/{block_name}_min": output.min(), + f"activations/{block_name}_max": output.max(), + f"activations/{block_name}_mean": output.mean(), + f"activations/{block_name}_std": output.std(), + }) + if torch.isnan(output).any(): + print(f"NaN in output of {block_name}") + module.register_forward_hook(hook_fn) + + def setup(self, stage: str): + if stage == 'fit': + self.register_activation_hooks() + + def log_weight_stats(self): + for name, param in self.named_parameters(): + self.log_dict({ + f"weights/{name}_min": param.data.min(), + f"weights/{name}_max": param.data.max(), + f"weights/{name}_mean": param.data.mean(), + f"weights/{name}_std": param.data.std(), + }) + + def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): + total_norm_before = torch.norm(torch.stack([p.grad.norm(2) for p in self.parameters() if p.grad is not None])) + self.log("gradients/total_norm", total_norm_before.item()) + max_grad_before = max([p.grad.abs().max().item() for p in self.parameters() if p.grad is not None]) + self.log("gradients/max_grad", max_grad_before) + + for p in self.parameters(): + if p.grad is not None: + p.grad.data = torch.nan_to_num(p.grad.data, nan=0.0, posinf=1e4, neginf=-1e4) + total_norm2 = torch.norm(torch.stack([p.grad.norm(2) for p in self.parameters() if p.grad is not None])) + self.log("gradients/clamped_total_norm", total_norm2.item()) + max_grad2 = max([p.grad.abs().max().item() for p in self.parameters() if p.grad is not None]) + self.log("gradients/clamped_max_grad", max_grad2) + + self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm) + + total_norm_after = torch.norm(torch.stack([p.grad.norm(2) for p in self.parameters() if p.grad is not None])) + self.log("clipped_gradients/total_norm", total_norm_after.item(), on_step=True) + max_grad_after = max([p.grad.abs().max().item() for p in self.parameters() if p.grad is not None]) + self.log("clipped_gradients/max_grad", max_grad_after) + + def worker_init_fn(worker_id): """ Ensures different seeds for each worker. """ @@ -288,6 +350,7 @@ def main(): val_check_interval=args.val_check_interval, # validation full cube inference expensive so less frequent check_val_every_n_epoch=None, num_sanity_val_steps=args.n_debug_steps, + gradient_clip_val=1.0, ) print(f"Checkpoints will be saved in: {trainer.default_root_dir}/checkpoints") @@ -295,7 +358,14 @@ def main(): args.save_dir = save_dir args.num_input_channels = n_channels - model = BANIS(**vars(args)) + if os.path.exists(args.model_from_checkpoint): + print(f"Loading model from checkpoint: {args.model_from_checkpoint}") + model = BANIS(**vars(args)) + checkpoint = torch.load(args.model_from_checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["state_dict"]) + model.hparams.update(vars(args)) + else: + model = BANIS(**vars(args)) trainer.fit( model=model, @@ -328,6 +398,7 @@ def parse_args(): parser.add_argument("--log_every_n_steps", type=int, default=100, help="Log every n steps.") parser.add_argument("--val_check_interval", type=int, default=5000, help="Validation check interval.") parser.add_argument("--resume_from_last_checkpoint", action=argparse.BooleanOptionalAction, default=False, help="Resume training from the last checkpoint.") + parser.add_argument("--model_from_checkpoint", type=str, default="", help="Load model from defined checkpoint.") parser.add_argument("--validate_extern", action=argparse.BooleanOptionalAction, default=False, help="Long training with a separate validation process.") parser.add_argument("--distributed", action=argparse.BooleanOptionalAction, default=False, help="Use distributed training.") @@ -336,6 +407,7 @@ def parse_args(): parser.add_argument("--data_setting", type=str, default="base", help="Data setting identifier.") parser.add_argument("--real_data_path", type=str, default="/cajal/scratch/projects/misc/mdraw/data/funke/zebrafinch/training/", help="Path to the real dataset. See https://colab.research.google.com/github/funkelab/lsd/blob/master/lsd/tutorial/notebooks/lsd_data_download.ipynb ") parser.add_argument("--synthetic", type=float, default=1.0, help="Ratio of synthetic data to real data.") + parser.add_argument("--augment", action=argparse.BooleanOptionalAction, default=True, help="Use augmentations") parser.add_argument("--drop_slice_prob", type=float, default=0.05, help="Probability of dropping a slice during augmentation.") parser.add_argument("--shift_slice_prob", type=float, default=0.05, help="Probability of shifting a slice during augmentation.") parser.add_argument("--intensity_aug", action=argparse.BooleanOptionalAction, default=True, help="Whether to apply intensity augmentation.") diff --git a/data.py b/data.py index a307921..e0b7dc9 100644 --- a/data.py +++ b/data.py @@ -309,6 +309,7 @@ def get_seg_dataset( data_path: str, len_multiplier: int = 10, small_size: int = 128, + augment = False, augment_args: Namespace = Namespace( drop_slice_prob=0, shift_slice_prob=0, @@ -343,7 +344,7 @@ def get_seg_dataset( seg=seg.astype(np.int64), img=(img / 255).astype(np.float16), # divide = 255 - augment=True, + augment=augment, len_multiplier=len_multiplier, augment_args=augment_args, long_range=augment_args.long_range, @@ -371,6 +372,7 @@ def get_train_data(args: argparse.Namespace): args.real_data_path, small_size=args.small_size, len_multiplier=100, + augment=args.augment, augment_args=args, ) if args.synthetic > 0: @@ -412,7 +414,7 @@ def get_syn_train_data(args: argparse.Namespace): seg=img_seg["seg"], img=img_seg["img"], long_range=args.long_range, - augment=True, + augment=args.augment, augment_args=args, divide=255.0, small_size=args.small_size, diff --git a/inference.py b/inference.py index 8a939bd..1106f1b 100644 --- a/inference.py +++ b/inference.py @@ -3,7 +3,6 @@ import numba import numpy as np -import psutil import torch import torch.utils import zarr diff --git a/slurm_job_scheduler.py b/slurm_job_scheduler.py index 268d8ec..99482ab 100644 --- a/slurm_job_scheduler.py +++ b/slurm_job_scheduler.py @@ -16,7 +16,7 @@ def construct_args(params, combination, variable_keys): save_dir = "" exp_name_parts = [params["exp_name"][0]] if "exp_name" in params else [] for key, value in zip(params.keys(), combination): - if key in ["scheduler", "intensity_aug", "validate_extern", "distributed", "compile", "resume_from_last_checkpoint"]: + if key in ["scheduler", "intensity_aug", "validate_extern", "distributed", "compile", "resume_from_last_checkpoint", "augment"]: if not value: args.append(f"--no-{key}") else: From 7cc87fd6ca607303072b78bf07b842d9aa3d1174 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Mon, 30 Jun 2025 17:25:05 +0200 Subject: [PATCH 18/33] correct worker init function (default) --- BANIS.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/BANIS.py b/BANIS.py index 2b942fe..be09526 100644 --- a/BANIS.py +++ b/BANIS.py @@ -268,14 +268,6 @@ def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_cli max_grad_before = max([p.grad.abs().max().item() for p in self.parameters() if p.grad is not None]) self.log("gradients/max_grad", max_grad_before) - for p in self.parameters(): - if p.grad is not None: - p.grad.data = torch.nan_to_num(p.grad.data, nan=0.0, posinf=1e4, neginf=-1e4) - total_norm2 = torch.norm(torch.stack([p.grad.norm(2) for p in self.parameters() if p.grad is not None])) - self.log("gradients/clamped_total_norm", total_norm2.item()) - max_grad2 = max([p.grad.abs().max().item() for p in self.parameters() if p.grad is not None]) - self.log("gradients/clamped_max_grad", max_grad2) - self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm) total_norm_after = torch.norm(torch.stack([p.grad.norm(2) for p in self.parameters() if p.grad is not None])) @@ -284,17 +276,6 @@ def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_cli self.log("clipped_gradients/max_grad", max_grad_after) - -def worker_init_fn(worker_id): - """ Ensures different seeds for each worker. """ - # torch.initial_seed() is derived from the initial seed state but advanced for each worker - seed = torch.initial_seed() % (2**32) - np.random.seed(seed) - random.seed(seed) - torch.manual_seed(seed) - print(f"[Worker {worker_id}] Seed: {seed}") - - def main(): args = parse_args() seed_everything(args.seed, workers=True) @@ -369,9 +350,8 @@ def main(): trainer.fit( model=model, - train_dataloaders=DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, - drop_last=True, worker_init_fn=worker_init_fn), - val_dataloaders=DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn), + train_dataloaders=DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, drop_last=True), + val_dataloaders=DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers), ckpt_path="last" if args.resume_from_last_checkpoint else None ) From 5abdc759424f64d31c9419add287f53c9d8a6475 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 2 Jul 2025 14:09:54 +0200 Subject: [PATCH 19/33] config --- config.yaml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/config.yaml b/config.yaml index a0677bc..33cd0d2 100644 --- a/config.yaml +++ b/config.yaml @@ -34,7 +34,7 @@ params: n_steps: - 1_000_000 small_size: - - 256 + - 128 data_setting: #- "base" #- "liconn" @@ -57,8 +57,10 @@ params: auto_resubmit: - True distributed: - - True + - False compile: - False validate_extern: - - True \ No newline at end of file + - True + augment: + - False \ No newline at end of file From 585b42b6bf2af1fa5c46955fa08319f96eec1597 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Mon, 14 Jul 2025 17:59:50 +0200 Subject: [PATCH 20/33] wip: LocalCluster --- debug_visualilze.py | 132 +++++++++++++++++++++++++++++++ inference.py | 186 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 316 insertions(+), 2 deletions(-) create mode 100644 debug_visualilze.py diff --git a/debug_visualilze.py b/debug_visualilze.py new file mode 100644 index 0000000..4b8e79a --- /dev/null +++ b/debug_visualilze.py @@ -0,0 +1,132 @@ +import argparse +import os +import pickle +from typing import Tuple +from collections import defaultdict, Counter, deque + +import dask.array as da +import neuroglancer +import numpy as np +from dask.array import clip +from dask_image.ndfilters import gaussian +from neuroglancer import CoordinateSpace, LocalVolume, Viewer, SegmentationLayer +import zarr +from tqdm import tqdm +import networkx as nx +from networkx import connected_components, subgraph, convert_node_labels_to_integers + +from data import comp_affinities + +""" +Visualizes where the errors in the prediction are. +""" + +class SkeletonSource(neuroglancer.skeleton.SkeletonSource): + def __init__(self, dimensions, skel): + super().__init__(dimensions) + self.skel = skel + + def get_skeleton(self, i): + print(f"Getting skeleton for {i}") + cv_s = self.skel[i] + try: + s = neuroglancer.skeleton.Skeleton(vertex_positions=(cv_s.vertices / [9,9,20]), edges=cv_s.edges) + except Exception as e: + print(e) + return s + + + +# Coordinate spaces +COORDS = { + "standard": CoordinateSpace(names=['x', 'y', 'z'], units=['nm', 'nm', 'nm'], scales=[9, 9, 20]), + "standard_c": CoordinateSpace(names=["x", "y", "z", "c^"], units=["nm", "nm", "nm", ""], scales=[9, 9, 20, 1]), + "liconn": CoordinateSpace(names=['x', 'y', 'z'], units=['nm', 'nm', 'nm'], scales=[9, 9, 12]), + "liconn_c": CoordinateSpace(names=["x", "y", "z", "c^"], units=["nm", "nm", "nm", ""], scales=[9, 9, 12, 1]), + "aff": CoordinateSpace(names=[ "c^", "x", "y", "z"], units=["", "nm", "nm", "nm"], scales=[1, 9, 9, 20]), +} + + +def load_data(data_path: str): + """Load image, segmentation, and skeleton data.""" + seg = da.from_zarr(os.path.join(data_path, "data.zarr", "seg")).astype(np.uint32)[500:1000, 500:1000, 500:1000] + img = da.from_zarr(os.path.join(data_path, "data.zarr", "img"))[500:1000, 500:1000, 500:1000] + skel = da.from_zarr(os.path.join(data_path, "data.zarr", "skel")).astype(np.uint32) + with open(os.path.join(data_path, "skeleton_dense.pkl"), 'rb') as f: + skel_pkl = pickle.load(f) + return img, seg, skel, skel_pkl + + +def add_image_layer(s, name: str, img: da.Array, c_res: CoordinateSpace): + """Add an image layer to the viewer.""" + layer = LocalVolume(img, dimensions=c_res) + s.layers.append(name=f'img_{name}', layer=layer) + + +def add_segmentation_layer(s, name: str, seg: da.Array, skel: dict, res: CoordinateSpace): + """Add a segmentation layer to the viewer.""" + layer = SegmentationLayer( + source=[LocalVolume(seg, dimensions=res, volume_type="segmentation"), SkeletonSource(res, skel)], + skeleton_shader='void main() { emitRGB(vec3(.3, .8, .76)); }', + mesh_silhouette_rendering=2.0 + ) + layer.skeleton_rendering.mode3d = "lines" #"lines_and_points" + s.layers.append(name=f'seg_{name}', layer=layer) + + +def create_viewer(args) -> Viewer: + """Create and configure the Neuroglancer viewer.""" + neuroglancer.set_server_bind_address('localhost', args.port) + viewer = Viewer() + + with viewer.txn() as s: + img, seg, skel, skel_pkl = load_data(args.data_path) + + coord_space = COORDS["standard_c"] + add_image_layer(s, "gt", img, coord_space) + + seg_space = COORDS["standard"] + add_segmentation_layer(s, "gt", seg, skel_pkl, seg_space) + + if True: + aff, _ = comp_affinities(seg) + aff = da.from_array(aff).astype(np.float32) + s.layers["gt_aff"] = neuroglancer.ImageLayer( + source=neuroglancer.LocalVolume( + aff[:3], dimensions=COORDS["aff"], voxel_offset=[0, 0, 0, 0] + ), + shader="""void main() { + emitRGB(vec3(toNormalized(getDataValue(0)), + toNormalized(getDataValue(1)), + toNormalized(getDataValue(2)))); + }""", + ) + + pred_aff = da.from_zarr(args.pred_path).astype(np.float32) + s.layers["pred_aff"] = neuroglancer.ImageLayer( + source=neuroglancer.LocalVolume( + pred_aff[:3], dimensions=COORDS["aff"], voxel_offset=[0, 0, 0, 0] + ), + shader="""void main() { + emitRGB(vec3(toNormalized(getDataValue(0)), + toNormalized(getDataValue(1)), + toNormalized(getDataValue(2)))); + }""", + ) + + + print("If on a remote server, remember port forwarding. Meshes may take time to load.") + + return viewer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Neuroglancer Viewer for NISB project") + parser.add_argument("--data_path", type=str, default="/cajal/scratch/users/zuzur/NISB_corrected/base/val/seed100", help="Directory which contains data.zarr with segmentation + EM image + skeleton, and skeleton.pkl") + parser.add_argument("--pred_path", type=str, default="/cajal/scratch/projects/misc/zuzur/test.zarr/aff") + parser.add_argument("--port", type=int, default=8589, help="Port to run the viewer") + args = parser.parse_args() + + viewer = create_viewer(args) + print(viewer.get_viewer_url()) + input("Press Enter to quit") diff --git a/inference.py b/inference.py index 1106f1b..2218c96 100644 --- a/inference.py +++ b/inference.py @@ -1,4 +1,9 @@ import gc +import os +import shutil +import time +from collections import defaultdict +from datetime import timedelta from typing import Union, List, Tuple import numba @@ -6,17 +11,33 @@ import torch import torch.utils import zarr +import dask +from dask.distributed import (Client, LocalCluster) +import dask.array as da +from filelock import FileLock from numba import jit from scipy.ndimage import distance_transform_cdt from torch import autocast from torch.nn.functional import sigmoid from tqdm import tqdm +from tqdm.dask import TqdmCallback def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: """Scale sigmoid to avoid numerical issues in high confidence fp16.""" return sigmoid(0.2 * x) +def timing(func): + def wrapper(*args, **kwargs): + print(f"Starting '{func.__name__}'...") + start = time.time() + result = func(*args, **kwargs) + end = time.time() + elapsed = timedelta(seconds=end - start) + print(f"Finished '{func.__name__}' in {elapsed}") + return result + return wrapper + @jit(nopython=True) def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: @@ -67,6 +88,7 @@ def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray @torch.no_grad() @autocast(device_type="cuda") +@timing def patched_inference( img: Union[np.ndarray, zarr.Array], model: torch.nn.Module, @@ -91,8 +113,7 @@ def patched_inference( Returns: The full prediction. Shape: (channel, x, y, z). """ - print( - f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") + print(f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") img = img[:] # load into memory (expensive!) patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) @@ -197,3 +218,164 @@ def get_single_pred_weight(do_overlap: bool, small_size: int) -> Union[np.ndarra return distance_transform_cdt(pred_weight_helper).astype(np.float32)[1:-1, 1:-1, 1:-1] else: return None + + +@torch.no_grad() +@autocast(device_type="cuda") +@timing +def predict_aff( + img: Union[np.ndarray, zarr.Array], + model: torch.nn.Module, + zarr_path: str = "data.zarr", + small_size: int = 128, + do_overlap: bool = True, + prediction_channels: int = 6, + divide: int = 1, + chunk_cube_size: int = 1024 +): + """ + Perform patched affinity prediction with a model on an image. + + Args: + img: The input image. Shape: (x, y, z, channel). + model: The model to use for predictions. + small_size: The size of the patches. Defaults to 128. + do_overlap: Whether to perform overlapping predictions. Defaults to True: + half of patch size for all 3 axes. + prediction_channels: The number of channels in the output (additional model output + dimensions are discarded). Defaults to 6 (3 short + 3 long range affinities). + divide: The divisor for the image. Typically, 1 or 255 if img in [0, 255] + chunk_cube_size: The maximal side length of a cube held in memory. + + Returns: + The full prediction. Shape: (channel, x, y, z). + """ + print(f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") + + all_patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) + chunked_patch_coordinates = chunk_xyzs(all_patch_coordinates, chunk_cube_size) + + z = zarr.open_group(zarr_path, mode='w') + z.create_dataset('sum_pred', shape=(prediction_channels, *img.shape[:3]), chunks=(1, chunk_cube_size), dtype='f4') + z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), chunks=(1, chunk_cube_size), dtype='f4') + + # TODO: parallelize this!!!!!!!!!!!!! + cluster = LocalCluster(n_workers=4, threads_per_worker=1) + client = Client(cluster) + print("Dask Client Dashboard:", client.dashboard_link) + + tasks = [dask.delayed(predict_aff_patches_chunked)(chunk, img, model, zarr_path, small_size, do_overlap, prediction_channels, divide) + for chunk in chunked_patch_coordinates + ] + with TqdmCallback(desc="Overall Dask Progress (chunks)", unit="chunk", total=len(tasks)) as pbar: + dask.compute(*tasks) + #for chunk in tqdm(chunked_patch_coordinates): + # predict_aff_patches_chunked(chunk, img, model, zarr_path, small_size, do_overlap, prediction_channels, divide) + + tmp_sum_pred = da.from_zarr(f"{zarr_path}/sum_pred") + tmp_sum_weight = da.from_zarr(f"{zarr_path}/sum_weight") + aff = tmp_sum_pred / tmp_sum_weight + aff.to_zarr(f"{zarr_path}/aff", overwrite=True) + + for key in ['sum_pred', 'sum_weight']: + path = os.path.join(zarr_path, key) + if os.path.exists(path): + shutil.rmtree(path) + + return zarr.open(f"{zarr_path}/aff", mode="r") + + +def chunk_xyzs(xyzs, chunk_cube_size=1024): + """ + Chunks the patch coordinates into chunks containing coordinates from the same part of the big cube. + Args: + xyzs: list of all coordinates + chunk_cube_size: side length of each chunk + Returns: + chunked coordinates + """ + chunks = defaultdict(list) + for x, y, z in xyzs: + chunks[(x // chunk_cube_size, y // chunk_cube_size, z // chunk_cube_size)].append((x, y, z)) + return list(chunks.values()) + + +@torch.no_grad() +@autocast(device_type="cuda") +def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_size, do_overlap, prediction_channels, divide): + """ + Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. + Args: + patch_coordinates: List of patch coordinates. The extension of the coordinates must fit in memory (use adequate chunk size). + Returns: + Affinity prediction of the input chunk. + """ + max_x = max(x for x, y, z in patch_coordinates) + max_y = max(y for x, y, z in patch_coordinates) + max_z = max(z for x, y, z in patch_coordinates) + min_x = min(x for x, y, z in patch_coordinates) + min_y = min(y for x, y, z in patch_coordinates) + min_z = min(z for x, y, z in patch_coordinates) + + img_tmp = img[ + min_x: max_x + small_size, + min_y: max_y + small_size, + min_z: max_z + small_size, + ] + pred_tmp = np.zeros((prediction_channels, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) + weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) + single_pred_weight = get_single_pred_weight(do_overlap, small_size) + + for x_global, y_global, z_global in patch_coordinates: + x = x_global - min_x + y = y_global - min_y + z = z_global - min_z + img_patch = torch.tensor(np.moveaxis( + img_tmp[x: x + small_size, y: y + small_size, z: z + small_size], + -1, 0)[None]).to(model.device) / divide + pred = scale_sigmoid(model(img_patch))[0, :prediction_channels] + + weight_tmp[:, x: x + small_size, y: y + small_size, z: z + small_size] += single_pred_weight if do_overlap else 1 + pred_tmp[:, x: x + small_size, y: y + small_size, z: z + small_size] += pred.detach().cpu().numpy() * (single_pred_weight[None] if do_overlap else 1) + + + z = zarr.open_group(zarr_path, mode='a') + weight_mask = z['sum_weight'] + full_pred = z['sum_pred'] + + with FileLock(f"{zarr_path}/sum_weight.lock"): + weight_mask[ + :, + min_x: max_x + small_size, + min_y: max_y + small_size, + min_z: max_z + small_size, + ] += weight_tmp + + with FileLock(f"{zarr_path}/sum_pred.lock"): + full_pred[ + :, + min_x: max_x + small_size, + min_y: max_y + small_size, + min_z: max_z + small_size, + ] += pred_tmp + + + +if __name__ == "__main__": + input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" + img_data = zarr.open(input_path, mode="r")["img"] + + model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" + from BANIS import BANIS + model = BANIS.load_from_checkpoint(model_path) + + #aff_pred = patched_inference(img_data, model=model, do_overlap=True, prediction_channels=3, divide=255,small_size=model.hparams.small_size) + #store = zarr.DirectoryStore('/cajal/scratch/projects/misc/zuzur/test0.zarr') + #store.rmdir('') + #z = zarr.array(aff_pred, store=store, override=True) + + #aff_pred2 = predict_aff(img_data, model, zarr_path="/cajal/scratch/projects/misc/zuzur/test.zarr", do_overlap=True, prediction_channels=3, divide=255,small_size=model.hparams.small_size) + # ValueError('Codec does not support buffers of > 2147483647 bytes') + + aff_pred2 = predict_aff(img_data, model, chunk_cube_size=250, zarr_path="/cajal/scratch/projects/misc/zuzur/test2.zarr", do_overlap=True, prediction_channels=3, divide=255,small_size=model.hparams.small_size) + From 82a832fce2d559f0aacad00e70c5f4a606b61f1e Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Fri, 1 Aug 2025 11:32:10 +0200 Subject: [PATCH 21/33] add adapted nerl --- metrics.py | 207 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 206 insertions(+), 1 deletion(-) diff --git a/metrics.py b/metrics.py index 5f1c955..d166376 100644 --- a/metrics.py +++ b/metrics.py @@ -4,7 +4,8 @@ import numpy as np import zarr -from funlib.evaluate import rand_voi, expected_run_length +from funlib.evaluate import rand_voi, expected_run_length, get_skeleton_lengths +from funlib.evaluate.run_length import SkeletonScores from networkx import get_node_attributes from tqdm import tqdm @@ -78,10 +79,214 @@ def compute_metrics(pred_seg: Union[np.ndarray, zarr.Array], skel_path: str) -> "voi_report": voi_report, "n_non0_mergers": n_non0_mergers, } + + adapted_nerls = {} + for ignored_merger_size in [5, 20, 100, np.inf]: + adapted_erl_report = adapted_erl( + skel, + "id", + "edge_length", + get_node_attributes(skel, "pred_id"), + skeleton_position_attributes=["nm_position"], + return_merge_split_stats=True, + ignored_merger_size=ignored_merger_size + ) + adapted_nerls[ignored_merger_size] = adapted_erl_report[0] / max_erl + for ignored_mergers, adapted_nerl in adapted_nerls.items(): + metrics[f"nerl_{ignored_mergers}"] = adapted_nerl + print(f"metrics: {metrics}") return metrics +def adapted_erl( + skeletons, + skeleton_id_attribute, + edge_length_attribute, + node_segment_lut, + skeleton_lengths=None, + skeleton_position_attributes=None, + return_merge_split_stats=False, + ignored_merger_size=0 + ): + """ + Adapted from `funlib.evaluate.expected_run_length`. + Args: + ignored_merger_size: + Maximum number of nodes in a segment with the same predicted and ground truth id, + so that the segment is never considered to be part of a merger. + ignored_merger_size=0 is the same as the original ERL. + """ + + if skeleton_position_attributes is not None: + + if skeleton_lengths is not None: + raise ValueError( + "If skeleton_position_attributes is given, skeleton_lengths" + "should not be given") + + skeleton_lengths = get_skeleton_lengths( + skeletons, + skeleton_position_attributes, + skeleton_id_attribute, + store_edge_length=edge_length_attribute) + + total_skeletons_length = np.sum([l for _, l in skeleton_lengths.items()]) + + res = evaluate_skeletons( + skeletons, + skeleton_id_attribute, + node_segment_lut, + return_merge_split_stats=return_merge_split_stats, + ignored_merger_size=ignored_merger_size + ) + + if return_merge_split_stats: + skeleton_scores, merge_split_stats = res + else: + skeleton_scores = res + + skeletons_erl = 0 + + for skeleton_id, scores in skeleton_scores.items(): + + skeleton_length = skeleton_lengths[skeleton_id] + skeleton_erl = 0 + + for segment_id, correct_edges in scores.correct_edges.items(): + correct_edges_length = np.sum([ + skeletons.edges[e][edge_length_attribute] + for e in correct_edges]) + + skeleton_erl += ( + correct_edges_length * + (correct_edges_length / skeleton_length) + ) + + skeletons_erl += ( + (skeleton_length / total_skeletons_length) * + skeleton_erl + ) + + if return_merge_split_stats: + return skeletons_erl, merge_split_stats + else: + return skeletons_erl + + +def evaluate_skeletons( + skeletons, + skeleton_id_attribute, + node_segment_lut, + return_merge_split_stats=False, + ignored_merger_size=1 +): + """ + Adapted from `funlib.evaluate.expected_run_length`. + Args: + ignored_merger_size: + Maximum number of nodes in a segment with the same predicted and ground truth id, + so that the segment is never considered to be part of a merger. + ignored_merger_size=0 is the same as the original ERL. + """ + + # find all merging segments (skeleton edges on merging segments will be + # counted as wrong) + + # pairs of (skeleton, segment), one for each node + skeleton_segment = np.array([ + [data[skeleton_id_attribute], node_segment_lut[n]] + for n, data in skeletons.nodes(data=True) + ]) + + # unique pairs of (skeleton, segment) + # THIS IS THE CHANGE - only consider segments of bigger size than ignored_merger_size + skeleton_segment, counts = np.unique(skeleton_segment, axis=0, return_counts=True) + skeleton_segment = skeleton_segment[counts > ignored_merger_size] + + # number of times that a segment was mapped to a skeleton + segments, num_segment_skeletons = np.unique( + skeleton_segment[:, 1], + return_counts=True) + + # all segments that merge at least two skeletons + merging_segments = segments[num_segment_skeletons > 1] + + merging_segments_mask = np.isin(skeleton_segment[:, 1], merging_segments) + merged_skeletons = skeleton_segment[:, 0][merging_segments_mask] + merging_segments = set(merging_segments) + + merges = {} + splits = {} + + if return_merge_split_stats: + + merged_segments = skeleton_segment[:, 1][merging_segments_mask] + + for segment, skeleton in zip(merged_segments, merged_skeletons): + if segment not in merges: + merges[segment] = [] + merges[segment].append(skeleton) + + merged_skeletons = set(np.unique(merged_skeletons)) + + skeleton_scores = {} + + for u, v in skeletons.edges(): + + skeleton_id = skeletons.nodes[u][skeleton_id_attribute] + segment_u = node_segment_lut[u] + segment_v = node_segment_lut[v] + + if skeleton_id not in skeleton_scores: + scores = SkeletonScores() + skeleton_scores[skeleton_id] = scores + else: + scores = skeleton_scores[skeleton_id] + + if segment_u == 0 or segment_v == 0: + scores.ommitted += 1 + continue + + if segment_u != segment_v: + scores.split += 1 + + if return_merge_split_stats: + if skeleton_id not in splits: + splits[skeleton_id] = [] + splits[skeleton_id].append((segment_u, segment_v)) + continue + + # segment_u == segment_v != 0 + segment = segment_u + + # potentially merged edge? + if skeleton_id in merged_skeletons: + if segment in merging_segments: + scores.merged += 1 + continue + + scores.correct += 1 + + if segment not in scores.correct_edges: + scores.correct_edges[segment] = [] + scores.correct_edges[segment].append((u, v)) + + if return_merge_split_stats: + + merge_split_stats = { + 'merge_stats': merges, + 'split_stats': splits + } + + return skeleton_scores, merge_split_stats + + else: + + return skeleton_scores + + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Compute segmentation quality metrics.") parser.add_argument("--pred_seg", type=str, required=True, help="Path to predicted segmentation (Zarr format)") From 27324018faeafd12d73f4eda5947bf141d45b594 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Sat, 9 Aug 2025 16:02:03 +0200 Subject: [PATCH 22/33] inference --- inference.py | 199 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 148 insertions(+), 51 deletions(-) diff --git a/inference.py b/inference.py index 2218c96..e3b6c64 100644 --- a/inference.py +++ b/inference.py @@ -8,34 +8,70 @@ import numba import numpy as np +import psutil import torch import torch.utils import zarr import dask -from dask.distributed import (Client, LocalCluster) +from dask import compute, persist, delayed +from dask.distributed import Client, LocalCluster +from dask_cuda import LocalCUDACluster +from dask.diagnostics import ProgressBar import dask.array as da +from dask_jobqueue import SLURMCluster +from distributed import progress from filelock import FileLock from numba import jit from scipy.ndimage import distance_transform_cdt from torch import autocast from torch.nn.functional import sigmoid from tqdm import tqdm -from tqdm.dask import TqdmCallback +import tracemalloc +import threading def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: """Scale sigmoid to avoid numerical issues in high confidence fp16.""" return sigmoid(0.2 * x) -def timing(func): + +def measure_stats(func): + def monitor_memory(interval=0.1, result=None): + proc = psutil.Process(os.getpid()) + peak = 0 + while not getattr(monitor_memory, "stop", False): + rss = proc.memory_info().rss + peak = max(peak, rss) + time.sleep(interval) + if result is not None: + result["peak"] = peak # Save peak memory to shared dict + def wrapper(*args, **kwargs): - print(f"Starting '{func.__name__}'...") + memory_stats = {} + thread = threading.Thread(target=monitor_memory, kwargs={"interval": 0.1, "result": memory_stats}) + thread.start() + torch.cuda.reset_peak_memory_stats() + tracemalloc.start() start = time.time() + result = func(*args, **kwargs) + end = time.time() elapsed = timedelta(seconds=end - start) - print(f"Finished '{func.__name__}' in {elapsed}") - return result + current, peak = tracemalloc.get_traced_memory() + max_mem = torch.cuda.max_memory_reserved() + monitor_memory.stop = True + thread.join() + + stats = { + "time": f"{elapsed}", + "peak_python_mem": f"{peak / 1024**2:.2f} MB", + "max_cuda_mem": f"{max_mem / 1024 ** 2:.2f} MB", + "rss_mem": f"{memory_stats['peak'] / 1024 ** 2:.2f} MB" + } + + return result, stats + return wrapper @@ -88,7 +124,6 @@ def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray @torch.no_grad() @autocast(device_type="cuda") -@timing def patched_inference( img: Union[np.ndarray, zarr.Array], model: torch.nn.Module, @@ -113,7 +148,8 @@ def patched_inference( Returns: The full prediction. Shape: (channel, x, y, z). """ - print(f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") + print( + f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") img = img[:] # load into memory (expensive!) patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) @@ -222,7 +258,6 @@ def get_single_pred_weight(do_overlap: bool, small_size: int) -> Union[np.ndarra @torch.no_grad() @autocast(device_type="cuda") -@timing def predict_aff( img: Union[np.ndarray, zarr.Array], model: torch.nn.Module, @@ -231,7 +266,8 @@ def predict_aff( do_overlap: bool = True, prediction_channels: int = 6, divide: int = 1, - chunk_cube_size: int = 1024 + chunk_cube_size: int = 1024, + compute_backend: str = "local" ): """ Perform patched affinity prediction with a model on an image. @@ -239,6 +275,7 @@ def predict_aff( Args: img: The input image. Shape: (x, y, z, channel). model: The model to use for predictions. + zarr_path: Output path to save the prediction in zarr format. small_size: The size of the patches. Defaults to 128. do_overlap: Whether to perform overlapping predictions. Defaults to True: half of patch size for all 3 axes. @@ -246,31 +283,61 @@ def predict_aff( dimensions are discarded). Defaults to 6 (3 short + 3 long range affinities). divide: The divisor for the image. Typically, 1 or 255 if img in [0, 255] chunk_cube_size: The maximal side length of a cube held in memory. + compute_backend: Type of computation / dask backend. One of: + + - "local": uses a cycle on the local machine (default) + - "local_cluster": uses a localGPUcluster to utilize all local GPUs without SLURM + - "slurm": uses a slurm cluster with all available nodes Returns: The full prediction. Shape: (channel, x, y, z). """ - print(f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") + print( + f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") + print(f"Parameters: cube size {chunk_cube_size}, compute backend {compute_backend}.") all_patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) chunked_patch_coordinates = chunk_xyzs(all_patch_coordinates, chunk_cube_size) z = zarr.open_group(zarr_path, mode='w') - z.create_dataset('sum_pred', shape=(prediction_channels, *img.shape[:3]), chunks=(1, chunk_cube_size), dtype='f4') - z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), chunks=(1, chunk_cube_size), dtype='f4') - - # TODO: parallelize this!!!!!!!!!!!!! - cluster = LocalCluster(n_workers=4, threads_per_worker=1) - client = Client(cluster) - print("Dask Client Dashboard:", client.dashboard_link) - - tasks = [dask.delayed(predict_aff_patches_chunked)(chunk, img, model, zarr_path, small_size, do_overlap, prediction_channels, divide) - for chunk in chunked_patch_coordinates - ] - with TqdmCallback(desc="Overall Dask Progress (chunks)", unit="chunk", total=len(tasks)) as pbar: - dask.compute(*tasks) - #for chunk in tqdm(chunked_patch_coordinates): - # predict_aff_patches_chunked(chunk, img, model, zarr_path, small_size, do_overlap, prediction_channels, divide) + zarr_chunk_size = min(chunk_cube_size, 512) + z.create_dataset('sum_pred', shape=(prediction_channels, *img.shape[:3]), + chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), + chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + + if compute_backend == "local": + for chunk in tqdm(chunked_patch_coordinates): + predict_aff_patches_chunked(chunk, img, model, zarr_path, small_size, do_overlap, prediction_channels, divide) + torch.cuda.empty_cache() # TODO: does this help? + else: + if compute_backend == "local_cluster": + cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU + elif compute_backend == "slurm": + cluster = SLURMCluster( + cores=8, + memory="400GB", + processes=1, + worker_extra_args=["--resources processes=1", "--nthreads=1"], + job_extra_directives=["--gres=gpu:1"], + walltime="1-00:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + + else: + raise NotImplementedError(f"Compute backend {compute_backend} not available.") + + client = Client(cluster) + print(f"Waiting for workers...") + client.wait_for_workers(n_workers=1) + print("Dask Client Dashboard:", client.dashboard_link) + tasks = [dask.delayed(predict_aff_patches_chunked)(chunk, img, model, zarr_path, small_size, do_overlap, + prediction_channels, divide) + for chunk in chunked_patch_coordinates + ] + futures = persist(tasks) + progress(futures) # progress bar + compute(futures) tmp_sum_pred = da.from_zarr(f"{zarr_path}/sum_pred") tmp_sum_weight = da.from_zarr(f"{zarr_path}/sum_weight") @@ -302,7 +369,8 @@ def chunk_xyzs(xyzs, chunk_cube_size=1024): @torch.no_grad() @autocast(device_type="cuda") -def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_size, do_overlap, prediction_channels, divide): +def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_size, do_overlap, prediction_channels, + divide): """ Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. Args: @@ -318,14 +386,16 @@ def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_ min_z = min(z for x, y, z in patch_coordinates) img_tmp = img[ - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, - ] + min_x: max_x + small_size, + min_y: max_y + small_size, + min_z: max_z + small_size, + ] pred_tmp = np.zeros((prediction_channels, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) single_pred_weight = get_single_pred_weight(do_overlap, small_size) + # TODO: load model here from path to save time / properly parallelize? (problematic to use during training where model is in memory only) + for x_global, y_global, z_global in patch_coordinates: x = x_global - min_x y = y_global - min_y @@ -335,9 +405,10 @@ def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_ -1, 0)[None]).to(model.device) / divide pred = scale_sigmoid(model(img_patch))[0, :prediction_channels] - weight_tmp[:, x: x + small_size, y: y + small_size, z: z + small_size] += single_pred_weight if do_overlap else 1 - pred_tmp[:, x: x + small_size, y: y + small_size, z: z + small_size] += pred.detach().cpu().numpy() * (single_pred_weight[None] if do_overlap else 1) - + weight_tmp[:, x: x + small_size, y: y + small_size, + z: z + small_size] += single_pred_weight if do_overlap else 1 + pred_tmp[:, x: x + small_size, y: y + small_size, z: z + small_size] += pred.detach().cpu().numpy() * ( + single_pred_weight[None] if do_overlap else 1) z = zarr.open_group(zarr_path, mode='a') weight_mask = z['sum_weight'] @@ -345,37 +416,63 @@ def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_ with FileLock(f"{zarr_path}/sum_weight.lock"): weight_mask[ - :, - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, + :, + min_x: max_x + small_size, + min_y: max_y + small_size, + min_z: max_z + small_size, ] += weight_tmp with FileLock(f"{zarr_path}/sum_pred.lock"): full_pred[ - :, - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, + :, + min_x: max_x + small_size, + min_y: max_y + small_size, + min_z: max_z + small_size, ] += pred_tmp - -if __name__ == "__main__": +def test_local_prediction(): input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" img_data = zarr.open(input_path, mode="r")["img"] model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" from BANIS import BANIS + model = BANIS.load_from_checkpoint(model_path) - #aff_pred = patched_inference(img_data, model=model, do_overlap=True, prediction_channels=3, divide=255,small_size=model.hparams.small_size) - #store = zarr.DirectoryStore('/cajal/scratch/projects/misc/zuzur/test0.zarr') - #store.rmdir('') - #z = zarr.array(aff_pred, store=store, override=True) + all_stats = {} + + for chunk_cube_size in [200, 400, 512, 750, 1024, 1500, 3000]: + measured_predict_aff = measure_stats(predict_aff) + + result, stats = measured_predict_aff(img_data, model, chunk_cube_size=chunk_cube_size, compute_backend="local", + zarr_path=f"/cajal/scratch/projects/misc/zuzur/test{chunk_cube_size}.zarr", do_overlap=True, + prediction_channels=3, divide=255, small_size=model.hparams.small_size) + + all_stats[chunk_cube_size] = stats + print(f"chunk size {chunk_cube_size}: {stats}") - #aff_pred2 = predict_aff(img_data, model, zarr_path="/cajal/scratch/projects/misc/zuzur/test.zarr", do_overlap=True, prediction_channels=3, divide=255,small_size=model.hparams.small_size) - # ValueError('Codec does not support buffers of > 2147483647 bytes') + print(all_stats) + for (value, stat) in all_stats.items(): + print(f"{value}: {stat}") - aff_pred2 = predict_aff(img_data, model, chunk_cube_size=250, zarr_path="/cajal/scratch/projects/misc/zuzur/test2.zarr", do_overlap=True, prediction_channels=3, divide=255,small_size=model.hparams.small_size) +def test_slurm_prediction(): + input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" + img_data = zarr.open(input_path, mode="r")["img"] + + model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" + from BANIS import BANIS + + model = BANIS.load_from_checkpoint(model_path) + + measured_predict_aff = measure_stats(predict_aff) + # only one run - runtime dependent on number of available slurm nodes + result, stats = measured_predict_aff(img_data, model, chunk_cube_size=512, compute_backend="slurm", + zarr_path=f"/cajal/scratch/projects/misc/zuzur/test_slurm.zarr", do_overlap=True, + prediction_channels=3, divide=255, small_size=model.hparams.small_size) + + print(stats) + +if __name__ == "__main__": + test_slurm_prediction() \ No newline at end of file From bebf90dbb4fe401d71dc35df7846f2a354e06d15 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Thu, 14 Aug 2025 22:48:19 +0200 Subject: [PATCH 23/33] inference: local or slurm computation, memory management via cube size --- BANIS.py | 9 +- inference.py | 293 ++++++++++++++++++++------------------------------- 2 files changed, 118 insertions(+), 184 deletions(-) diff --git a/BANIS.py b/BANIS.py index be09526..fc9f15e 100644 --- a/BANIS.py +++ b/BANIS.py @@ -24,7 +24,7 @@ from tqdm import tqdm from data import load_data -from inference import scale_sigmoid, patched_inference, compute_connected_component_segmentation +from inference import scale_sigmoid, compute_connected_component_segmentation, predict_aff from metrics import compute_metrics @@ -164,11 +164,8 @@ def full_cube_inference(self, mode: str, global_step=None): img_data = zarr.open(os.path.join(seed_path, "data.zarr"), mode="r")["img"] - aff_pred = patched_inference(img_data, model=self, do_overlap=True, prediction_channels=3, divide=255, - small_size=self.hparams.small_size) - - aff_pred = zarr.array(aff_pred, dtype=np.float16, store=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", - chunks=(3, 512, 512, 512), overwrite=True) + aff_pred = predict_aff(img_data, model=self, zarr_path=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", do_overlap=True, prediction_channels=3, divide=255, + small_size=self.hparams.small_size, compute_backend="local") self._evaluate_thresholds(aff_pred, os.path.join(seed_path, "skeleton.pkl"), mode, global_step) diff --git a/inference.py b/inference.py index e3b6c64..4ffecbf 100644 --- a/inference.py +++ b/inference.py @@ -1,24 +1,17 @@ -import gc -import os import shutil -import time from collections import defaultdict -from datetime import timedelta from typing import Union, List, Tuple import numba import numpy as np -import psutil import torch import torch.utils import zarr import dask from dask import compute, persist, delayed from dask.distributed import Client, LocalCluster -from dask_cuda import LocalCUDACluster from dask.diagnostics import ProgressBar import dask.array as da -from dask_jobqueue import SLURMCluster from distributed import progress from filelock import FileLock from numba import jit @@ -26,8 +19,6 @@ from torch import autocast from torch.nn.functional import sigmoid from tqdm import tqdm -import tracemalloc -import threading def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: @@ -36,6 +27,13 @@ def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: def measure_stats(func): + import os + import time + from datetime import timedelta + import tracemalloc + import threading + import psutil + def monitor_memory(interval=0.1, result=None): proc = psutil.Process(os.getpid()) peak = 0 @@ -86,7 +84,7 @@ def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray Returns: The segmentation. Shape: (x, y, z). """ - visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint8) + visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=numba.boolean) seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) cur_id = 1 for i in range(visited.shape[0]): @@ -124,59 +122,101 @@ def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray @torch.no_grad() @autocast(device_type="cuda") -def patched_inference( +def predict_aff( img: Union[np.ndarray, zarr.Array], - model: torch.nn.Module, + model: torch.nn.Module = None, + model_path: str = None, + zarr_path: str = "aff_prediction.zarr", small_size: int = 128, do_overlap: bool = True, prediction_channels: int = 6, divide: int = 1, -) -> np.ndarray: + chunk_cube_size: int = 1024, + compute_backend: str = "local" +): """ - Perform patched inference with a model on an image. + Perform patched affinity prediction with a model on an image. Args: img: The input image. Shape: (x, y, z, channel). - model: The model to use for predictions. + model: The model to use for predictions (only for local prediction). + model_path: Path to the model checkpoint to use for predictions (if model not specified). + zarr_path: Output path to save the prediction in zarr format. small_size: The size of the patches. Defaults to 128. do_overlap: Whether to perform overlapping predictions. Defaults to True: half of patch size for all 3 axes. prediction_channels: The number of channels in the output (additional model output dimensions are discarded). Defaults to 6 (3 short + 3 long range affinities). divide: The divisor for the image. Typically, 1 or 255 if img in [0, 255] + chunk_cube_size: The maximal side length of a cube held in memory. + compute_backend: Type of computation / dask backend. One of: + + - "local": uses a cycle on the local machine (default) + - "local_cluster": uses a localGPUcluster to utilize all local GPUs without SLURM + - "slurm": uses a slurm cluster with all available nodes Returns: The full prediction. Shape: (channel, x, y, z). """ print( f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") - img = img[:] # load into memory (expensive!) + print(f"Parameters: cube size {chunk_cube_size}, compute backend {compute_backend}.") - patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) - single_pred_weight = get_single_pred_weight(do_overlap, small_size) - # to weight overlapping predictions lower close to the boundaries + all_patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) + chunked_patch_coordinates = chunk_xyzs(all_patch_coordinates, chunk_cube_size) - weight_sum = np.zeros((1, *img.shape[:3]), dtype=np.float32) - weighted_pred = np.zeros((prediction_channels, *img.shape[:3]), dtype=np.float32) + z = zarr.open_group(zarr_path + "_tmp", mode='w') + zarr_chunk_size = min(chunk_cube_size, 512) + z.create_dataset('sum_pred', shape=(prediction_channels, *img.shape[:3]), + chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), + chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - device = next(model.parameters()).device - assert device.type != 'cpu' + if compute_backend == "local": + if not model: + from BANIS import BANIS + model = BANIS.load_from_checkpoint(model_path) + for chunk in tqdm(chunked_patch_coordinates): + predict_aff_patches_chunked(chunk, img, model, zarr_path + "_tmp", small_size, do_overlap, prediction_channels, divide) + torch.cuda.empty_cache() # TODO: does this help? + else: + if compute_backend == "local_cluster": + from dask_cuda import LocalCUDACluster + cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU + elif compute_backend == "slurm": + from dask_jobqueue import SLURMCluster + cluster = SLURMCluster( + cores=8, + memory="400GB", + processes=1, + worker_extra_args=["--resources processes=1", "--nthreads=1"], + job_extra_directives=["--gres=gpu:1"], + walltime="1-00:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) - for x, y, z in tqdm(patch_coordinates): - img_patch = torch.tensor( - np.moveaxis(img[x: x + small_size, y: y + small_size, z: z + small_size], -1, 0)[None]).half().to( - device) / divide - pred = scale_sigmoid(model(img_patch))[0, :prediction_channels] + else: + raise NotImplementedError(f"Compute backend {compute_backend} not available.") - weight_sum[:, x: x + small_size, y: y + small_size, - z: z + small_size] += single_pred_weight if do_overlap else 1 - weighted_pred[:, x: x + small_size, y: y + small_size, z: z + small_size] += pred.cpu().numpy() * ( - single_pred_weight[None] if do_overlap else 1) - del img # to save memory before division - # assert np.all(weight_sum > 0) - np.divide(weighted_pred, weight_sum, out=weighted_pred) + client = Client(cluster) + print(f"Waiting for workers...") + client.wait_for_workers(n_workers=1) + print("Dask Client Dashboard:", client.dashboard_link) + tasks = [dask.delayed(predict_aff_patches_chunked)(chunk, img, model_path, zarr_path + "_tmp", small_size, do_overlap, prediction_channels, divide) + for chunk in chunked_patch_coordinates + ] + futures = persist(tasks) + progress(futures) # progress bar + compute(futures) + + tmp_sum_pred = da.from_zarr(f"{zarr_path}_tmp/sum_pred") + tmp_sum_weight = da.from_zarr(f"{zarr_path}_tmp/sum_weight") + aff = tmp_sum_pred / tmp_sum_weight + aff.to_zarr(zarr_path, overwrite=True) - return weighted_pred + shutil.rmtree(zarr_path + "_tmp") + + return zarr.open(zarr_path, mode="r") def get_coordinates( @@ -256,102 +296,6 @@ def get_single_pred_weight(do_overlap: bool, small_size: int) -> Union[np.ndarra return None -@torch.no_grad() -@autocast(device_type="cuda") -def predict_aff( - img: Union[np.ndarray, zarr.Array], - model: torch.nn.Module, - zarr_path: str = "data.zarr", - small_size: int = 128, - do_overlap: bool = True, - prediction_channels: int = 6, - divide: int = 1, - chunk_cube_size: int = 1024, - compute_backend: str = "local" -): - """ - Perform patched affinity prediction with a model on an image. - - Args: - img: The input image. Shape: (x, y, z, channel). - model: The model to use for predictions. - zarr_path: Output path to save the prediction in zarr format. - small_size: The size of the patches. Defaults to 128. - do_overlap: Whether to perform overlapping predictions. Defaults to True: - half of patch size for all 3 axes. - prediction_channels: The number of channels in the output (additional model output - dimensions are discarded). Defaults to 6 (3 short + 3 long range affinities). - divide: The divisor for the image. Typically, 1 or 255 if img in [0, 255] - chunk_cube_size: The maximal side length of a cube held in memory. - compute_backend: Type of computation / dask backend. One of: - - - "local": uses a cycle on the local machine (default) - - "local_cluster": uses a localGPUcluster to utilize all local GPUs without SLURM - - "slurm": uses a slurm cluster with all available nodes - - Returns: - The full prediction. Shape: (channel, x, y, z). - """ - print( - f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") - print(f"Parameters: cube size {chunk_cube_size}, compute backend {compute_backend}.") - - all_patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) - chunked_patch_coordinates = chunk_xyzs(all_patch_coordinates, chunk_cube_size) - - z = zarr.open_group(zarr_path, mode='w') - zarr_chunk_size = min(chunk_cube_size, 512) - z.create_dataset('sum_pred', shape=(prediction_channels, *img.shape[:3]), - chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), - chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - - if compute_backend == "local": - for chunk in tqdm(chunked_patch_coordinates): - predict_aff_patches_chunked(chunk, img, model, zarr_path, small_size, do_overlap, prediction_channels, divide) - torch.cuda.empty_cache() # TODO: does this help? - else: - if compute_backend == "local_cluster": - cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU - elif compute_backend == "slurm": - cluster = SLURMCluster( - cores=8, - memory="400GB", - processes=1, - worker_extra_args=["--resources processes=1", "--nthreads=1"], - job_extra_directives=["--gres=gpu:1"], - walltime="1-00:00:00" - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - - else: - raise NotImplementedError(f"Compute backend {compute_backend} not available.") - - client = Client(cluster) - print(f"Waiting for workers...") - client.wait_for_workers(n_workers=1) - print("Dask Client Dashboard:", client.dashboard_link) - tasks = [dask.delayed(predict_aff_patches_chunked)(chunk, img, model, zarr_path, small_size, do_overlap, - prediction_channels, divide) - for chunk in chunked_patch_coordinates - ] - futures = persist(tasks) - progress(futures) # progress bar - compute(futures) - - tmp_sum_pred = da.from_zarr(f"{zarr_path}/sum_pred") - tmp_sum_weight = da.from_zarr(f"{zarr_path}/sum_weight") - aff = tmp_sum_pred / tmp_sum_weight - aff.to_zarr(f"{zarr_path}/aff", overwrite=True) - - for key in ['sum_pred', 'sum_weight']: - path = os.path.join(zarr_path, key) - if os.path.exists(path): - shutil.rmtree(path) - - return zarr.open(f"{zarr_path}/aff", mode="r") - - def chunk_xyzs(xyzs, chunk_cube_size=1024): """ Chunks the patch coordinates into chunks containing coordinates from the same part of the big cube. @@ -369,8 +313,7 @@ def chunk_xyzs(xyzs, chunk_cube_size=1024): @torch.no_grad() @autocast(device_type="cuda") -def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_size, do_overlap, prediction_channels, - divide): +def predict_aff_patches_chunked(patch_coordinates, img, model_path, zarr_path, small_size, do_overlap, prediction_channels, divide): """ Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. Args: @@ -394,7 +337,9 @@ def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_ weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) single_pred_weight = get_single_pred_weight(do_overlap, small_size) - # TODO: load model here from path to save time / properly parallelize? (problematic to use during training where model is in memory only) + from BANIS import BANIS + print(model_path, flush=True) + model = BANIS.load_from_checkpoint(model_path) for x_global, y_global, z_global in patch_coordinates: x = x_global - min_x @@ -431,48 +376,40 @@ def predict_aff_patches_chunked(patch_coordinates, img, model, zarr_path, small_ ] += pred_tmp -def test_local_prediction(): - input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" - img_data = zarr.open(input_path, mode="r")["img"] - - model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" - from BANIS import BANIS - - model = BANIS.load_from_checkpoint(model_path) - - all_stats = {} - - for chunk_cube_size in [200, 400, 512, 750, 1024, 1500, 3000]: - measured_predict_aff = measure_stats(predict_aff) - - result, stats = measured_predict_aff(img_data, model, chunk_cube_size=chunk_cube_size, compute_backend="local", - zarr_path=f"/cajal/scratch/projects/misc/zuzur/test{chunk_cube_size}.zarr", do_overlap=True, - prediction_channels=3, divide=255, small_size=model.hparams.small_size) - - all_stats[chunk_cube_size] = stats - print(f"chunk size {chunk_cube_size}: {stats}") - - print(all_stats) - for (value, stat) in all_stats.items(): - print(f"{value}: {stat}") - - -def test_slurm_prediction(): - input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" - img_data = zarr.open(input_path, mode="r")["img"] - - model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" - from BANIS import BANIS - - model = BANIS.load_from_checkpoint(model_path) - - measured_predict_aff = measure_stats(predict_aff) - # only one run - runtime dependent on number of available slurm nodes - result, stats = measured_predict_aff(img_data, model, chunk_cube_size=512, compute_backend="slurm", - zarr_path=f"/cajal/scratch/projects/misc/zuzur/test_slurm.zarr", do_overlap=True, - prediction_channels=3, divide=255, small_size=model.hparams.small_size) +def full_inference( + # AFFINITY PREDICTION ARGUMENTS: + img: Union[np.ndarray, zarr.Array], + model_path: str, + aff_zarr_path: str = "aff_prediction.zarr", + small_size: int = 128, + do_overlap: bool = True, + prediction_channels: int = 6, + divide: int = 1, + chunk_cube_size: int = 1024, + compute_backend: str = "local", + # POSTPROCESSING ARGUMENTS: + postprocessing_type: str = "thresholding", + thr: float = 0.5, + seg_zarr_path: str = "seg_prediction.zarr" +): - print(stats) + aff = predict_aff( + img, + model_path=model_path, + zarr_path=aff_zarr_path, + small_size=small_size, + do_overlap=do_overlap, + prediction_channels=prediction_channels, + divide=divide, + chunk_cube_size=chunk_cube_size, + compute_backend=compute_backend + ) + + if postprocessing_type == "thresholding": + seg = compute_connected_component_segmentation(aff[:3] > thr) + zarr.array(seg, store=seg_zarr_path) + elif postprocessing_type == "mws": + raise NotImplementedError(f"Mutex Watershed is not implemented") + else: + raise NotImplementedError(f"Postprocessing type {postprocessing_type} is not implemented") -if __name__ == "__main__": - test_slurm_prediction() \ No newline at end of file From d9f26706c5cb63df262060df8252ad0a87625eae Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Fri, 29 Aug 2025 14:39:33 +0200 Subject: [PATCH 24/33] testing memory-efficient segmentation --- .gitignore | 2 + debug_parched_inference_copy.py | 1032 +++++++++++++++++++++++++++++++ debug_progress.py | 40 ++ debug_retrain.py | 162 +++++ debug_test_inference.py | 63 ++ inference.py | 263 ++++++-- 6 files changed, 1511 insertions(+), 51 deletions(-) create mode 100644 debug_parched_inference_copy.py create mode 100644 debug_progress.py create mode 100644 debug_retrain.py create mode 100644 debug_test_inference.py diff --git a/.gitignore b/.gitignore index 93d2b8d..33ea081 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ slurm-*.out +*.zarr # Byte-compiled / optimized / DLL files __pycache__/ @@ -162,3 +163,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ + diff --git a/debug_parched_inference_copy.py b/debug_parched_inference_copy.py new file mode 100644 index 0000000..ad8852c --- /dev/null +++ b/debug_parched_inference_copy.py @@ -0,0 +1,1032 @@ +import os +import pickle +import shutil +import time +from collections import Counter +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from datetime import timedelta +from functools import partial +import hydra +from omegaconf import DictConfig, open_dict +import joblib +import itertools +import gc + +import cc3d +import configargparse +import fastremap +import mwatershed +import numpy as np +import pandas as pd +import zarr +from dask import config as dask_cfg +from dask_jobqueue import SLURMCluster +from dask.distributed import Client, LocalCluster, as_completed +from tqdm import tqdm +from numba import jit +from datetime import datetime +import psutil, os + + +from metrics import compute_metrics + +# this only changes the configuration in the local process, and not subprocesses (like remote workers) +dask_cfg.set( + { + "distributed.scheduler.worker-ttl": "1h", # required because mwatershed blocks for a long time + "distributed.comm.timeouts.connect": "1h", + "distributed.comm.timeouts.tcp": "1h", + "distributed.admin.tick.limit": "30s" # increase time before triggering a warning (default limit of 3s remains in workers - https://github.com/dask/distributed/issues/3882) + } +) + + +def chunk_list(list_to_chunk, chunk_size): # todo:use zarr bag instead? + return [ + list_to_chunk[i: i + chunk_size] + for i in range(0, len(list_to_chunk), chunk_size) + ] + + +def patched_thresholding(aff, conf): + """ + Creates a segmentation from an affinity map. + Segmentation is created patchwise using thresholding or mutex watershed, + and subsequently merging segments that span multiple patches. + """ + ijk_to_idx, patch_to_coords = get_mappings(aff, conf.patch_size) + + # Predict segments for all patches + if conf.debug_patched_seg_path: + print(f"Using patched segmentation from {conf.debug_patched_seg_path}") + patched_seg = zarr.open(conf.debug_patched_seg_path, mode="r") + else: + patched_seg = segment_patches( + aff, + conf, + patch_to_coords, + ijk_to_idx + ) + + # Agglomerate segments at the edges of neighboring patches + if conf.debug_fragment_agglomeration_path: + print(f"Using fragment agglomeration from {conf.debug_fragment_agglomeration_path}") + if not os.path.normpath(conf.debug_fragment_agglomeration_path) == os.path.normpath(f"{conf.path_root}/agglo_pkl_chunks"): + print(f"CONF PATH {conf.debug_fragment_agglomeration_path} NOT EQUAL TO {conf.path_root}/agglo_pkl_chunks - MIGHT CAUSE PROBLEMS") + else: + fragment_agglomeration_flattened = None + print(f"NOT LOADING PATH {conf.debug_fragment_agglomeration_path} - WILL DO IT IN THE WORKERS") + fragment_agglomeration_flattened = None + else: + fragment_agglomeration_flattened = compute_fragment_agglomeration( + patched_seg, + aff, + conf, + ijk_to_idx, + patch_to_coords, + ) + + # Unify indexing of all patches, including the merged agglomerations at the border + if conf.debug_relabeled_seg_path: + print(f"Using relabeled agglomerated segmentation from {conf.debug_relabeled_seg_path}") + agglomerated_seg = zarr.open(conf.debug_relabeled_seg_path, mode="r") + else: + agglomerated_seg = relabel_globally( + fragment_agglomeration_flattened, + patched_seg, + aff, + conf, + patch_to_coords, + ijk_to_idx, + ) + + # Filter out segments that are too small + filtered_seg = size_filter_relabel(agglomerated_seg, conf) + + # Delete intermediary files that are not needed anymore + if conf.delete_files: + try: + zarr.DirectoryStore(f"{conf.path_root}/patched_seg.zarr").rmdir() + os.rmdir(f"{conf.path_root}/agglo_pkl_chunks") + zarr.DirectoryStore(f"{conf.path_root}/agglomerated_seg.zarr").rmdir() + os.remove(f"{conf.path_root}/id_mapping.csv") + shutil.rmtree(conf.path_root + "/dask-worker-space/", ignore_errors=True) + except: + print("Exception while deleting files.") + + return filtered_seg + + +def get_mappings(aff, patch_size): + """ + Returns coordinates of patches, and their indices. + """ + # x,y,z: coordinates + # i,j,k: patch indices (i.e. i*n <= x < (i+1)*n) + # idx: patch index (i.e. i * len(ys) * len(zs) + j * len(zs) + k) + + xs = list(range(0, aff.shape[1], patch_size)) + ys = list(range(0, aff.shape[2], patch_size)) + zs = list(range(0, aff.shape[3], patch_size)) + + ijk_to_idx = { + (i, j, k): i * len(ys) * len(zs) + j * len(zs) + k + for i in range(len(xs)) + for j in range(len(ys)) + for k in range(len(zs)) + } + + patch_to_coords = { + (i, j, k): ( + (xs[i], xs[i + 1] if i + 1 < len(xs) else None), + (ys[j], ys[j + 1] if j + 1 < len(ys) else None), + (zs[k], zs[k + 1] if k + 1 < len(zs) else None), + ) + for i in range(len(xs)) + for j in range(len(ys)) + for k in range(len(zs)) + } + + return ijk_to_idx, patch_to_coords + + +def segment_patches(aff, conf, patch_to_coords, ijk_to_idx): + """ + Predict segmentation for each patch independently. + Returns segmentation of shape (len(patch_to_coords), patch_size + overlap, patch_size + overlap, patch_size + overlap) + """ + + print(f"Computing patch segmentation...") + print(f"Store: {conf.path_root}/patched_seg.zarr") + patched_seg = zarr.zeros( + ( + len(patch_to_coords), + conf.patch_size + conf.overlap, + conf.patch_size + conf.overlap, + conf.patch_size + conf.overlap, + ), + chunks=(1, conf.patch_size + conf.overlap, conf.patch_size + conf.overlap, conf.patch_size + conf.overlap), + dtype=np.uint32, + store=f"{conf.path_root}/patched_seg.zarr", + ) + print(patched_seg.shape) + + ijks_chunked = chunk_list( + list(patch_to_coords.keys()), + chunk_size=1 # faster than bigger chunks + ) + + def chunked_thresholding(ijks): + for ijk in ijks: + threshold_ijk(ijk, aff, patched_seg, patch_to_coords, ijk_to_idx, conf) + + if not conf.use_parallelization: + result = list(map(chunked_thresholding, tqdm(ijks_chunked))) + # we don't need the result, the iteration writes the segmentation directly into the zarr array + else: + if conf.use_slurm: + cluster = SLURMCluster( + cores=8, + memory="500GB", + processes=1, + worker_extra_args=["--resources processes=1"], + log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/segment/", + walltime="3:00:00" # default is 30mins and then worker gets killed, chunked ijks can take more time + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + + else: + cluster = LocalCluster( + n_workers=min(os.cpu_count(), 16), + threads_per_worker=1, + local_directory=conf.path_root + "/dask-worker-space/" + # to avoid independent runs deleting each other's directories + ) + + with Client(cluster) as client: + print("Dask threshold Client Dashboard:", client.dashboard_link) + + start_time = time.time() + futures = client.map( + chunked_thresholding, + ijks_chunked, + batch_size=1, + resources={"processes": 1}, + ) + for future in tqdm(as_completed(futures), total=len(ijks_chunked), smoothing=0): + future.release() + pass # tqdm progress bar is nicer and shows remaining time + print(f"Computing patch segmentations took {timedelta(seconds=int(time.time() - start_time))}") + #cluster.close() + + return patched_seg + + +def threshold_ijk(ijk, aff, patched_seg, patch_to_coords, ijk_to_idx, conf): + """ + Predicts segmentation from the affinities for one patch, and writes the result into patched_seg. + """ + i, j, k = ijk + #dask.distributed.print(f"processing: {ijk_to_idx[i, j, k]}") + ((x_start, x_end), (y_start, y_end), (z_start, z_end)) = patch_to_coords[(i, j, k)] + cur_aff = aff[ + :, + max(0, x_start - conf.overlap - conf.surrounding): ( + x_end + conf.surrounding if x_end is not None else None), + max(0, y_start - conf.overlap - conf.surrounding): ( + y_end + conf.surrounding if y_end is not None else None), + max(0, z_start - conf.overlap - conf.surrounding): ( + z_end + conf.surrounding if z_end is not None else None), + ] + + cur_aff[np.isnan(cur_aff)] = 0.0 + cur_aff = np.clip(cur_aff, 0.0, 1.0) # todo: enforce clip + not nan + not inf in aff inference + + # extend on all cut off sides + cur_aff_tmp = cur_aff + cur_aff = np.zeros( + ( + aff.shape[0], + conf.patch_size + (conf.overlap + 2 * conf.surrounding), + conf.patch_size + (conf.overlap + 2 * conf.surrounding), + conf.patch_size + (conf.overlap + 2 * conf.surrounding), + ) + ) + + x_start_tmp = (conf.overlap + conf.surrounding) if x_start == 0 else 0 + y_start_tmp = (conf.overlap + conf.surrounding) if y_start == 0 else 0 + z_start_tmp = (conf.overlap + conf.surrounding) if z_start == 0 else 0 + cur_aff[ + :, + x_start_tmp: x_start_tmp + cur_aff_tmp.shape[1], + y_start_tmp: y_start_tmp + cur_aff_tmp.shape[2], + z_start_tmp: z_start_tmp + cur_aff_tmp.shape[3], + ] = cur_aff_tmp + + if conf.mws: + cur_aff = deepcopy(cur_aff).astype(np.float64) + cur_aff[:3] += conf.mws_bias_short + cur_aff[3:] += conf.mws_bias_long if conf.mws_bias_long is not None else 0.0 + + cur_aff[:3] = np.clip(cur_aff[:3], 0, 1) + cur_aff[3:] = np.clip(cur_aff[3:], -1, 0) + + mws_pred = mwatershed.agglom( + affinities=cur_aff if conf.mws_bias_long is not None else cur_aff[:3], + offsets=( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [conf.long_range, 0, 0], + [0, conf.long_range, 0], + [0, 0, conf.long_range], + ] + if conf.mws_bias_long is not None + else [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + ), + ) + + # mwatershed is wasteful with IDs (not contiguous) -> filter out single voxel objects and relabel again + # size filter. single voxel objects are irrelevant for merging, take ~95% of IDs in an example cube, causing OOM when creating fragment_agglomeration + dusted = cc3d.dust( # does a cc first (reducing false mergers in add_to_agglomeration) + mws_pred, + threshold=2, + connectivity=6, + in_place=False, + ) + # relabeling to save IDs + pred_relabeled, N = cc3d.connected_components( + dusted, return_N=True, connectivity=6 + ) + + assert (pred_relabeled[mws_pred == 0] == 0).all() # 0 stays 0 + assert N <= np.iinfo(np.uint32).max + + pred = pred_relabeled.astype(np.uint32) + + else: + pred = conn_comps(cur_aff >= conf.thr) + + pred_no_surrounding = ( + pred[ + conf.surrounding:-conf.surrounding, + conf.surrounding:-conf.surrounding, + conf.surrounding:-conf.surrounding, + ] + if conf.surrounding > 0 + else pred + ) + patched_seg[ijk_to_idx[i, j, k]] = pred_no_surrounding + print(f"processed: {ijk_to_idx[i, j, k]}") + return + +@jit(nopython=True) +def conn_comps(hard_aff): + visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.bool_) + seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) + cur_id = 1 + cur_id_used = False + for i in range(visited.shape[0]): + for j in range(visited.shape[1]): + for k in range(visited.shape[2]): + if hard_aff[ + :, i, j, k + ].any() and not visited[i, j, k]: # if foreground + cur_to_visit = [(i, j, k)] # todo: use 3 array.array instead? or np.array and append? + visited[i, j, k] = True + while len(cur_to_visit) > 0: + x, y, z = cur_to_visit.pop() + # if not visited[x, y, z]: + # visited[x, y, z] = True + seg[x, y, z] = cur_id + cur_id_used = True + if x + 1 < visited.shape[0] and hard_aff[0, x, y, z] and not visited[x + 1, y, z]: + cur_to_visit.append((x + 1, y, z)) + visited[x + 1, y, z] = True + if y + 1 < visited.shape[1] and hard_aff[1, x, y, z] and not visited[x, y + 1, z]: + cur_to_visit.append((x, y + 1, z)) + visited[x, y + 1, z] = True + if z + 1 < visited.shape[2] and hard_aff[2, x, y, z] and not visited[x, y, z + 1]: + cur_to_visit.append((x, y, z + 1)) + visited[x, y, z + 1] = True + if x - 1 >= 0 and hard_aff[0, x - 1, y, z] and not visited[x - 1, y, z]: + cur_to_visit.append((x - 1, y, z)) + visited[x - 1, y, z] = True + if y - 1 >= 0 and hard_aff[1, x, y - 1, z] and not visited[x, y - 1, z]: + cur_to_visit.append((x, y - 1, z)) + visited[x, y - 1, z] = True + if z - 1 >= 0 and hard_aff[2, x, y, z - 1] and not visited[x, y, z - 1]: + cur_to_visit.append((x, y, z - 1)) + visited[x, y, z - 1] = True + if cur_id_used: + cur_id += 1 + cur_id_used = False + return seg + + + + +def compute_fragment_agglomeration( + patched_seg, + aff, + conf, + ijk_to_idx, + patch_to_coords, +): + """ + From the patched segmentation, merges fragments at the border of adjacent patches. + Computes flattened agglomeration, a dict with the keys (i, j, k, idx), where (i, j, k) is a patch and idx is an id in that patch, and values of a global id for this fragment. + """ + + print("Computing fragment agglomeration...") + data_chunked = list( + enumerate(chunk_list(list(patch_to_coords.items()), chunk_size=8)) + ) + + if conf.use_slurm: + cluster = SLURMCluster( + cores=16, + memory="500GB", + processes=1, + log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/agglo/", + walltime="3:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + else: + cluster = LocalCluster( + n_workers=min(os.cpu_count(), 16), + threads_per_worker=1, + local_directory=conf.path_root + "/dask-worker-space/" + ) + + with Client(cluster) as client: + print("Dask Client Dashboard:", client.dashboard_link) + + start_time = time.time() + futures = client.map( + partial( + compute_agglomeration_part, + patched_seg=patched_seg, + ijk_to_idx=ijk_to_idx, + conf=conf, + aff=aff + ), + data_chunked + ) + + fragment_agglomeration = {} + # agglomerate all fragments from all chunks as they complete + for future, frag_aggl in tqdm(as_completed(futures, with_results=True), total=len(data_chunked), smoothing=0): + for k, v in frag_aggl.items(): + fragment_agglomeration.setdefault(k, set()).update(v) + + print(f"Computing fragment agglomeration in patches took {timedelta(seconds=time.time() - start_time)}") + #cluster.close() + + fragment_agglomeration_flattened = flatten_agglomeration(fragment_agglomeration, f"{conf.path_root}/agglo_pkl_chunks") + + return fragment_agglomeration_flattened + + +def compute_agglomeration_part( + idx_samples, + patched_seg, + ijk_to_idx, + conf, + aff +): + """ + Merges neighboring voxels from different cubes, creates a graph with connected fragments. + Args: + idx_samples: idx of the chunk, chunked patch indices + + Returns: + fragment_agglomeration: Dictionary representing a graph between vertices (i, j, k, idx) + where (i, j, k) is a patch index, and idx a fragment id in this patch. + An edge means the fragments should be merged. + """ + print("Entered compute_agglomeration_part", flush=True) + idx, samples = idx_samples + fragment_agglomeration = {} + for sample in samples: + print(f"{datetime.now()}: computing {sample}", flush=True) + (i, j, k), ((x_start, x_end), (y_start, y_end), (z_start, z_end)) = sample + + # for x,y,z get the last slice of the current cube (l, low) and the first slice of the next cube (h, high) + # they overlap, the voxels should have the same id + + if x_end is not None: + if conf.overlap > 0: + result_l = patched_seg[ijk_to_idx[i, j, k], -conf.overlap:] + result_h = patched_seg[ijk_to_idx[i + 1, j, k], :conf.overlap] + uniques = compute_uniques(conf.do_overlap_filter, result_h, result_l, min_overlap=conf.min_overlap) + else: + # merge according to short range affinities between each pair of IDs in neighboring cubes + cur_aff = ( + aff[0, x_end - 1: x_end, y_start:y_end, z_start:z_end] >= conf.merge_thr + ) + # todo: this simple thresholding can re-introduce catastrophic mergers. + # tune thr? use mean instead of max? better use overlap strategy? + + result_l = patched_seg[ijk_to_idx[i, j, k], -1:][ + : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] + ][cur_aff] + result_h = patched_seg[ijk_to_idx[i + 1, j, k], :1][ + : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] + ][cur_aff] + combined = np.stack([result_l, result_h]).T + uniques = np.unique(combined, axis=0) + + for id_l, id_h in uniques: + if id_l > 0 and id_h > 0: + fragment_agglomeration.setdefault((i + 1, j, k, id_h), set()).add( + (i, j, k, id_l) + ) + fragment_agglomeration.setdefault((i, j, k, id_l), set()).add( + (i + 1, j, k, id_h) + ) + + if y_end is not None: + if conf.overlap > 0: + result_l = patched_seg[ijk_to_idx[i, j, k], :, -conf.overlap:] + result_h = patched_seg[ijk_to_idx[i, j + 1, k], :, :conf.overlap] + uniques = compute_uniques(conf.do_overlap_filter, result_h, result_l, min_overlap=conf.min_overlap) + else: + cur_aff = ( + aff[1, x_start:x_end, y_end - 1: y_end, z_start:z_end] >= conf.merge_thr + ) + result_l = patched_seg[ijk_to_idx[i, j, k], :, -1:][ + : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] + ][cur_aff] + result_h = patched_seg[ijk_to_idx[i, j + 1, k], :, :1][ + : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] + ][cur_aff] + combined = np.stack([result_l, result_h]).T + uniques = np.unique(combined, axis=0) + + for id_l, id_h in uniques: + if id_l > 0 and id_h > 0: + fragment_agglomeration.setdefault((i, j + 1, k, id_h), set()).add( + (i, j, k, id_l) + ) + fragment_agglomeration.setdefault((i, j, k, id_l), set()).add( + (i, j + 1, k, id_h) + ) + + if z_end is not None: + if conf.overlap > 0: + result_l = patched_seg[ijk_to_idx[i, j, k], :, :, -conf.overlap:] + result_h = patched_seg[ijk_to_idx[i, j, k + 1], :, :, :conf.overlap] + uniques = compute_uniques(conf.do_overlap_filter, result_h, result_l, min_overlap=conf.min_overlap) + else: + cur_aff = ( + aff[2, x_start:x_end, y_start:y_end, z_end - 1: z_end] >= conf.merge_thr + ) + result_l = patched_seg[ijk_to_idx[i, j, k], :, :, -1:][ + : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] + ][cur_aff] + result_h = patched_seg[ijk_to_idx[i, j, k + 1], :, :, :1][ + : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] + ][cur_aff] + combined = np.stack([result_l, result_h]).T + uniques = np.unique(combined, axis=0) + + for id_l, id_h in uniques: + if id_l > 0 and id_h > 0: + fragment_agglomeration.setdefault((i, j, k + 1, id_h), set()).add( + (i, j, k, id_l) + ) + fragment_agglomeration.setdefault((i, j, k, id_l), set()).add( + (i, j, k + 1, id_h) + ) + print(f"{datetime.now()}: done", flush=True) + return fragment_agglomeration + + +def compute_uniques(do_overlap_filter, result_h, result_l, min_overlap=0.9): + if do_overlap_filter: + result_l_ccs = cc3d.connected_components(result_l, connectivity=6) + result_h_ccs = cc3d.connected_components(result_h, connectivity=6) + + l_ccs_to_l = np.unique( + np.stack([result_l_ccs.flatten(), result_l.flatten()]), axis=1 + ) + l_ccs_to_l = {l_ccs: l for l_ccs, l in l_ccs_to_l.T} + + h_ccs_to_h = np.unique( + np.stack([result_h_ccs.flatten(), result_h.flatten()]), axis=1 + ) + h_ccs_to_h = {h_ccs: h for h_ccs, h in h_ccs_to_h.T} + + combined_ccs = np.stack([result_l_ccs.flatten(), result_h_ccs.flatten()]) + uniques_ccs, counts_ccs = np.unique(combined_ccs, axis=1, return_counts=True) + uniques_ccs = uniques_ccs.T + # uniques_ccs = exact_overlap_filter(uniques_ccs) + uniques_ccs = mutual_largest_overlap_filter( + counts_ccs, uniques_ccs, + min_overlap=min_overlap + ) + + uniques = [ + (l_ccs_to_l[l_ccs], h_ccs_to_h[h_ccs]) for l_ccs, h_ccs in uniques_ccs + ] + + else: + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques, counts = np.unique(combined, axis=0, return_counts=True) + return uniques + + +def exact_overlap_filter(uniques): + # keep only non-zero IDs with mutually exact corresponding count to reduce merge errors + + l_partners = {} + h_partners = {} + for id_l, id_h in uniques: + l_partners.setdefault(id_l, []).append(id_h) + h_partners.setdefault(id_h, []).append(id_l) + + uniques = [ + (id_l, id_h) + for id_l, id_h in uniques + if ( + (len(l_partners[id_l]) == len(h_partners[id_h]) == 1) + and id_l != 0 + and id_h != 0 + ) + ] + return uniques + + +def mutual_largest_overlap_filter( + counts, + uniques, + min_overlap=0.5, + # 1.0: exact overlap, 0.0: any overlap, 0.5: at least half of total count +): + # keep only non-zero IDs with mutually largest corresponding count to reduce merge errors + + # todo: Merge only perfect matches? i.e. except for 0 there are no other IDs in the uniques (ignore counts) + + highest_count_l = {} + highest_count_h = {} + + total_count_l = {} + total_count_h = {} + + for (id_l, id_h), count in zip(uniques, counts): + total_count_l[id_l] = total_count_l.get(id_l, 0) + count + total_count_h[id_h] = total_count_h.get(id_h, 0) + count + + # if id_l > 0 and id_h > 0: don't filter background here: if there is more overlap with background than with another ID, it should not be merged + cur_highest_count_l, cur_highest_id_l = highest_count_l.setdefault( + id_l, (-1, -1) + ) + cur_highest_count_h, cur_highest_id_h = highest_count_h.setdefault( + id_h, (-1, -1) + ) + + if count > cur_highest_count_l: + highest_count_l[id_l] = (count, id_h) + if count > cur_highest_count_h: + highest_count_h[id_h] = (count, id_l) + # uniques = [(id_l, id_h) for id_l, (count, id_h) in highest_count_l.items()] + [ + # (id_l, id_h) for id_h, (count, id_l) in highest_count_h.items() + # ] # for non ccs case: but snakes get split because only 1 assignment per ID but should be several + + uniques = [ + (id_l, id_h) + for id_l, (count, id_h) in highest_count_l.items() + if highest_count_h[id_h][1] == id_l + and count >= min_overlap * total_count_l[id_l] + and count >= min_overlap * total_count_h[id_h] + and id_l != 0 + and id_h != 0 + and count >= 2 # single voxel branches get split + ] + + return np.array(uniques) + + +def flatten_agglomeration(fragment_agglomeration, output_dir): + """ + Computes connected components in the fragment agglomeration graph, and relabels the fragments with ids starting from 1. + Args: + fragment_agglomeration: dictionary with keys (i, j, k, id) indicating cube (i, j, k) and component id in that cube, and values a set of (i, j, k, id) in other cubes that should be connected + Returns: + fragment_agglomeration_flattened: dictionary with keys (i, j, k, id) and values the global component index + """ + cur_id = 1 + fragment_agglomeration_flattened = dict() + fragment_agglomeration_final = dict() + flattened_ids = set() + chunk_n = 0 + os.makedirs(output_dir, exist_ok=True) + for position_id in tqdm(fragment_agglomeration): # (i, j, k, idx) = position_id + if position_id not in flattened_ids: + to_visit = {position_id} + visited = set() + while len(to_visit) > 0: + current = to_visit.pop() + if current not in visited: + visited.add(current) + for neighbor in fragment_agglomeration[current]: + to_visit.add(neighbor) + for v in visited: + assert v not in fragment_agglomeration_flattened + fragment_agglomeration_flattened[v] = cur_id + flattened_ids.add(v) + if len(fragment_agglomeration_flattened) >= 10_000_000: + file_path = os.path.join(output_dir, f"chunk_{chunk_n:02}.pkl") + with open(file_path, "wb") as f: + pickle.dump(fragment_agglomeration_flattened, f) + print(f"Saved {len(fragment_agglomeration_flattened)} items to {file_path}") + fragment_agglomeration_final.update(fragment_agglomeration_flattened) + fragment_agglomeration_flattened = dict() + chunk_n += 1 + cur_id += 1 + + if fragment_agglomeration_flattened: + file_path = os.path.join(output_dir, f"chunk_{chunk_n:02}.pkl") + with open(file_path, "wb") as f: + pickle.dump(fragment_agglomeration_flattened, f) + print(f"Saved final {len(fragment_agglomeration_flattened)} items to {file_path}") + fragment_agglomeration_final.update(fragment_agglomeration_flattened) + + return fragment_agglomeration_final + + +def relabel_cube_batched_wrapped(kwargs): + return relabel_cube_batched(**kwargs) + + +def relabel_globally( + fragment_agglomeration_flattened, + patched_seg, + aff, + conf, + patch_to_coords, + ijk_to_idx, +): + """ + Unite indexing within multiple cubes - if an object spans multiple cubes, it should have the same index everywhere + """ + print(f"Global relabeling...") + print(f"Agglomerated segmentation: {conf.path_root}/agglomerated_seg.zarr") + agglomerated_seg = zarr.zeros( + (aff.shape[1:]), + chunks=(conf.patch_size, conf.patch_size, conf.patch_size), + dtype=np.uint64, # cheap because of zarr compression + store=f"{conf.path_root}/agglomerated_seg.zarr", + ) + + cubes = list(patch_to_coords.items()) + cubes_batched = chunk_list(cubes, 100) + + if not conf.use_parallelization: + result = list(tqdm(map( + partial( + relabel_cube, + patched_seg=patched_seg, + fragment_agglomeration_flattened=fragment_agglomeration_flattened, + ijk_to_idx=ijk_to_idx, + agglomerated_seg=agglomerated_seg, + conf=conf + ), + cubes), total=len(cubes))) + else: + if not conf.use_slurm: + with ThreadPoolExecutor(max_workers=32) as executor: + result = list(tqdm(executor.map( + partial( + relabel_cube, + patched_seg=patched_seg, + fragment_agglomeration_flattened=fragment_agglomeration_flattened, + ijk_to_idx=ijk_to_idx, + agglomerated_seg=agglomerated_seg, + conf=conf + ), + cubes, + chunksize=8 + ), total=len(cubes), smoothing=0)) + else: + cluster = SLURMCluster( + cores=32, + memory="500GB", + processes=1, + worker_extra_args=["--resources", "processes=1"], + log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/relabel/", + walltime="24:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + + with Client(cluster) as client: + print("Dask relabeling Client Dashboard:", client.dashboard_link) + + start_time = time.time() + print(len(ijk_to_idx)) + with open(f"{conf.path_root}/ijk_to_idx.pkl", "wb") as f: + pickle.dump(ijk_to_idx, f) + #ijk_to_idx_future = client.scatter(ijk_to_idx, broadcast=True) + #print(f"Broadcasted ijk_to_idx in {timedelta(seconds=int(time.time() - start_time))}") + #print(list(ijk_to_idx.items())[:10]) + configs = [ + { + "cubes": cubes, + "patched_seg": patched_seg, + "fragment_agglomeration_chunks_path": f"{conf.path_root}/agglo_pkl_chunks", + #"ijk_to_idx": ijk_to_idx_future, + "agglomerated_seg": agglomerated_seg, + "conf": conf, + } + for cubes in cubes_batched + ] + futures = client.map( + relabel_cube_batched_wrapped, + configs, + resources={'processes': 1}, + #batch_size=1 + ) + for _ in tqdm(as_completed(futures), total=len(cubes_batched), smoothing=0): + pass # tqdm progress bar + print(f"Relabeling fragments took {timedelta(seconds=int(time.time() - start_time))}") + #cluster.close() + + return agglomerated_seg + + +def relabel_cube_batched(cubes, patched_seg, fragment_agglomeration_chunks_path, agglomerated_seg, conf): + print(f"{datetime.now()}: start relabel_cube_batched", flush=True) + with open(f"{conf.path_root}/ijk_to_idx.pkl", "rb") as f: + ijk_to_idx = pickle.load(f) + print(f"{datetime.now()}: ijk_to_idx loaded", flush=True) + fragment_agglomeration_flattened = dict() + chunk_files = os.listdir(fragment_agglomeration_chunks_path) + for chunk_file in tqdm(chunk_files): + with open(os.path.join(fragment_agglomeration_chunks_path, chunk_file), "rb") as f: + chunk = pickle.load(f) + fragment_agglomeration_flattened.update(chunk) + print(f"{datetime.now()}: fragment_agglomeration_flattened loaded {len(fragment_agglomeration_flattened)}", flush=True) + + print(f"{datetime.now()}: Relabeling cube chunk", flush=True) + for cube in cubes: + relabel_cube(cube, patched_seg, fragment_agglomeration_flattened, ijk_to_idx, agglomerated_seg, conf) + print(f"{datetime.now()}: End relabeling cube chunk", flush=True) + + +def relabel_cube(cube, patched_seg, fragment_agglomeration_flattened, ijk_to_idx, agglomerated_seg, conf): + """ + If an object spans multiple cubes, relabel the indices to be the same + """ + (i, j, k), ((x_start, x_end), (y_start, y_end), (z_start, z_end)) = cube + # todo: for dask: fragment_agglomeration_flattened is big, load it in here from disk (once for several items?) + + cube = patched_seg[ijk_to_idx[i, j, k]] + perm = [0] + for idx in range(1, int(cube.max()) + 1): # assuming cube has continuous indices from 0 to max + if (i, j, k, idx) in fragment_agglomeration_flattened: # object (idx) continued in neighboring cube -> already has a unique id + perm.append(fragment_agglomeration_flattened[i, j, k, idx]) + else: # object only in this cube + # use upper 32 bits to indicate cube, lower 32 bits to indicate id + perm.append((ijk_to_idx[i, j, k] + 1) * np.uint64(2 ** 32) + idx) + perm = np.array(perm, dtype=np.uint64) + + relabeled = perm[cube[conf.overlap:, conf.overlap:, conf.overlap:]] + if len(perm) > 1: + print(cube.shape, np.max(cube), (i,j,k), ijk_to_idx[i,j,k]) + print(len(perm), perm[1] if len(perm) > 1 else "out of bounds") + print(np.max(relabeled)) + # can't just use agglomerated_seg[x_start:x_end, y_start:y_end, z_start:z_end] = relabeled + # because chunks at the boundary can be smaller + cur_shape = agglomerated_seg[x_start:x_end, y_start:y_end, z_start:z_end].shape + # this is exactly 1 chunk (chunk-borders) -> no race conditions / overwriting + agglomerated_seg[x_start:x_end, y_start:y_end, z_start:z_end] = relabeled[: cur_shape[0], : cur_shape[1], + : cur_shape[2]] + + +def size_filter_relabel(seg, conf): + """ + Filters out segments that are too small (less than minsize voxels) + and relabels the remaining segments contiguously from 1. + """ + start_time = time.time() + + block_indices = [(i, j, k) for i in range(seg.cdata_shape[0]) for j in range(seg.cdata_shape[1]) for k in + range(seg.cdata_shape[2])] + block_indices_batched = chunk_list(block_indices, 16) + combined_counter = Counter() + + print("Counting occurences of fragments...") + if not conf.use_parallelization: + result = list(tqdm(map( + partial(batched_unique, seg=seg), + block_indices_batched, + ), total=len(block_indices_batched), smoothing=0)) + for counter in tqdm(result, total=len(block_indices_batched), smoothing=0): + combined_counter.update(counter) + else: + if not conf.use_slurm: + with ThreadPoolExecutor(max_workers=8) as executor: + result = list(tqdm(executor.map( + partial(batched_unique, seg=seg), + block_indices_batched, + ), total=len(block_indices_batched), smoothing=0)) + for counter in tqdm(result, total=len(block_indices_batched), smoothing=0): + combined_counter.update(counter) + else: + cluster = SLURMCluster( + cores=32, + memory="800GB", + log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/count/", + processes=1, + worker_extra_args=["--resources processes=1"], + walltime="12:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + with Client(cluster) as client: + print("Dask counting Client Dashboard:", client.dashboard_link) + futures = client.map(batched_unique, block_indices_batched, seg=seg, resources={'processes': 1}) + + for future in (pbar := tqdm(as_completed(futures), total=len(block_indices_batched), smoothing=0)): + counter = future.result() + filtered = {k: v for k, v in counter.items() if v > 10} + combined_counter.update(filtered) + del counter + del future + mem = psutil.Process().memory_full_info() + rss = mem.rss / 1e9 + vms = mem.vms / 1e9 + pbar.set_postfix(rss=f"{rss:.1f} GB", vms=f"{vms:.1f} GB") + gc.collect() + + #cluster.close() + + remaining_ids = {id for id, count in combined_counter.items() if count > conf.minsize} + id_mapping_remaining = {old_id: new_id for new_id, old_id in enumerate(sorted(remaining_ids))} + assert id_mapping_remaining[0] == 0 + with open(f"{conf.path_root}/id_mapping.csv", 'w') as f: + for k, v in id_mapping_remaining.items(): + f.write(f"{k} {v}\n") + + print(f"Store: {conf.path_root}/relabeled_seg.zarr") + relabeled_seg = zarr.zeros( + seg.shape, + chunks=seg.chunks, + dtype=np.uint32, + store=f"{conf.path_root}/relabeled_seg.zarr", + ) + + print("Filtering out small fragments and relabeling contiguously...") + if not conf.use_parallelization: + result = list(tqdm(map(partial(batched_relabel, seg=seg, relabeled_seg=relabeled_seg, conf=conf), + block_indices_batched))) + else: + if not conf.use_slurm: + with ThreadPoolExecutor(max_workers=16) as executor: + result = list(tqdm(executor.map(partial(batched_relabel, seg=seg, relabeled_seg=relabeled_seg, conf=conf), + block_indices_batched))) + else: + cluster = SLURMCluster( + cores=16, + memory="500GB", + log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/filter/", + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + with Client(cluster) as client: + print("Dask relabeling Client Dashboard:", client.dashboard_link) + futures = client.map(partial(batched_relabel, seg=seg, relabeled_seg=relabeled_seg, conf=conf), + block_indices_batched) + for _ in tqdm(as_completed(futures), total=len(block_indices_batched), smoothing=0): + pass + #cluster.close() + print(f"Filtering small fragments and relabeling took {timedelta(seconds=int(time.time() - start_time))}") + return relabeled_seg + + +def batched_unique(block_indices, seg): + print(f"{datetime.now()}: start count", flush=True) + c = Counter() + for idx in block_indices: + chunk_series = pd.Series(seg.blocks[idx].ravel()) + c.update(chunk_series.value_counts().to_dict()) + print(f"{datetime.now()}: end count", flush=True) + return c + + +def batched_relabel(block_indices, seg, relabeled_seg, conf): + mapping = {} + with open(f"{conf.path_root}/id_mapping.csv", 'r') as f: + for line in f: + key, value = line.split() + mapping[int(key)] = int(value) + for block_index in block_indices: + block = seg.blocks[block_index] + masked_block = fastremap.mask_except(block, list(mapping.keys())) + relabeled_block = fastremap.remap(masked_block, mapping) + relabeled_seg.blocks[block_index] = relabeled_block + return None + + +def main(conf): + print(conf) + aff = zarr.open(conf.aff_path, mode="r") + + if len(conf.thresholds) > 0: + thr = conf.thresholds[0] + + path_root = f"{conf.path_base}/{f'thr_{thr}'}/" + print(f"Root path: {path_root}") + os.makedirs(path_root, exist_ok=True) + + with open_dict(conf): + conf.path_root = path_root + conf.thr = thr + + start_time = time.time() + segmentation = patched_thresholding( + aff, + conf + ) + print(f"Patched thresholding took {timedelta(seconds=int(time.time() - start_time))}") + + elif len(conf.mws_biases_short) > 0: + biases = list(itertools.product(conf.mws_biases_short, conf.mws_biases_long)) + for (short, long) in biases: + print(f"SEGMENTATION FOR {short}, {long}") + path_root = f"{conf.path_base}/{f'mws_{short}_{long}'}/" + print(f"Root path: {path_root}") + os.makedirs(path_root, exist_ok=False) + with open_dict(conf): + conf.path_root = path_root + conf.mws_bias_short = short + conf.mws_bias_long = long + + start_time = time.time() + segmentation = patched_thresholding( + aff, + conf + ) + print(f"Patched thresholding took {timedelta(seconds=int(time.time() - start_time))}") + pass + return + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise configargparse.ArgumentTypeError('Boolean value expected.') + + +@hydra.main(config_path=".") +def main_wrapper(conf: DictConfig): + return main(conf) + +if __name__ == "__main__": + main_wrapper() \ No newline at end of file diff --git a/debug_progress.py b/debug_progress.py new file mode 100644 index 0000000..fa69dd1 --- /dev/null +++ b/debug_progress.py @@ -0,0 +1,40 @@ +import dask +from distributed import LocalCluster, progress, as_completed +from tqdm import tqdm +from tqdm.dask import TqdmCallback +from dask import delayed, compute, persist +from dask.distributed import Client +import time + + +# Create some simulated tasks +def work(x): + time.sleep(1) + return x * x + +if __name__ == '__main__': + # Start a local distributed cluster + cluster = LocalCluster(n_workers=1, threads_per_worker=1) + client = Client(cluster) + + print("computing with persist") + tasks = [dask.delayed(work)(i) for i in range(20)] + x = persist(tasks) # start computation in the background + progress(x) + results1 = client.gather(x) + + print("computing with tqdm (doesn't work)") + tasks = [dask.delayed(work)(i) for i in range(20)] + with TqdmCallback(desc="Distributed compute", total=len(tasks), mininterval=0.5): + results = client.compute(tasks, sync=True) + + print("computing with as_completed") + tasks = [dask.delayed(work)(i) for i in range(20)] + futures = client.compute(tasks) + for future in tqdm( + as_completed(futures), + total=len(futures), + smoothing=0, + desc="Predicting chunks" + ): + pass diff --git a/debug_retrain.py b/debug_retrain.py new file mode 100644 index 0000000..a8f6fc3 --- /dev/null +++ b/debug_retrain.py @@ -0,0 +1,162 @@ +import os + +import numpy as np +import torch +from pytorch_lightning import Trainer +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm +import zarr + +from BANIS import BANIS, parse_args + +from data import comp_affinities, load_data + + +def train_model_with_samples(model_path, dataloader): + model = BANIS.load_from_checkpoint(model_path) + + trainer = Trainer( + max_steps=1000, + accelerator="gpu", + devices=1, + ) + trainer.fit(model, dataloader) + + +def compare_model_weights(path_a, path_b): + model_a = BANIS.load_from_checkpoint(path_a) + model_b = BANIS.load_from_checkpoint(path_b) + + state_a = model_a.state_dict() + state_b = model_b.state_dict() + + diffs = {} + total_diff = 0.0 + total_norm = 0.0 + total_cos_sim = 0.0 + biggest_diff = 0 + nonfinite_a = 0 + nonfinite_b = 0 + total_params = 0 + + n_layers = 0 + + for key in tqdm(state_a): + #print(f"getting {key}") + param_a = state_a[key].flatten() + param_b = state_b[key].flatten() + + diff = torch.abs(param_a - param_b).mean().item() + if diff > biggest_diff: + biggest_diff = diff + norm = torch.norm(param_a - param_b).item() + cos_sim = torch.nn.functional.cosine_similarity(param_a.unsqueeze(0), param_b.unsqueeze(0)).item() + + diffs[key] = { + 'mean_abs_diff': diff, + 'l2_norm': norm, + 'cosine_similarity': cos_sim, + } + + total = param_a.numel() + nonfinite_params_a = (~torch.isfinite(param_a)).sum().item() + nonfinite_params_b = (~torch.isfinite(param_b)).sum().item() + total_params += total + nonfinite_a += nonfinite_params_a + nonfinite_b += nonfinite_params_b + + total_diff += diff + total_norm += norm + total_cos_sim += cos_sim + n_layers += 1 + + avg_diff = total_diff / n_layers + avg_norm = total_norm / n_layers + avg_cos_sim = total_cos_sim / n_layers + + return { + 'avg_mean_abs_diff': avg_diff, + 'avg_l2_norm': avg_norm, + 'avg_cosine_similarity': avg_cos_sim, + 'biggest_diff': biggest_diff, + 'nonfinite_a': nonfinite_a, + 'nonfinite_b': nonfinite_b, + 'total_params': total_params, + #'layerwise': diffs + } + + #print(compare_model_weights( + # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=110000.ckpt", + # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=115000.ckpt" + # )) + ## {'avg_mean_abs_diff': nan, 'avg_l2_norm': nan, 'avg_cosine_similarity': nan, 'biggest_diff': 0, 'nonfinite_a': 0, 'nonfinite_b': 62993031, 'total_params': 62993031} + #print(compare_model_weights( + # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=100000.ckpt", + # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=105000.ckpt" + # )) + ## {'avg_mean_abs_diff': 0.017999131043465913, 'avg_l2_norm': 2.9829090611515583, 'avg_cosine_similarity': 0.9655602666255757, 'biggest_diff': 0.06504751741886139} + #print(compare_model_weights( + # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=105000.ckpt", + # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=110000.ckpt" + # )) + ## {'avg_mean_abs_diff': 0.01891954702438203, 'avg_l2_norm': 3.084119017241339, 'avg_cosine_similarity': 0.9619980035223629, 'biggest_diff': 0.09889261424541473} + #print(compare_model_weights( + # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=90000.ckpt", + # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=110000.ckpt" + # )) + ## {'avg_mean_abs_diff': 0.03352864947696027, 'avg_l2_norm': 5.531613261509161, 'avg_cosine_similarity': 0.9207515456647084, 'biggest_diff': 0.1563815325498581} + + +def prepare_good_samples(): + args = parse_args() + train_data, val_data, n_channels = load_data(args) + return train_data + + +class SimpleDataset(Dataset): + def __init__(self, samples): + self.samples = samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = self.samples[idx] + return { + "img": torch.from_numpy(sample["img"]), + "aff": torch.from_numpy(sample["aff"]), + "seg": torch.from_numpy(sample["seg"]), + } + +def prepare_bad_samples(): + samples = [] + runs_root = "/cajal/scratch/projects/misc/zuzur/ss3/" + for run in os.listdir(runs_root): + if run.startswith("debug1GPU-seed"): + for candidate in os.listdir(os.path.join(runs_root, run)): + if candidate.endswith("0_img.zarr"): + img = zarr.open(os.path.join(runs_root, run, candidate)) + seg_name = candidate.replace("img", "seg") + seg = zarr.open(os.path.join(runs_root, run, seg_name)) + aff, _ = comp_affinities(seg[:]) + data = { + "img": img.astype(np.float16), + "seg": seg, + "aff": aff, + } + samples.append(data) + if len(samples) >= 1000: + return SimpleDataset(samples) + return SimpleDataset(samples) + + +if __name__ == "__main__": + good_model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=110000.ckpt" + bad_model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=115000.ckpt" + early_model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=50000.ckpt" + + bad_samples = DataLoader(prepare_bad_samples(), batch_size=1, num_workers=8, shuffle=True, drop_last=True) + good_samples = DataLoader(prepare_good_samples(), batch_size=1, num_workers=8, shuffle=True, drop_last=True) + + train_model_with_samples(good_model_path, good_samples) + train_model_with_samples(good_model_path, bad_samples) diff --git a/debug_test_inference.py b/debug_test_inference.py new file mode 100644 index 0000000..f271e10 --- /dev/null +++ b/debug_test_inference.py @@ -0,0 +1,63 @@ +import zarr + +from inference import measure_stats, predict_aff, full_inference, thresholding + + +def test_local_prediction(): + input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" + img_data = zarr.open(input_path, mode="r")["img"] + + model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" + from BANIS import BANIS + + model = BANIS.load_from_checkpoint(model_path) + + all_stats = {} + + for chunk_cube_size in [200, 400, 512, 750, 1024, 1500, 3000]: + measured_predict_aff = measure_stats(predict_aff) + + result, stats = measured_predict_aff(img_data, model, chunk_cube_size=chunk_cube_size, compute_backend="local", + zarr_path=f"/cajal/scratch/projects/misc/zuzur/test{chunk_cube_size}.zarr", do_overlap=True, + prediction_channels=3, divide=255, small_size=model.hparams.small_size) + + all_stats[chunk_cube_size] = stats + print(f"chunk size {chunk_cube_size}: {stats}") + + print(all_stats) + for (value, stat) in all_stats.items(): + print(f"{value}: {stat}") + + +def test_slurm_prediction(): + input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" + img_data = zarr.open(input_path, mode="r")["img"] + + model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" + from BANIS import BANIS + model = BANIS.load_from_checkpoint(model_path) + + measured_predict_aff = measure_stats(predict_aff) + # only one run - runtime dependent on number of available slurm nodes + result, stats = measured_predict_aff(img_data, model_path=model_path, chunk_cube_size=512, compute_backend="slurm", + zarr_path=f"/cajal/scratch/projects/misc/zuzur/test_slurm.zarr", do_overlap=True, + prediction_channels=3, divide=255, small_size=model.hparams.small_size) + + print(stats) + +def test_full_inference(): + input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" + img_data = zarr.open(input_path, mode="r")["img"] + + model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" + from BANIS import BANIS + model = BANIS.load_from_checkpoint(model_path) + + full_inference(img_data, model_path, thr=0.7685) + +def test_thresholding(): + aff = zarr.open("/cajal/scratch/projects/misc/zuzur/skeleton_recall/rerun_base/dsbase_s0_a0_25-03-20_19-44-08-067760/pred_aff_val_6.zarr/") + thresholding(aff, 0.7685, "test1.zarr", 300, "local") + +if __name__ == "__main__": + test_thresholding() diff --git a/inference.py b/inference.py index 4ffecbf..7a5360e 100644 --- a/inference.py +++ b/inference.py @@ -1,7 +1,9 @@ import shutil from collections import defaultdict +from copy import deepcopy from typing import Union, List, Tuple +import cc3d import numba import numpy as np import torch @@ -15,10 +17,12 @@ from distributed import progress from filelock import FileLock from numba import jit +from numpy.f2py.crackfortran import updatevars from scipy.ndimage import distance_transform_cdt from torch import autocast from torch.nn.functional import sigmoid from tqdm import tqdm +import mwatershed def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: @@ -162,7 +166,7 @@ def predict_aff( f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") print(f"Parameters: cube size {chunk_cube_size}, compute backend {compute_backend}.") - all_patch_coordinates = get_coordinates(img.shape[:3], small_size, do_overlap) + all_patch_coordinates = get_coordinates(img.shape[:3], small_size, overlap = small_size // 2 if do_overlap else 0, last_has_smaller_overlap=True) chunked_patch_coordinates = chunk_xyzs(all_patch_coordinates, chunk_cube_size) z = zarr.open_group(zarr_path + "_tmp", mode='w') @@ -219,60 +223,32 @@ def predict_aff( return zarr.open(zarr_path, mode="r") -def get_coordinates( - shape: Tuple[int, int, int], small_size: int, do_overlap: bool -) -> List[Tuple[int, int, int]]: +def get_coordinates(shape: Tuple[int, int, int], small_size: int, overlap: int = 0, last_has_smaller_overlap: bool = True) -> List[Tuple[int, int, int]]: """ - Get coordinates for cubes to be predicted. - + Get coordinates for smaller patches to process a big cube in memory. Args: - shape: The shape of the input image (x, y, z). + shape: The shape of the input (x, y, z). small_size: The size of the patches. - do_overlap: Whether to perform overlapping predictions. - + overlap: The overlap between patches. The default 0 means no overlap (next patch starts on the next pixel from the previous patch). For half-cube overlap set overlap=small_size//2, for 1-pixel overlap set overlap=1. + last_has_smaller_overlap: If the last patch with the specified size and overlap would exceed the big cube, move the patch so that it ends with the big cube, creating a bigger overlap in this patch. Returns: - List of (x, y, z) coordinates for prediction cubes. + List of (x, y, z) coordinates (starting voxel of a patch) for processing of smaller patches. """ - offsets = [get_offsets(s, small_size) for s in shape] + if overlap < 0 or overlap >= small_size: + raise ValueError(f"Overlap must be between 0 and {small_size}.") + offsets = [get_offsets(s, small_size, small_size-overlap, last_has_smaller_overlap) for s in shape] xyzs = [(x, y, z) for x in offsets[0] for y in offsets[1] for z in offsets[2]] - if do_overlap: # Add shifted cubes (half cube overlap) - offset = small_size // 2 - - xyzs_shifted = [ - set((x + offset, y, z) for x, y, z in xyzs), - set((x, y + offset, z) for x, y, z in xyzs), - set((x, y, z + offset) for x, y, z in xyzs), - set((x + offset, y + offset, z) for x, y, z in xyzs), - set((x + offset, y, z + offset) for x, y, z in xyzs), - set((x, y + offset, z + offset) for x, y, z in xyzs), - set((x + offset, y + offset, z + offset) for x, y, z in xyzs), - ] - xyzs_shifted = set( - (x, y, z) - for s in xyzs_shifted - for x, y, z in s - if x + small_size <= shape[0] - and y + small_size <= shape[1] - and z + small_size <= shape[2] - ) - xyzs = list(set.union(set(xyzs), xyzs_shifted)) return xyzs -def get_offsets(big_size: int, small_size: int) -> List[int]: - """ - Calculate offsets for image patching. - - Args: - big_size: The size of the whole image. - small_size: The size of the patches. - - Returns: - List of offsets. - """ - offsets = list(range(0, big_size - small_size + 1, small_size)) - if offsets[-1] != big_size - small_size: +def get_offsets(big_size, small_size, step, last_has_smaller_overlap): + offsets = list(range(0, big_size - small_size + 1, step)) + if small_size > big_size: + offsets.append(0) + elif offsets[-1] != big_size - small_size and last_has_smaller_overlap: offsets.append(big_size - small_size) + elif offsets[-1] != big_size - small_size and not last_has_smaller_overlap: + offsets.append(len(offsets) * step) return offsets @@ -376,21 +352,205 @@ def predict_aff_patches_chunked(patch_coordinates, img, model_path, zarr_path, s ] += pred_tmp +def update_fragment_agglomeration(fragment_agglomeration, matching_l, matching_h, chunk_l, chunk_h): + combined = np.stack([matching_l, matching_h]).T + uniques = np.unique(combined, axis=0) + for id_l, id_h in uniques: + if id_l > 0 and id_h > 0: + fragment_agglomeration.setdefault((chunk_h, id_h), set()).add((chunk_l, id_l)) + fragment_agglomeration.setdefault((chunk_l, id_l), set()).add((chunk_h, id_h)) + return fragment_agglomeration + + +def flatten_agglomeration(fragment_agglomeration): + """ + Computes connected components in the fragment agglomeration graph, and assigns the fragments new ids starting from 1. + Args: + fragment_agglomeration: dictionary with keys (chunk_id, fragment_id), and values a set of (chunk_id, fragment_id) in another chunk (cube) that should be connected + Returns: + fragment_agglomeration_flattened: dictionary with keys (chunk_id, fragment_id) and values the global component index + """ + cur_id = 1 + fragment_agglomeration_flattened = dict() + for position_id in tqdm(fragment_agglomeration): # (chunk, idx) = position_id + if position_id not in fragment_agglomeration_flattened: + to_visit = {position_id} + visited = set() + while len(to_visit) > 0: + current = to_visit.pop() + if current not in visited: + visited.add(current) + for neighbor in fragment_agglomeration[current]: + to_visit.add(neighbor) + for v in visited: + assert v not in fragment_agglomeration_flattened + fragment_agglomeration_flattened[v] = cur_id + cur_id += 1 + + return cur_id, fragment_agglomeration_flattened + + +def add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, cur_id, chunks, zarr_path): + z = zarr.open(f"{zarr_path}_tmp/instances_patched") + for i, chunk in enumerate(tqdm(chunks)): + data = z[i, :, :, :] + for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max + if (i, idx) not in fragment_agglomeration_flattened: + fragment_agglomeration_flattened[(i, idx)] = cur_id + cur_id += 1 + return fragment_agglomeration_flattened + + +def thresholding(aff, thr, zarr_path, chunk_cube_size, compute_backend): + chunks = get_coordinates(aff.shape[1:], chunk_cube_size, overlap=0, last_has_smaller_overlap=False) + reverse_chunks = {chunk: i for i, chunk in enumerate(chunks)} + + z_root = zarr.open_group(zarr_path + "_tmp", mode='w') + zarr_chunk_size = min(chunk_cube_size, 512) + z_root.create_dataset('instances_patched', shape=(len(chunks), chunk_cube_size, chunk_cube_size, chunk_cube_size), + chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='i4') + + # SEGMENT AFFINITIES IN CHUNKS THAT FIT IN MEMORY + if compute_backend == "local": + for i, chunk in enumerate(tqdm(chunks)): + x, y, z = chunk + x_end, y_end, z_end = min(x + chunk_cube_size, aff.shape[1]), min(y + chunk_cube_size, aff.shape[2]), min(z + chunk_cube_size, aff.shape[3]) + curr_aff = aff[:3, x : x_end, y : y_end, z : z_end] + curr_seg = compute_connected_component_segmentation(curr_aff > thr) + z_root["instances_patched"][i, : x_end - x, : y_end - y, : z_end - z] = curr_seg + else: + raise NotImplementedError(f"Compute backend {compute_backend} not implemented.") + + # FIND GROUPS OF FRAGMENTS THAT SHOULD BE MERGED BETWEEN CHUNKS + if compute_backend == "local": + fragment_agglomeration = {} + for i, chunk in enumerate(tqdm(chunks)): + x, y, z = chunk + x_end, y_end, z_end = min(x + chunk_cube_size, aff.shape[1]), min(y + chunk_cube_size, aff.shape[2]), min(z + chunk_cube_size, aff.shape[3]) + + # merge according to short range affinities between each pair of IDs in neighboring cubes + if x_end < aff.shape[1]: + chunk_h = reverse_chunks[x + chunk_cube_size, y, z] + border_aff = aff[0, x_end - 1 : x_end, y : y_end, z : z_end] >= thr + matching_ids_l = z_root["instances_patched"][i, -1:, :, :][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] + matching_ids_h = z_root["instances_patched"][chunk_h, -1:, :, :][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] + fragment_agglomeration = update_fragment_agglomeration(fragment_agglomeration, matching_ids_l, matching_ids_h, i, chunk_h) + + if y_end < aff.shape[2]: + chunk_h = reverse_chunks[x, y + chunk_cube_size, z] + border_aff = aff[0, x : x_end, y_end - 1 : y_end, z : z_end] >= thr + matching_ids_l = z_root["instances_patched"][i, :, -1:, :][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] + matching_ids_h = z_root["instances_patched"][chunk_h, :, -1:, :][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] + fragment_agglomeration = update_fragment_agglomeration(fragment_agglomeration, matching_ids_l, matching_ids_h, i, chunk_h) + + if z_end < aff.shape[3]: + chunk_h = reverse_chunks[x, y, z + chunk_cube_size] + border_aff = aff[0, x : x_end, y : y_end, z_end - 1 : z_end] >= thr + matching_ids_l = z_root["instances_patched"][i, :, :, -1:][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] + matching_ids_h = z_root["instances_patched"][chunk_h, :, :, -1:][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] + fragment_agglomeration = update_fragment_agglomeration(fragment_agglomeration, matching_ids_l, matching_ids_h, i, chunk_h) + + curr_id, fragment_agglomeration_flattened = flatten_agglomeration(fragment_agglomeration) + print("MERGING CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) + fragment_agglomeration_flattened = add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, curr_id, chunks, zarr_path) + print("ALL CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) + + else: + raise NotImplementedError(f"Compute backend {compute_backend} not implemented.") + + # MERGE AND RELABEL INSTANCES GLOBALLY + z_final = zarr.create(shape=aff.shape[1:], + chunks=(zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='i4', + store=zarr_path, overwrite=True) + + if compute_backend == "local": + for i, chunk in enumerate(tqdm(chunks)): + x, y, z = chunk + x_end, y_end, z_end = min(x + chunk_cube_size, aff.shape[1]), min(y + chunk_cube_size, aff.shape[2]), min(z + chunk_cube_size, aff.shape[3]) + data = z_root["instances_patched"][i, :x_end, :y_end, :z_end] + perm = [0] + for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max + assert (i, idx) in fragment_agglomeration_flattened # all fragments have a new index (congiguous from 0) + perm.append(fragment_agglomeration_flattened[(i, idx)]) + perm = np.array(perm, dtype=np.uint64) + relabeled = perm[data] + z_final[x : x_end, y : y_end, z : z_end] = relabeled + + else: + raise NotImplementedError(f"Compute backend {compute_backend} not implemented.") + + shutil.rmtree(zarr_path + "_tmp") + + +def compute_mws_segmentation(cur_aff, mws_bias_short, mws_bias_long, long_range=10): + """ + Mutex Watershed segmentation. + Args: + cur_aff: An affinity array with 3 short-range and 3 long-range affinities (size must fit in memory). + mws_bias_short: Short-range bias + mws_bias_long: Long-range bias + Returns: + Segmentation of the affinities. + """ + cur_aff = deepcopy(cur_aff).astype(np.float64) + cur_aff[:3] += mws_bias_short + cur_aff[3:] += mws_bias_long + + cur_aff[:3] = np.clip(cur_aff[:3], 0, 1) # short-range attractive edges + cur_aff[3:] = np.clip(cur_aff[3:], -1, 0) # long-range repulsive edges (see the Mutex Watershed paper) + + mws_pred = mwatershed.agglom( + affinities=cur_aff, + offsets=( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [long_range, 0, 0], + [0, long_range, 0], + [0, 0, long_range], + ] + ), + ) + + # mwatershed is wasteful with IDs (not contiguous) -> filter out single voxel objects and relabel again + # size filter. single voxel objects are irrelevant for merging, take ~95% of IDs in an example cube, causing OOM when creating fragment_agglomeration + dusted = cc3d.dust( # does a cc first (reducing false mergers in add_to_agglomeration) + mws_pred, + threshold=2, + connectivity=6, + in_place=False, + ) + # relabeling to save IDs + pred_relabeled, N = cc3d.connected_components( + dusted, return_N=True, connectivity=6 + ) + + assert (pred_relabeled[mws_pred == 0] == 0).all() # 0 stays 0 + assert N <= np.iinfo(np.uint32).max + + pred = pred_relabeled.astype(np.uint32) + return pred + + def full_inference( + # RESOURCES ARGUMENTS: + chunk_cube_size: int = 1024, + compute_backend: str = "local", # AFFINITY PREDICTION ARGUMENTS: - img: Union[np.ndarray, zarr.Array], - model_path: str, + img: Union[np.ndarray, zarr.Array] = None, + model_path: str = None, aff_zarr_path: str = "aff_prediction.zarr", small_size: int = 128, do_overlap: bool = True, prediction_channels: int = 6, divide: int = 1, - chunk_cube_size: int = 1024, - compute_backend: str = "local", # POSTPROCESSING ARGUMENTS: postprocessing_type: str = "thresholding", + seg_zarr_path: str = "seg_prediction.zarr", thr: float = 0.5, - seg_zarr_path: str = "seg_prediction.zarr" + mws_bias_short: float = -0.5, + mws_bias_long: float = -0.5, ): aff = predict_aff( @@ -409,7 +569,8 @@ def full_inference( seg = compute_connected_component_segmentation(aff[:3] > thr) zarr.array(seg, store=seg_zarr_path) elif postprocessing_type == "mws": - raise NotImplementedError(f"Mutex Watershed is not implemented") + seg = mws(aff, seg_zarr_path, mws_bias_short, mws_bias_long) else: raise NotImplementedError(f"Postprocessing type {postprocessing_type} is not implemented") + print(f"Segmentation saved at {seg_zarr_path}.") From 4ced80f8e5160cab7726bbed80f15d49d76e9213 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zuzana=20Urbanov=C3=A1?= Date: Fri, 29 Aug 2025 18:28:26 +0200 Subject: [PATCH 25/33] class-based inference --- debug_test_inference.py | 5 +- inference2.py | 606 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 610 insertions(+), 1 deletion(-) create mode 100644 inference2.py diff --git a/debug_test_inference.py b/debug_test_inference.py index f271e10..6488285 100644 --- a/debug_test_inference.py +++ b/debug_test_inference.py @@ -1,6 +1,7 @@ import zarr from inference import measure_stats, predict_aff, full_inference, thresholding +from inference2 import Thresholding def test_local_prediction(): @@ -57,7 +58,9 @@ def test_full_inference(): def test_thresholding(): aff = zarr.open("/cajal/scratch/projects/misc/zuzur/skeleton_recall/rerun_base/dsbase_s0_a0_25-03-20_19-44-08-067760/pred_aff_val_6.zarr/") - thresholding(aff, 0.7685, "test1.zarr", 300, "local") + postprocessor = Thresholding(1024, "local", 0.7685) + postprocessor.aff_to_seg(aff, zarr_path="test2.zarr") + #thresholding(aff, 0.7685, "test2.zarr", 1024, "local") if __name__ == "__main__": test_thresholding() diff --git a/inference2.py b/inference2.py new file mode 100644 index 0000000..16588b1 --- /dev/null +++ b/inference2.py @@ -0,0 +1,606 @@ +import shutil +from collections import defaultdict +from copy import deepcopy +from typing import Union, List, Tuple + +import cc3d +import numba +import numpy as np +import torch +import torch.utils +import zarr +import dask +from dask import compute, persist, delayed +from dask.distributed import Client, LocalCluster +from dask.diagnostics import ProgressBar +import dask.array as da +from distributed import progress +from filelock import FileLock +from numba import jit +from numpy.f2py.crackfortran import updatevars +from scipy.ndimage import distance_transform_cdt +from torch import autocast +from torch.nn.functional import sigmoid +from tqdm import tqdm +import mwatershed + + +class Utils: + @staticmethod + def get_coordinates(shape: Tuple[int, int, int], small_size: int, overlap: int = 0, + last_has_smaller_overlap: bool = True) -> List[Tuple[int, int, int]]: + """ + Get coordinates for smaller patches to process a big cube in memory. + Args: + shape: The shape of the input (x, y, z). + small_size: The size of the patches. + overlap: The overlap between patches. The default 0 means no overlap (next patch starts on the next pixel from the previous patch). For half-cube overlap set overlap=small_size//2, for 1-pixel overlap set overlap=1. + last_has_smaller_overlap: If the last patch with the specified size and overlap would exceed the big cube, move the patch so that it ends with the big cube, creating a bigger overlap in this patch. + Returns: + List of (x, y, z) coordinates (starting voxel of a patch) for processing of smaller patches. + """ + if overlap < 0 or overlap >= small_size: + raise ValueError(f"Overlap must be between 0 and {small_size}.") + offsets = [Utils.get_offsets(s, small_size, small_size - overlap, last_has_smaller_overlap) for s in shape] + xyzs = [(x, y, z) for x in offsets[0] for y in offsets[1] for z in offsets[2]] + return xyzs + + @staticmethod + def get_offsets(big_size, small_size, step, last_has_smaller_overlap): + offsets = list(range(0, big_size - small_size + 1, step)) + if small_size > big_size: + offsets.append(0) + elif offsets[-1] != big_size - small_size and last_has_smaller_overlap: + offsets.append(big_size - small_size) + elif offsets[-1] != big_size - small_size and not last_has_smaller_overlap: + offsets.append(len(offsets) * step) + return offsets + + @staticmethod + def chunk_xyzs(xyzs, chunk_cube_size=1024): + """ + Chunks the patch coordinates into chunks containing coordinates from the same part of the big cube. + Args: + xyzs: list of all coordinates + chunk_cube_size: side length of each chunk + Returns: + chunked coordinates + """ + chunks = defaultdict(list) + for x, y, z in xyzs: + chunks[(x // chunk_cube_size, y // chunk_cube_size, z // chunk_cube_size)].append((x, y, z)) + return list(chunks.values()) + + @staticmethod + def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: + """Scale sigmoid to avoid numerical issues in high confidence fp16.""" + return sigmoid(0.2 * x) + + +class AffinityPredictor: + def __init__(self, + chunk_cube_size: int = 1024, + compute_backend: str = "local", + model: torch.nn.Module = None, + model_path: str = None, + small_size: int = 128, + do_overlap: bool = True, + prediction_channels: int = 6, + divide: int = 1, + ): + self.chunk_cube_size = chunk_cube_size + self.compute_backend = compute_backend + + self.model = model # only for local prediction + self.model_path = model_path # loads model in the worker in case of distributed inference (model not pickleable) + self.small_size = small_size + self.do_overlap = do_overlap + self.prediction_channels = prediction_channels + self.divide = divide + + def img_to_aff(self, img, zarr_path): + """ + Complete prediction of affinities from the input image, with the model previously specified in AffinityPredictor. + """ + print(f"Performing patched inference with do_overlap={self.do_overlap} for img of shape {img.shape} and dtype {img.dtype}") + print(f"Parameters: cube size {self.chunk_cube_size}, compute backend {self.compute_backend}.") + + all_patch_coordinates = Utils.get_coordinates(img.shape[:3], self.small_size, overlap=self.small_size // 2 if self.do_overlap else 0, last_has_smaller_overlap=True) + chunked_patch_coordinates = Utils.chunk_xyzs(all_patch_coordinates, self.chunk_cube_size) + + z = zarr.open_group("tmp_" + zarr_path, mode='w') + zarr_chunk_size = min(self.chunk_cube_size, 512) + z.create_dataset('sum_pred', shape=(self.prediction_channels, *img.shape[:3]), chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + + if self.compute_backend == "local": + for chunk in tqdm(chunked_patch_coordinates): + self.predict_aff_patches_chunked(chunk, img, "tmp_" + zarr_path) + torch.cuda.empty_cache() + else: + if self.compute_backend == "local_cluster": + from dask_cuda import LocalCUDACluster + cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU + elif self.compute_backend == "slurm": + from dask_jobqueue import SLURMCluster + cluster = SLURMCluster( + cores=8, + memory="400GB", + processes=1, + worker_extra_args=["--resources processes=1", "--nthreads=1"], + job_extra_directives=["--gres=gpu:1"], + walltime="1-00:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") + + client = Client(cluster) + print(f"Waiting for workers...") + client.wait_for_workers(n_workers=1) + print("Dask Client Dashboard:", client.dashboard_link) + tasks = [dask.delayed(self.predict_aff_patches_chunked)(chunk, img, "tmp_" + zarr_path) for chunk in chunked_patch_coordinates] + futures = persist(tasks) + progress(futures) # progress bar + compute(futures) + + tmp_sum_pred = da.from_zarr(f"tmp_{zarr_path}/sum_pred") + tmp_sum_weight = da.from_zarr(f"tmp_{zarr_path}/sum_weight") + aff = tmp_sum_pred / tmp_sum_weight + aff.to_zarr(zarr_path, overwrite=True) + + shutil.rmtree("tmp_" + zarr_path) + + return + + def predict_aff_patches_chunked(self, patch_coordinates, img, zarr_path): + """ + Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. + Args: + patch_coordinates: List of patch coordinates. The extension of the coordinates must fit in memory (use adequate chunk size). + Returns: + Affinity prediction of the input chunk. + """ + max_x = max(x for x, y, z in patch_coordinates) + max_y = max(y for x, y, z in patch_coordinates) + max_z = max(z for x, y, z in patch_coordinates) + min_x = min(x for x, y, z in patch_coordinates) + min_y = min(y for x, y, z in patch_coordinates) + min_z = min(z for x, y, z in patch_coordinates) + + img_tmp = img[ + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] + pred_tmp = np.zeros((self.prediction_channels, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) + weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) + single_pred_weight = self.get_single_pred_weight(self.do_overlap, self.small_size) + + if not self.model: + from BANIS import BANIS + print(self.model_path, flush=True) + model = BANIS.load_from_checkpoint(self.model_path) + else: + model = self.model + + for x_global, y_global, z_global in patch_coordinates: + x = x_global - min_x + y = y_global - min_y + z = z_global - min_z + img_patch = torch.tensor(np.moveaxis(img_tmp[x: x + self.small_size, y: y + self.small_size, z: z + self.small_size], -1, 0)[None]).to(model.device) / self.divide + pred = Utils.scale_sigmoid(model(img_patch))[0, :self.prediction_channels] + + weight_tmp[:, x: x + self.small_size, y: y + self.small_size, z: z + self.small_size] += single_pred_weight if self.do_overlap else 1 + pred_tmp[:, x: x + self.small_size, y: y + self.small_size, z: z + self.small_size] += pred.detach().cpu().numpy() * (single_pred_weight[None] if self.do_overlap else 1) + + z = zarr.open_group(zarr_path, mode='a') + weight_mask = z['sum_weight'] + full_pred = z['sum_pred'] + + with FileLock(f"{zarr_path}/sum_weight.lock"): + weight_mask[ + :, + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] += weight_tmp + + with FileLock(f"{zarr_path}/sum_pred.lock"): + full_pred[ + :, + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] += pred_tmp + + def get_single_pred_weight(self, do_overlap: bool, small_size: int) -> Union[np.ndarray, None]: + """ + Get the weight for a single prediction. + + Args: + do_overlap: Whether to perform overlapping predictions. + small_size: The size of the patches. + + Returns: + The weight array for a single prediction, or None if no overlap. + """ + if do_overlap: + # The weight (confidence/expected quality) of the predictions: + # Low at the surface of the predicted cube, high in the center + pred_weight_helper = np.pad(np.ones((small_size,) * 3), 1, mode='constant') + return distance_transform_cdt(pred_weight_helper).astype(np.float32)[1:-1, 1:-1, 1:-1] + else: + return None + + +class Postprocessing: + def __init__(self, + chunk_cube_size: int = 1024, + compute_backend: str = "local" + ): + self.chunk_cube_size = chunk_cube_size + self.compute_backend = compute_backend + + def aff_to_seg(self, aff, zarr_path): + chunks = Utils.get_coordinates(aff.shape[1:], self.chunk_cube_size, overlap=1, last_has_smaller_overlap=False) + reverse_chunks = {chunk: i for i, chunk in enumerate(chunks)} + + zarr_chunk_size = min(self.chunk_cube_size, 512) + z_root = zarr.create(shape=(len(chunks), self.chunk_cube_size, self.chunk_cube_size, self.chunk_cube_size), + store="tmp_" + zarr_path, dtype='i4', overwrite=True, + chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) + + # SEGMENT AFFINITIES IN CHUNKS THAT FIT IN MEMORY + self.patched_segment_affinities(aff, "tmp_" + zarr_path, chunks) + + # FIND GROUPS OF FRAGMENTS THAT SHOULD BE MERGED BETWEEN CHUNKS + fragment_agglomeration = self.agglomerate_fragments(chunks, reverse_chunks, zarr_path, aff.shape) + + # MERGE AND RELABEL INSTANCES GLOBALLY + self.merge_and_relabel(fragment_agglomeration, "tmp_" + zarr_path, zarr_path, chunks, aff.shape) + + return + + def patched_segment_affinities(self, aff, zarr_path, chunks): + if self.compute_backend == "local": + for i, chunk in enumerate(tqdm(chunks)): + self.segment_chunk_wrapped(chunk, i, aff, zarr_path) + else: + if self.compute_backend == "local_cluster": + from dask_cuda import LocalCUDACluster + cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU + elif self.compute_backend == "slurm": + from dask_jobqueue import SLURMCluster + cluster = SLURMCluster( + cores=8, + memory="400GB", + processes=1, + worker_extra_args=["--resources processes=1", "--nthreads=1"], + job_extra_directives=["--gres=gpu:1"], + walltime="1-00:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") + + client = Client(cluster) + print(f"Waiting for workers...") + client.wait_for_workers(n_workers=1) + print("Dask Client Dashboard:", client.dashboard_link) + tasks = [dask.delayed(self.segment_chunk_wrapped)(chunk, i, aff, zarr_path) for (i, chunk) in enumerate(chunks)] + futures = persist(tasks) + progress(futures) # progress bar + compute(futures) + + def get_xyz_end(self, chunk, aff_shape): + """ + Returns the end indices of a chunk, that correspond either to the chunk size, or align with the size of the affinities. + """ + x, y, z = chunk + x_end, y_end, z_end = (min(x + self.chunk_cube_size, aff_shape[1]), + min(y + self.chunk_cube_size, aff_shape[2]), + min(z + self.chunk_cube_size, aff_shape[3])) + return (x_end, y_end, z_end) + + def agglomerate_fragments(self, chunks, reverse_chunks, zarr_path, aff_shape): + if self.compute_backend == "local": + fragment_agglomeration = {} + for i, chunk in enumerate(tqdm(chunks)): + chunk_agglomeration = self.agglomerate_chunk(chunk, reverse_chunks, zarr_path, aff_shape) + fragment_agglomeration.update(chunk_agglomeration) + if len(fragment_agglomeration) > 10_000_000: + print("WARNING: fragment agglomeration too long, might cause problems!") + # TODO: solve this + + curr_id, fragment_agglomeration_flattened = self.flatten_agglomeration(fragment_agglomeration) + print("MERGING CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) + fragment_agglomeration_flattened = self.add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, curr_id, chunks, zarr_path) + print("ALL CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) + + else: + # TODO: add slurm (and measure memory) + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") + + return fragment_agglomeration_flattened + + def agglomerate_chunk(self, chunk, reverse_chunks, zarr_path, aff_shape): + fragment_agglomeration = {} + x, y, z = chunk + x_end, y_end, z_end = self.get_xyz_end(chunk, aff_shape) + z_root = zarr.open(zarr_path, mode='r') + + # for (x,y,z) get the last slice of the current cube (l, low) and the first slice of the next cube (h, high) + # these slices overlap, so the voxels should have the same global id + + if x_end < aff_shape[1]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x + self.chunk_cube_size, y, z] + result_l = z_root[chunk_l, -1:, :, :] + result_h = z_root[chunk_h, :1, :, :] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + if y_end < aff_shape[2]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x, y + self.chunk_cube_size, z] + result_l = z_root[chunk_l, :, -1:, :] + result_h = z_root[chunk_h, :, :1, :] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + if z_end < aff_shape[3]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x, y, z + self.chunk_cube_size] + result_l = z_root[chunk_l, :, :, -1:] + result_h = z_root[chunk_h, :, :, :1] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + return fragment_agglomeration + + def update_fragment_agglomeration(self, fragment_agglomeration, uniques, chunk_l, chunk_h): + for id_l, id_h in uniques: + if id_l > 0 and id_h > 0: + fragment_agglomeration.setdefault((chunk_h, id_h), set()).add( + (chunk_l, id_l) + ) + fragment_agglomeration.setdefault((chunk_l, id_l), set()).add( + (chunk_h, id_h) + ) + return fragment_agglomeration + + def flatten_agglomeration(self, fragment_agglomeration): + """ + Computes connected components in the fragment agglomeration graph, and assigns the fragments new ids starting from 1. + Args: + fragment_agglomeration: dictionary with keys (chunk_id, fragment_id), and values a set of (chunk_id, fragment_id) in another chunk (cube) that should be connected + Returns: + fragment_agglomeration_flattened: dictionary with keys (chunk_id, fragment_id) and values the global component index + """ + cur_id = 1 + fragment_agglomeration_flattened = dict() + for position_id in tqdm(fragment_agglomeration): # (chunk, idx) = position_id + if position_id not in fragment_agglomeration_flattened: + to_visit = {position_id} + visited = set() + while len(to_visit) > 0: + current = to_visit.pop() + if current not in visited: + visited.add(current) + for neighbor in fragment_agglomeration[current]: + to_visit.add(neighbor) + for v in visited: + assert v not in fragment_agglomeration_flattened + fragment_agglomeration_flattened[v] = cur_id + cur_id += 1 + + return cur_id, fragment_agglomeration_flattened + + def add_all_fragments_to_agglomeration(self, fragment_agglomeration_flattened, cur_id, chunks, zarr_path): + z_root = zarr.open(zarr_path) + for i, chunk in enumerate(tqdm(chunks)): + data = z_root[i, :, :, :] + for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max + if (i, idx) not in fragment_agglomeration_flattened: + fragment_agglomeration_flattened[(i, idx)] = cur_id + cur_id += 1 + return fragment_agglomeration_flattened + + def merge_and_relabel(self, fragment_agglomeration, zarr_patched, zarr_final, chunks, aff_shape): + zarr_chunk_size = min(self.chunk_cube_size, 512) + z_root = zarr.open(zarr_patched) + z_final = zarr.create(shape=aff_shape[1:], + store=zarr_final, dtype='i4', overwrite=True, + chunks=(zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) + + if self.compute_backend == "local": + for i, chunk in enumerate(tqdm(chunks)): + x, y, z = chunk + x_end, y_end, z_end = self.get_xyz_end(chunk, aff_shape) + data = z_root[i, :x_end, :y_end, :z_end] + perm = [0] + for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max + assert (i, idx) in fragment_agglomeration # all fragments have a new index (contiguous from 0) + perm.append(fragment_agglomeration[(i, idx)]) + perm = np.array(perm, dtype=np.uint64) + relabeled = perm[data] + z_final[x: x_end, y: y_end, z: z_end] = relabeled + + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not implemented.") + + shutil.rmtree(zarr_patched) + + def segment_chunk_wrapped(self, chunk, i, aff, zarr_path): + x, y, z = chunk + x_end, y_end, z_end = self.get_xyz_end(chunk, aff.shape) + curr_aff = aff[:, x : x_end, y : y_end, z : z_end] + curr_seg = self.segment_chunk(curr_aff) + with zarr.open(zarr_path, mode="w", chunks=curr_seg.shape) as z_root: + z_root[i, : x_end - x, : y_end - y, : z_end - z] = curr_seg + + def segment_chunk(self, curr_aff): + """ + In-memory segmentation of a chunk of affinities. + Args: + curr_aff: The affinities to segment (must fit in memory). + Returns: + Segmentation of the given affinities. + """ + raise NotImplementedError(f"This method should be overridden in a subclass.") + + +class MutexWatershed(Postprocessing): + def __init__(self, chunk_cube_size, compute_backend, mws_bias_short, mws_bias_long, long_range=10): + super().__init__(chunk_cube_size, compute_backend) + self.mws_bias_short = mws_bias_short + self.mws_bias_long = mws_bias_long + self.long_range = long_range + + def compute_mws_segmentation(self, cur_aff): + cur_aff = deepcopy(cur_aff).astype(np.float64) + cur_aff[:3] += self.mws_bias_short + cur_aff[3:] += self.mws_bias_long + + cur_aff[:3] = np.clip(cur_aff[:3], 0, 1) # short-range attractive edges + cur_aff[3:] = np.clip(cur_aff[3:], -1, 0) # long-range repulsive edges (see the Mutex Watershed paper) + + mws_pred = mwatershed.agglom( + affinities=cur_aff, + offsets=( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [self.long_range, 0, 0], + [0, self.long_range, 0], + [0, 0, self.long_range], + ] + ), + ) + + # mwatershed is wasteful with IDs (not contiguous) -> filter out single voxel objects and relabel again + # size filter. single voxel objects are irrelevant for merging, take ~95% of IDs in an example cube, causing OOM when creating fragment_agglomeration + dusted = cc3d.dust( # does a cc first (reducing false mergers in add_to_agglomeration) + mws_pred, + threshold=2, + connectivity=6, + in_place=False, + ) + # relabeling to save IDs + pred_relabeled, N = cc3d.connected_components( + dusted, return_N=True, connectivity=6 + ) + + assert (pred_relabeled[mws_pred == 0] == 0).all() # 0 stays 0 + assert N <= np.iinfo(np.uint32).max + + pred = pred_relabeled.astype(np.uint32) + return pred + + def segment_chunk(self, curr_aff): + return self.compute_mws_segmentation(curr_aff) + + + +class Thresholding(Postprocessing): + def __init__(self, chunk_cube_size, compute_backend, thr): + super().__init__(chunk_cube_size, compute_backend) + self.thr = thr + + @jit(nopython=True) + def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: + """ + Compute connected components from affinities. + + Args: + hard_aff: The (thresholded, boolean) short range affinities. Shape: (3, x, y, z). + + Returns: + The segmentation. Shape: (x, y, z). + """ + visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=numba.boolean) + seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) + cur_id = 1 + for i in range(visited.shape[0]): + for j in range(visited.shape[1]): + for k in range(visited.shape[2]): + if hard_aff[:, i, j, k].any() and not visited[i, j, k]: # If foreground + cur_to_visit = [(i, j, k)] + visited[i, j, k] = True + while cur_to_visit: + x, y, z = cur_to_visit.pop() + seg[x, y, z] = cur_id + + # Check all neighbors + if x + 1 < visited.shape[0] and hard_aff[0, x, y, z] and not visited[x + 1, y, z]: + cur_to_visit.append((x + 1, y, z)) + visited[x + 1, y, z] = True + if y + 1 < visited.shape[1] and hard_aff[1, x, y, z] and not visited[x, y + 1, z]: + cur_to_visit.append((x, y + 1, z)) + visited[x, y + 1, z] = True + if z + 1 < visited.shape[2] and hard_aff[2, x, y, z] and not visited[x, y, z + 1]: + cur_to_visit.append((x, y, z + 1)) + visited[x, y, z + 1] = True + if x - 1 >= 0 and hard_aff[0, x - 1, y, z] and not visited[x - 1, y, z]: + cur_to_visit.append((x - 1, y, z)) + visited[x - 1, y, z] = True + if y - 1 >= 0 and hard_aff[1, x, y - 1, z] and not visited[x, y - 1, z]: + cur_to_visit.append((x, y - 1, z)) + visited[x, y - 1, z] = True + if z - 1 >= 0 and hard_aff[2, x, y, z - 1] and not visited[x, y, z - 1]: + cur_to_visit.append((x, y, z - 1)) + visited[x, y, z - 1] = True + cur_id += 1 + return seg + + def segment_chunk(self, curr_aff): + return self.compute_connected_component_segmentation(curr_aff[:3] > self.thr) + + +def full_inference( + # RESOURCES ARGUMENTS: + chunk_cube_size: int = 1024, + compute_backend: str = "local", + # AFFINITY PREDICTION ARGUMENTS: + img: Union[np.ndarray, zarr.Array] = None, + model_path: str = None, + aff_zarr_path: str = "aff_prediction.zarr", + small_size: int = 128, + do_overlap: bool = True, + prediction_channels: int = 6, + divide: int = 1, + # POSTPROCESSING ARGUMENTS: + postprocessing_type: str = "thresholding", + seg_zarr_path: str = "seg_prediction.zarr", + thr: float = 0.5, + mws_bias_short: float = -0.5, + mws_bias_long: float = -0.5, +): + affinity_predictor = AffinityPredictor( + chunk_cube_size=chunk_cube_size, + compute_backend=compute_backend, + model_path=model_path, + small_size=small_size, + do_overlap=do_overlap, + prediction_channels=prediction_channels, + divide=divide, + ) + affinity_predictor.img_to_aff(img, zarr_path=aff_zarr_path) + aff = zarr.open(aff_zarr_path, mode="r") + + if postprocessing_type == "thresholding": + postprocessor = Thresholding(chunk_cube_size, compute_backend, thr) + elif postprocessing_type == "mws": + postprocessor = MutexWatershed(chunk_cube_size, compute_backend, mws_bias_short, mws_bias_long) + else: + raise NotImplementedError(f"Postprocessing type {postprocessing_type} is not implemented") + postprocessor.aff_to_seg(aff, zarr_path=seg_zarr_path) + seg = zarr.open(seg_zarr_path, mode="r") + + print(f"Segmentation saved at {seg_zarr_path}.") From 1c62b4677e1d302c166d3fe01e36e19d12a84b91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zuzana=20Urbanov=C3=A1?= Date: Fri, 29 Aug 2025 18:32:02 +0200 Subject: [PATCH 26/33] fix bug --- inference2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference2.py b/inference2.py index 16588b1..bab0abe 100644 --- a/inference2.py +++ b/inference2.py @@ -514,7 +514,7 @@ def __init__(self, chunk_cube_size, compute_backend, thr): self.thr = thr @jit(nopython=True) - def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: + def compute_connected_component_segmentation(self, hard_aff: np.ndarray) -> np.ndarray: """ Compute connected components from affinities. From 6141a842fb61c6b59ec609557cd63b44be2d4524 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zuzana=20Urbanov=C3=A1?= Date: Fri, 29 Aug 2025 18:37:25 +0200 Subject: [PATCH 27/33] dont jit class --- inference2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/inference2.py b/inference2.py index bab0abe..a2e7473 100644 --- a/inference2.py +++ b/inference2.py @@ -513,8 +513,9 @@ def __init__(self, chunk_cube_size, compute_backend, thr): super().__init__(chunk_cube_size, compute_backend) self.thr = thr + @staticmethod @jit(nopython=True) - def compute_connected_component_segmentation(self, hard_aff: np.ndarray) -> np.ndarray: + def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: """ Compute connected components from affinities. From 807b2e8b972e01555fe2d9cc1d5cb25710f92446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zuzana=20Urbanov=C3=A1?= Date: Fri, 29 Aug 2025 18:42:12 +0200 Subject: [PATCH 28/33] autocomplete wrong --- inference2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference2.py b/inference2.py index a2e7473..70d5aa3 100644 --- a/inference2.py +++ b/inference2.py @@ -441,7 +441,7 @@ def segment_chunk_wrapped(self, chunk, i, aff, zarr_path): x_end, y_end, z_end = self.get_xyz_end(chunk, aff.shape) curr_aff = aff[:, x : x_end, y : y_end, z : z_end] curr_seg = self.segment_chunk(curr_aff) - with zarr.open(zarr_path, mode="w", chunks=curr_seg.shape) as z_root: + with zarr.open(zarr_path, mode="w") as z_root: z_root[i, : x_end - x, : y_end - y, : z_end - z] = curr_seg def segment_chunk(self, curr_aff): From 73878439df56900ac2e535c01a09f69699445064 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zuzana=20Urbanov=C3=A1?= Date: Sun, 31 Aug 2025 23:15:45 +0200 Subject: [PATCH 29/33] update and debug and profile --- debug_test_inference.py | 34 +++++++++++++++---- environment.yaml | 2 ++ inference.py | 9 ++---- inference2.py | 72 ++++++++++++++++++++++------------------- 4 files changed, 71 insertions(+), 46 deletions(-) diff --git a/debug_test_inference.py b/debug_test_inference.py index 6488285..e8a1102 100644 --- a/debug_test_inference.py +++ b/debug_test_inference.py @@ -1,7 +1,7 @@ import zarr from inference import measure_stats, predict_aff, full_inference, thresholding -from inference2 import Thresholding +from inference2 import Thresholding, AffinityPredictor def test_local_prediction(): @@ -56,11 +56,33 @@ def test_full_inference(): full_inference(img_data, model_path, thr=0.7685) -def test_thresholding(): +def test_thresholding_old(): aff = zarr.open("/cajal/scratch/projects/misc/zuzur/skeleton_recall/rerun_base/dsbase_s0_a0_25-03-20_19-44-08-067760/pred_aff_val_6.zarr/") - postprocessor = Thresholding(1024, "local", 0.7685) - postprocessor.aff_to_seg(aff, zarr_path="test2.zarr") - #thresholding(aff, 0.7685, "test2.zarr", 1024, "local") + thresholding(aff, 0.7685, "test_old2.zarr", 1024, "local") + +def test_thresholding_new(): + aff = zarr.open("/cajal/scratch/projects/misc/zuzur/skeleton_recall/rerun_base/dsbase_s0_a0_25-03-20_19-44-08-067760/pred_aff_val_6.zarr/") + postprocessor = Thresholding(300, "local", 0.7685) + postprocessor.aff_to_seg(aff, zarr_path="test1.zarr") + +def test_prediction_old(): + input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" + img_data = zarr.open(input_path, mode="r")["img"] + model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" + + predict_aff(img_data, model_path=model_path, chunk_cube_size=1024, compute_backend="local", + zarr_path=f"/cajal/scratch/projects/misc/zuzur/test_0.zarr", do_overlap=True, + prediction_channels=3, divide=255, small_size=128) + + +def test_prediction_new(): + input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" + img_data = zarr.open(input_path, mode="r")["img"] + model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" + + predictor = AffinityPredictor(model_path=model_path, chunk_cube_size=1024, compute_backend="local", do_overlap=True, + prediction_channels=3, divide=255, small_size=128) + predictor.img_to_aff(img_data, zarr_path=f"/cajal/scratch/projects/misc/zuzur/newtest_0.zarr") if __name__ == "__main__": - test_thresholding() + test_prediction_new() diff --git a/environment.yaml b/environment.yaml index 520faa1..1a60e4a 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,7 @@ dependencies: - bzip2=1.0.8 - ca-certificates=2024.8.30 - cython==3.0.11 + - dask=2025.7.0 - ld_impl_linux-64=2.43 - libexpat=2.6.3 - libffi=3.4.2 @@ -68,6 +69,7 @@ dependencies: - monai==1.3.2 - mpmath==1.3.0 - multidict==6.1.0 + - mwatershed==0.5.3 - networkx==3.3 - nibabel==5.3.0 - numba==0.60.0 diff --git a/inference.py b/inference.py index 7a5360e..b88d5e6 100644 --- a/inference.py +++ b/inference.py @@ -177,11 +177,8 @@ def predict_aff( chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') if compute_backend == "local": - if not model: - from BANIS import BANIS - model = BANIS.load_from_checkpoint(model_path) for chunk in tqdm(chunked_patch_coordinates): - predict_aff_patches_chunked(chunk, img, model, zarr_path + "_tmp", small_size, do_overlap, prediction_channels, divide) + predict_aff_patches_chunked(chunk, img, model_path, zarr_path + "_tmp", small_size, do_overlap, prediction_channels, divide) torch.cuda.empty_cache() # TODO: does this help? else: if compute_backend == "local_cluster": @@ -314,7 +311,7 @@ def predict_aff_patches_chunked(patch_coordinates, img, model_path, zarr_path, s single_pred_weight = get_single_pred_weight(do_overlap, small_size) from BANIS import BANIS - print(model_path, flush=True) + print(f"model path: {model_path}", flush=True) model = BANIS.load_from_checkpoint(model_path) for x_global, y_global, z_global in patch_coordinates: @@ -467,7 +464,7 @@ def thresholding(aff, thr, zarr_path, chunk_cube_size, compute_backend): for i, chunk in enumerate(tqdm(chunks)): x, y, z = chunk x_end, y_end, z_end = min(x + chunk_cube_size, aff.shape[1]), min(y + chunk_cube_size, aff.shape[2]), min(z + chunk_cube_size, aff.shape[3]) - data = z_root["instances_patched"][i, :x_end, :y_end, :z_end] + data = z_root["instances_patched"][i, : x_end - x, : y_end - y, : z_end - z] perm = [0] for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max assert (i, idx) in fragment_agglomeration_flattened # all fragments have a new index (congiguous from 0) diff --git a/inference2.py b/inference2.py index 70d5aa3..de865cd 100644 --- a/inference2.py +++ b/inference2.py @@ -76,6 +76,17 @@ def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: """Scale sigmoid to avoid numerical issues in high confidence fp16.""" return sigmoid(0.2 * x) + @staticmethod + def get_xyz_end(chunk, chunk_cube_size, aff_shape): + """ + Returns the end indices of a chunk, that correspond either to the chunk size, or align with the size of the affinities. + """ + x, y, z = chunk + x_end, y_end, z_end = (min(x + chunk_cube_size, aff_shape[1]), + min(y + chunk_cube_size, aff_shape[2]), + min(z + chunk_cube_size, aff_shape[3])) + return (x_end, y_end, z_end) + class AffinityPredictor: def __init__(self, @@ -246,27 +257,28 @@ def __init__(self, def aff_to_seg(self, aff, zarr_path): chunks = Utils.get_coordinates(aff.shape[1:], self.chunk_cube_size, overlap=1, last_has_smaller_overlap=False) reverse_chunks = {chunk: i for i, chunk in enumerate(chunks)} + patched_zarr_path = "tmp_" + zarr_path zarr_chunk_size = min(self.chunk_cube_size, 512) z_root = zarr.create(shape=(len(chunks), self.chunk_cube_size, self.chunk_cube_size, self.chunk_cube_size), - store="tmp_" + zarr_path, dtype='i4', overwrite=True, + store=patched_zarr_path, dtype='i4', overwrite=True, chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) # SEGMENT AFFINITIES IN CHUNKS THAT FIT IN MEMORY - self.patched_segment_affinities(aff, "tmp_" + zarr_path, chunks) + self.patched_segment_affinities(aff, patched_zarr_path, chunks) # FIND GROUPS OF FRAGMENTS THAT SHOULD BE MERGED BETWEEN CHUNKS - fragment_agglomeration = self.agglomerate_fragments(chunks, reverse_chunks, zarr_path, aff.shape) + fragment_agglomeration = self.agglomerate_fragments(chunks, reverse_chunks, patched_zarr_path, aff.shape) # MERGE AND RELABEL INSTANCES GLOBALLY - self.merge_and_relabel(fragment_agglomeration, "tmp_" + zarr_path, zarr_path, chunks, aff.shape) + self.merge_and_relabel(fragment_agglomeration, patched_zarr_path, zarr_path, chunks, aff.shape) return - def patched_segment_affinities(self, aff, zarr_path, chunks): + def patched_segment_affinities(self, aff, patched_zarr_path, chunks): if self.compute_backend == "local": for i, chunk in enumerate(tqdm(chunks)): - self.segment_chunk_wrapped(chunk, i, aff, zarr_path) + self.segment_chunk_wrapped(chunk, i, aff, patched_zarr_path) else: if self.compute_backend == "local_cluster": from dask_cuda import LocalCUDACluster @@ -289,34 +301,26 @@ def patched_segment_affinities(self, aff, zarr_path, chunks): print(f"Waiting for workers...") client.wait_for_workers(n_workers=1) print("Dask Client Dashboard:", client.dashboard_link) - tasks = [dask.delayed(self.segment_chunk_wrapped)(chunk, i, aff, zarr_path) for (i, chunk) in enumerate(chunks)] + tasks = [dask.delayed(self.segment_chunk_wrapped)(chunk, i, aff, patched_zarr_path) for (i, chunk) in enumerate(chunks)] futures = persist(tasks) progress(futures) # progress bar compute(futures) - def get_xyz_end(self, chunk, aff_shape): - """ - Returns the end indices of a chunk, that correspond either to the chunk size, or align with the size of the affinities. - """ - x, y, z = chunk - x_end, y_end, z_end = (min(x + self.chunk_cube_size, aff_shape[1]), - min(y + self.chunk_cube_size, aff_shape[2]), - min(z + self.chunk_cube_size, aff_shape[3])) - return (x_end, y_end, z_end) - - def agglomerate_fragments(self, chunks, reverse_chunks, zarr_path, aff_shape): + def agglomerate_fragments(self, chunks, reverse_chunks, patched_zarr_path, aff_shape): if self.compute_backend == "local": fragment_agglomeration = {} for i, chunk in enumerate(tqdm(chunks)): - chunk_agglomeration = self.agglomerate_chunk(chunk, reverse_chunks, zarr_path, aff_shape) - fragment_agglomeration.update(chunk_agglomeration) + chunk_agglomeration = self.agglomerate_chunk(chunk, reverse_chunks, patched_zarr_path, aff_shape) + for node, nbrs in chunk_agglomeration.items(): + for nbr in nbrs: + fragment_agglomeration.setdefault(node, set()).add(nbr) if len(fragment_agglomeration) > 10_000_000: print("WARNING: fragment agglomeration too long, might cause problems!") # TODO: solve this curr_id, fragment_agglomeration_flattened = self.flatten_agglomeration(fragment_agglomeration) print("MERGING CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) - fragment_agglomeration_flattened = self.add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, curr_id, chunks, zarr_path) + fragment_agglomeration_flattened = self.add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, curr_id, chunks, patched_zarr_path) print("ALL CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) else: @@ -325,18 +329,18 @@ def agglomerate_fragments(self, chunks, reverse_chunks, zarr_path, aff_shape): return fragment_agglomeration_flattened - def agglomerate_chunk(self, chunk, reverse_chunks, zarr_path, aff_shape): + def agglomerate_chunk(self, chunk, reverse_chunks, patched_zarr_path, aff_shape): fragment_agglomeration = {} x, y, z = chunk - x_end, y_end, z_end = self.get_xyz_end(chunk, aff_shape) - z_root = zarr.open(zarr_path, mode='r') + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff_shape) + z_root = zarr.open(patched_zarr_path, mode='r') # for (x,y,z) get the last slice of the current cube (l, low) and the first slice of the next cube (h, high) # these slices overlap, so the voxels should have the same global id if x_end < aff_shape[1]: chunk_l = reverse_chunks[chunk] - chunk_h = reverse_chunks[x + self.chunk_cube_size, y, z] + chunk_h = reverse_chunks[x + self.chunk_cube_size - 1, y, z] result_l = z_root[chunk_l, -1:, :, :] result_h = z_root[chunk_h, :1, :, :] combined = np.stack([result_l.flatten(), result_h.flatten()]).T @@ -345,7 +349,7 @@ def agglomerate_chunk(self, chunk, reverse_chunks, zarr_path, aff_shape): if y_end < aff_shape[2]: chunk_l = reverse_chunks[chunk] - chunk_h = reverse_chunks[x, y + self.chunk_cube_size, z] + chunk_h = reverse_chunks[x, y + self.chunk_cube_size - 1, z] result_l = z_root[chunk_l, :, -1:, :] result_h = z_root[chunk_h, :, :1, :] combined = np.stack([result_l.flatten(), result_h.flatten()]).T @@ -354,7 +358,7 @@ def agglomerate_chunk(self, chunk, reverse_chunks, zarr_path, aff_shape): if z_end < aff_shape[3]: chunk_l = reverse_chunks[chunk] - chunk_h = reverse_chunks[x, y, z + self.chunk_cube_size] + chunk_h = reverse_chunks[x, y, z + self.chunk_cube_size - 1] result_l = z_root[chunk_l, :, :, -1:] result_h = z_root[chunk_h, :, :, :1] combined = np.stack([result_l.flatten(), result_h.flatten()]).T @@ -401,8 +405,8 @@ def flatten_agglomeration(self, fragment_agglomeration): return cur_id, fragment_agglomeration_flattened - def add_all_fragments_to_agglomeration(self, fragment_agglomeration_flattened, cur_id, chunks, zarr_path): - z_root = zarr.open(zarr_path) + def add_all_fragments_to_agglomeration(self, fragment_agglomeration_flattened, cur_id, chunks, patched_zarr_path): + z_root = zarr.open(patched_zarr_path) for i, chunk in enumerate(tqdm(chunks)): data = z_root[i, :, :, :] for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max @@ -421,8 +425,8 @@ def merge_and_relabel(self, fragment_agglomeration, zarr_patched, zarr_final, ch if self.compute_backend == "local": for i, chunk in enumerate(tqdm(chunks)): x, y, z = chunk - x_end, y_end, z_end = self.get_xyz_end(chunk, aff_shape) - data = z_root[i, :x_end, :y_end, :z_end] + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff_shape) + data = z_root[i, : x_end - x, : y_end - y, : z_end - z] perm = [0] for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max assert (i, idx) in fragment_agglomeration # all fragments have a new index (contiguous from 0) @@ -438,11 +442,11 @@ def merge_and_relabel(self, fragment_agglomeration, zarr_patched, zarr_final, ch def segment_chunk_wrapped(self, chunk, i, aff, zarr_path): x, y, z = chunk - x_end, y_end, z_end = self.get_xyz_end(chunk, aff.shape) + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff.shape) curr_aff = aff[:, x : x_end, y : y_end, z : z_end] curr_seg = self.segment_chunk(curr_aff) - with zarr.open(zarr_path, mode="w") as z_root: - z_root[i, : x_end - x, : y_end - y, : z_end - z] = curr_seg + z_root = zarr.open(zarr_path, mode="r+") + z_root[i, : x_end - x, : y_end - y, : z_end - z] = curr_seg def segment_chunk(self, curr_aff): """ From 03dc9776bf9e46beb1c4235c1f8af856916455e0 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 10 Sep 2025 14:37:06 +0200 Subject: [PATCH 30/33] object-based inference --- README.md | 12 + debug_parched_inference_copy.py | 1032 ----------------------------- debug_progress.py | 40 -- debug_retrain.py | 162 ----- debug_test_inference.py | 88 --- debug_visualilze.py | 132 ---- environment.yaml | 3 + inference.py | 1091 +++++++++++++++++-------------- 8 files changed, 600 insertions(+), 1960 deletions(-) delete mode 100644 debug_parched_inference_copy.py delete mode 100644 debug_progress.py delete mode 100644 debug_retrain.py delete mode 100644 debug_test_inference.py delete mode 100644 debug_visualilze.py diff --git a/README.md b/README.md index 81d39f0..8af933e 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,18 @@ python slurm_job_scheduler.py Adding an `auto_resubmit` argument to `config.yaml` allows Slurm to automatically resubmit jobs that reach the Slurm time limit (see `aff_train.sh`). +## Prediction + +To predict segmentation from an image: + +```bash +python inference --img_path /path/to/image.zarr --model_path /path/to/model.ckpt --chunk_cube_size 3000 +``` + +The `chunk_cube_size` parameter sets the maximum cube size that can be loaded in memory. +If you have enough memory available, set it to a bigger value, if you are tight with memory, set a lower value (in exchange for increased computation time). +See [inference.py](inference.py) for other parameters. + ## Evaluation To evaluate a predicted segmentation (`.zarr` or `.npy`): diff --git a/debug_parched_inference_copy.py b/debug_parched_inference_copy.py deleted file mode 100644 index ad8852c..0000000 --- a/debug_parched_inference_copy.py +++ /dev/null @@ -1,1032 +0,0 @@ -import os -import pickle -import shutil -import time -from collections import Counter -from concurrent.futures import ThreadPoolExecutor -from copy import deepcopy -from datetime import timedelta -from functools import partial -import hydra -from omegaconf import DictConfig, open_dict -import joblib -import itertools -import gc - -import cc3d -import configargparse -import fastremap -import mwatershed -import numpy as np -import pandas as pd -import zarr -from dask import config as dask_cfg -from dask_jobqueue import SLURMCluster -from dask.distributed import Client, LocalCluster, as_completed -from tqdm import tqdm -from numba import jit -from datetime import datetime -import psutil, os - - -from metrics import compute_metrics - -# this only changes the configuration in the local process, and not subprocesses (like remote workers) -dask_cfg.set( - { - "distributed.scheduler.worker-ttl": "1h", # required because mwatershed blocks for a long time - "distributed.comm.timeouts.connect": "1h", - "distributed.comm.timeouts.tcp": "1h", - "distributed.admin.tick.limit": "30s" # increase time before triggering a warning (default limit of 3s remains in workers - https://github.com/dask/distributed/issues/3882) - } -) - - -def chunk_list(list_to_chunk, chunk_size): # todo:use zarr bag instead? - return [ - list_to_chunk[i: i + chunk_size] - for i in range(0, len(list_to_chunk), chunk_size) - ] - - -def patched_thresholding(aff, conf): - """ - Creates a segmentation from an affinity map. - Segmentation is created patchwise using thresholding or mutex watershed, - and subsequently merging segments that span multiple patches. - """ - ijk_to_idx, patch_to_coords = get_mappings(aff, conf.patch_size) - - # Predict segments for all patches - if conf.debug_patched_seg_path: - print(f"Using patched segmentation from {conf.debug_patched_seg_path}") - patched_seg = zarr.open(conf.debug_patched_seg_path, mode="r") - else: - patched_seg = segment_patches( - aff, - conf, - patch_to_coords, - ijk_to_idx - ) - - # Agglomerate segments at the edges of neighboring patches - if conf.debug_fragment_agglomeration_path: - print(f"Using fragment agglomeration from {conf.debug_fragment_agglomeration_path}") - if not os.path.normpath(conf.debug_fragment_agglomeration_path) == os.path.normpath(f"{conf.path_root}/agglo_pkl_chunks"): - print(f"CONF PATH {conf.debug_fragment_agglomeration_path} NOT EQUAL TO {conf.path_root}/agglo_pkl_chunks - MIGHT CAUSE PROBLEMS") - else: - fragment_agglomeration_flattened = None - print(f"NOT LOADING PATH {conf.debug_fragment_agglomeration_path} - WILL DO IT IN THE WORKERS") - fragment_agglomeration_flattened = None - else: - fragment_agglomeration_flattened = compute_fragment_agglomeration( - patched_seg, - aff, - conf, - ijk_to_idx, - patch_to_coords, - ) - - # Unify indexing of all patches, including the merged agglomerations at the border - if conf.debug_relabeled_seg_path: - print(f"Using relabeled agglomerated segmentation from {conf.debug_relabeled_seg_path}") - agglomerated_seg = zarr.open(conf.debug_relabeled_seg_path, mode="r") - else: - agglomerated_seg = relabel_globally( - fragment_agglomeration_flattened, - patched_seg, - aff, - conf, - patch_to_coords, - ijk_to_idx, - ) - - # Filter out segments that are too small - filtered_seg = size_filter_relabel(agglomerated_seg, conf) - - # Delete intermediary files that are not needed anymore - if conf.delete_files: - try: - zarr.DirectoryStore(f"{conf.path_root}/patched_seg.zarr").rmdir() - os.rmdir(f"{conf.path_root}/agglo_pkl_chunks") - zarr.DirectoryStore(f"{conf.path_root}/agglomerated_seg.zarr").rmdir() - os.remove(f"{conf.path_root}/id_mapping.csv") - shutil.rmtree(conf.path_root + "/dask-worker-space/", ignore_errors=True) - except: - print("Exception while deleting files.") - - return filtered_seg - - -def get_mappings(aff, patch_size): - """ - Returns coordinates of patches, and their indices. - """ - # x,y,z: coordinates - # i,j,k: patch indices (i.e. i*n <= x < (i+1)*n) - # idx: patch index (i.e. i * len(ys) * len(zs) + j * len(zs) + k) - - xs = list(range(0, aff.shape[1], patch_size)) - ys = list(range(0, aff.shape[2], patch_size)) - zs = list(range(0, aff.shape[3], patch_size)) - - ijk_to_idx = { - (i, j, k): i * len(ys) * len(zs) + j * len(zs) + k - for i in range(len(xs)) - for j in range(len(ys)) - for k in range(len(zs)) - } - - patch_to_coords = { - (i, j, k): ( - (xs[i], xs[i + 1] if i + 1 < len(xs) else None), - (ys[j], ys[j + 1] if j + 1 < len(ys) else None), - (zs[k], zs[k + 1] if k + 1 < len(zs) else None), - ) - for i in range(len(xs)) - for j in range(len(ys)) - for k in range(len(zs)) - } - - return ijk_to_idx, patch_to_coords - - -def segment_patches(aff, conf, patch_to_coords, ijk_to_idx): - """ - Predict segmentation for each patch independently. - Returns segmentation of shape (len(patch_to_coords), patch_size + overlap, patch_size + overlap, patch_size + overlap) - """ - - print(f"Computing patch segmentation...") - print(f"Store: {conf.path_root}/patched_seg.zarr") - patched_seg = zarr.zeros( - ( - len(patch_to_coords), - conf.patch_size + conf.overlap, - conf.patch_size + conf.overlap, - conf.patch_size + conf.overlap, - ), - chunks=(1, conf.patch_size + conf.overlap, conf.patch_size + conf.overlap, conf.patch_size + conf.overlap), - dtype=np.uint32, - store=f"{conf.path_root}/patched_seg.zarr", - ) - print(patched_seg.shape) - - ijks_chunked = chunk_list( - list(patch_to_coords.keys()), - chunk_size=1 # faster than bigger chunks - ) - - def chunked_thresholding(ijks): - for ijk in ijks: - threshold_ijk(ijk, aff, patched_seg, patch_to_coords, ijk_to_idx, conf) - - if not conf.use_parallelization: - result = list(map(chunked_thresholding, tqdm(ijks_chunked))) - # we don't need the result, the iteration writes the segmentation directly into the zarr array - else: - if conf.use_slurm: - cluster = SLURMCluster( - cores=8, - memory="500GB", - processes=1, - worker_extra_args=["--resources processes=1"], - log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/segment/", - walltime="3:00:00" # default is 30mins and then worker gets killed, chunked ijks can take more time - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - - else: - cluster = LocalCluster( - n_workers=min(os.cpu_count(), 16), - threads_per_worker=1, - local_directory=conf.path_root + "/dask-worker-space/" - # to avoid independent runs deleting each other's directories - ) - - with Client(cluster) as client: - print("Dask threshold Client Dashboard:", client.dashboard_link) - - start_time = time.time() - futures = client.map( - chunked_thresholding, - ijks_chunked, - batch_size=1, - resources={"processes": 1}, - ) - for future in tqdm(as_completed(futures), total=len(ijks_chunked), smoothing=0): - future.release() - pass # tqdm progress bar is nicer and shows remaining time - print(f"Computing patch segmentations took {timedelta(seconds=int(time.time() - start_time))}") - #cluster.close() - - return patched_seg - - -def threshold_ijk(ijk, aff, patched_seg, patch_to_coords, ijk_to_idx, conf): - """ - Predicts segmentation from the affinities for one patch, and writes the result into patched_seg. - """ - i, j, k = ijk - #dask.distributed.print(f"processing: {ijk_to_idx[i, j, k]}") - ((x_start, x_end), (y_start, y_end), (z_start, z_end)) = patch_to_coords[(i, j, k)] - cur_aff = aff[ - :, - max(0, x_start - conf.overlap - conf.surrounding): ( - x_end + conf.surrounding if x_end is not None else None), - max(0, y_start - conf.overlap - conf.surrounding): ( - y_end + conf.surrounding if y_end is not None else None), - max(0, z_start - conf.overlap - conf.surrounding): ( - z_end + conf.surrounding if z_end is not None else None), - ] - - cur_aff[np.isnan(cur_aff)] = 0.0 - cur_aff = np.clip(cur_aff, 0.0, 1.0) # todo: enforce clip + not nan + not inf in aff inference - - # extend on all cut off sides - cur_aff_tmp = cur_aff - cur_aff = np.zeros( - ( - aff.shape[0], - conf.patch_size + (conf.overlap + 2 * conf.surrounding), - conf.patch_size + (conf.overlap + 2 * conf.surrounding), - conf.patch_size + (conf.overlap + 2 * conf.surrounding), - ) - ) - - x_start_tmp = (conf.overlap + conf.surrounding) if x_start == 0 else 0 - y_start_tmp = (conf.overlap + conf.surrounding) if y_start == 0 else 0 - z_start_tmp = (conf.overlap + conf.surrounding) if z_start == 0 else 0 - cur_aff[ - :, - x_start_tmp: x_start_tmp + cur_aff_tmp.shape[1], - y_start_tmp: y_start_tmp + cur_aff_tmp.shape[2], - z_start_tmp: z_start_tmp + cur_aff_tmp.shape[3], - ] = cur_aff_tmp - - if conf.mws: - cur_aff = deepcopy(cur_aff).astype(np.float64) - cur_aff[:3] += conf.mws_bias_short - cur_aff[3:] += conf.mws_bias_long if conf.mws_bias_long is not None else 0.0 - - cur_aff[:3] = np.clip(cur_aff[:3], 0, 1) - cur_aff[3:] = np.clip(cur_aff[3:], -1, 0) - - mws_pred = mwatershed.agglom( - affinities=cur_aff if conf.mws_bias_long is not None else cur_aff[:3], - offsets=( - [ - [1, 0, 0], - [0, 1, 0], - [0, 0, 1], - [conf.long_range, 0, 0], - [0, conf.long_range, 0], - [0, 0, conf.long_range], - ] - if conf.mws_bias_long is not None - else [[1, 0, 0], [0, 1, 0], [0, 0, 1]] - ), - ) - - # mwatershed is wasteful with IDs (not contiguous) -> filter out single voxel objects and relabel again - # size filter. single voxel objects are irrelevant for merging, take ~95% of IDs in an example cube, causing OOM when creating fragment_agglomeration - dusted = cc3d.dust( # does a cc first (reducing false mergers in add_to_agglomeration) - mws_pred, - threshold=2, - connectivity=6, - in_place=False, - ) - # relabeling to save IDs - pred_relabeled, N = cc3d.connected_components( - dusted, return_N=True, connectivity=6 - ) - - assert (pred_relabeled[mws_pred == 0] == 0).all() # 0 stays 0 - assert N <= np.iinfo(np.uint32).max - - pred = pred_relabeled.astype(np.uint32) - - else: - pred = conn_comps(cur_aff >= conf.thr) - - pred_no_surrounding = ( - pred[ - conf.surrounding:-conf.surrounding, - conf.surrounding:-conf.surrounding, - conf.surrounding:-conf.surrounding, - ] - if conf.surrounding > 0 - else pred - ) - patched_seg[ijk_to_idx[i, j, k]] = pred_no_surrounding - print(f"processed: {ijk_to_idx[i, j, k]}") - return - -@jit(nopython=True) -def conn_comps(hard_aff): - visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.bool_) - seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) - cur_id = 1 - cur_id_used = False - for i in range(visited.shape[0]): - for j in range(visited.shape[1]): - for k in range(visited.shape[2]): - if hard_aff[ - :, i, j, k - ].any() and not visited[i, j, k]: # if foreground - cur_to_visit = [(i, j, k)] # todo: use 3 array.array instead? or np.array and append? - visited[i, j, k] = True - while len(cur_to_visit) > 0: - x, y, z = cur_to_visit.pop() - # if not visited[x, y, z]: - # visited[x, y, z] = True - seg[x, y, z] = cur_id - cur_id_used = True - if x + 1 < visited.shape[0] and hard_aff[0, x, y, z] and not visited[x + 1, y, z]: - cur_to_visit.append((x + 1, y, z)) - visited[x + 1, y, z] = True - if y + 1 < visited.shape[1] and hard_aff[1, x, y, z] and not visited[x, y + 1, z]: - cur_to_visit.append((x, y + 1, z)) - visited[x, y + 1, z] = True - if z + 1 < visited.shape[2] and hard_aff[2, x, y, z] and not visited[x, y, z + 1]: - cur_to_visit.append((x, y, z + 1)) - visited[x, y, z + 1] = True - if x - 1 >= 0 and hard_aff[0, x - 1, y, z] and not visited[x - 1, y, z]: - cur_to_visit.append((x - 1, y, z)) - visited[x - 1, y, z] = True - if y - 1 >= 0 and hard_aff[1, x, y - 1, z] and not visited[x, y - 1, z]: - cur_to_visit.append((x, y - 1, z)) - visited[x, y - 1, z] = True - if z - 1 >= 0 and hard_aff[2, x, y, z - 1] and not visited[x, y, z - 1]: - cur_to_visit.append((x, y, z - 1)) - visited[x, y, z - 1] = True - if cur_id_used: - cur_id += 1 - cur_id_used = False - return seg - - - - -def compute_fragment_agglomeration( - patched_seg, - aff, - conf, - ijk_to_idx, - patch_to_coords, -): - """ - From the patched segmentation, merges fragments at the border of adjacent patches. - Computes flattened agglomeration, a dict with the keys (i, j, k, idx), where (i, j, k) is a patch and idx is an id in that patch, and values of a global id for this fragment. - """ - - print("Computing fragment agglomeration...") - data_chunked = list( - enumerate(chunk_list(list(patch_to_coords.items()), chunk_size=8)) - ) - - if conf.use_slurm: - cluster = SLURMCluster( - cores=16, - memory="500GB", - processes=1, - log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/agglo/", - walltime="3:00:00" - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - else: - cluster = LocalCluster( - n_workers=min(os.cpu_count(), 16), - threads_per_worker=1, - local_directory=conf.path_root + "/dask-worker-space/" - ) - - with Client(cluster) as client: - print("Dask Client Dashboard:", client.dashboard_link) - - start_time = time.time() - futures = client.map( - partial( - compute_agglomeration_part, - patched_seg=patched_seg, - ijk_to_idx=ijk_to_idx, - conf=conf, - aff=aff - ), - data_chunked - ) - - fragment_agglomeration = {} - # agglomerate all fragments from all chunks as they complete - for future, frag_aggl in tqdm(as_completed(futures, with_results=True), total=len(data_chunked), smoothing=0): - for k, v in frag_aggl.items(): - fragment_agglomeration.setdefault(k, set()).update(v) - - print(f"Computing fragment agglomeration in patches took {timedelta(seconds=time.time() - start_time)}") - #cluster.close() - - fragment_agglomeration_flattened = flatten_agglomeration(fragment_agglomeration, f"{conf.path_root}/agglo_pkl_chunks") - - return fragment_agglomeration_flattened - - -def compute_agglomeration_part( - idx_samples, - patched_seg, - ijk_to_idx, - conf, - aff -): - """ - Merges neighboring voxels from different cubes, creates a graph with connected fragments. - Args: - idx_samples: idx of the chunk, chunked patch indices - - Returns: - fragment_agglomeration: Dictionary representing a graph between vertices (i, j, k, idx) - where (i, j, k) is a patch index, and idx a fragment id in this patch. - An edge means the fragments should be merged. - """ - print("Entered compute_agglomeration_part", flush=True) - idx, samples = idx_samples - fragment_agglomeration = {} - for sample in samples: - print(f"{datetime.now()}: computing {sample}", flush=True) - (i, j, k), ((x_start, x_end), (y_start, y_end), (z_start, z_end)) = sample - - # for x,y,z get the last slice of the current cube (l, low) and the first slice of the next cube (h, high) - # they overlap, the voxels should have the same id - - if x_end is not None: - if conf.overlap > 0: - result_l = patched_seg[ijk_to_idx[i, j, k], -conf.overlap:] - result_h = patched_seg[ijk_to_idx[i + 1, j, k], :conf.overlap] - uniques = compute_uniques(conf.do_overlap_filter, result_h, result_l, min_overlap=conf.min_overlap) - else: - # merge according to short range affinities between each pair of IDs in neighboring cubes - cur_aff = ( - aff[0, x_end - 1: x_end, y_start:y_end, z_start:z_end] >= conf.merge_thr - ) - # todo: this simple thresholding can re-introduce catastrophic mergers. - # tune thr? use mean instead of max? better use overlap strategy? - - result_l = patched_seg[ijk_to_idx[i, j, k], -1:][ - : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] - ][cur_aff] - result_h = patched_seg[ijk_to_idx[i + 1, j, k], :1][ - : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] - ][cur_aff] - combined = np.stack([result_l, result_h]).T - uniques = np.unique(combined, axis=0) - - for id_l, id_h in uniques: - if id_l > 0 and id_h > 0: - fragment_agglomeration.setdefault((i + 1, j, k, id_h), set()).add( - (i, j, k, id_l) - ) - fragment_agglomeration.setdefault((i, j, k, id_l), set()).add( - (i + 1, j, k, id_h) - ) - - if y_end is not None: - if conf.overlap > 0: - result_l = patched_seg[ijk_to_idx[i, j, k], :, -conf.overlap:] - result_h = patched_seg[ijk_to_idx[i, j + 1, k], :, :conf.overlap] - uniques = compute_uniques(conf.do_overlap_filter, result_h, result_l, min_overlap=conf.min_overlap) - else: - cur_aff = ( - aff[1, x_start:x_end, y_end - 1: y_end, z_start:z_end] >= conf.merge_thr - ) - result_l = patched_seg[ijk_to_idx[i, j, k], :, -1:][ - : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] - ][cur_aff] - result_h = patched_seg[ijk_to_idx[i, j + 1, k], :, :1][ - : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] - ][cur_aff] - combined = np.stack([result_l, result_h]).T - uniques = np.unique(combined, axis=0) - - for id_l, id_h in uniques: - if id_l > 0 and id_h > 0: - fragment_agglomeration.setdefault((i, j + 1, k, id_h), set()).add( - (i, j, k, id_l) - ) - fragment_agglomeration.setdefault((i, j, k, id_l), set()).add( - (i, j + 1, k, id_h) - ) - - if z_end is not None: - if conf.overlap > 0: - result_l = patched_seg[ijk_to_idx[i, j, k], :, :, -conf.overlap:] - result_h = patched_seg[ijk_to_idx[i, j, k + 1], :, :, :conf.overlap] - uniques = compute_uniques(conf.do_overlap_filter, result_h, result_l, min_overlap=conf.min_overlap) - else: - cur_aff = ( - aff[2, x_start:x_end, y_start:y_end, z_end - 1: z_end] >= conf.merge_thr - ) - result_l = patched_seg[ijk_to_idx[i, j, k], :, :, -1:][ - : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] - ][cur_aff] - result_h = patched_seg[ijk_to_idx[i, j, k + 1], :, :, :1][ - : cur_aff.shape[0], : cur_aff.shape[1], : cur_aff.shape[2] - ][cur_aff] - combined = np.stack([result_l, result_h]).T - uniques = np.unique(combined, axis=0) - - for id_l, id_h in uniques: - if id_l > 0 and id_h > 0: - fragment_agglomeration.setdefault((i, j, k + 1, id_h), set()).add( - (i, j, k, id_l) - ) - fragment_agglomeration.setdefault((i, j, k, id_l), set()).add( - (i, j, k + 1, id_h) - ) - print(f"{datetime.now()}: done", flush=True) - return fragment_agglomeration - - -def compute_uniques(do_overlap_filter, result_h, result_l, min_overlap=0.9): - if do_overlap_filter: - result_l_ccs = cc3d.connected_components(result_l, connectivity=6) - result_h_ccs = cc3d.connected_components(result_h, connectivity=6) - - l_ccs_to_l = np.unique( - np.stack([result_l_ccs.flatten(), result_l.flatten()]), axis=1 - ) - l_ccs_to_l = {l_ccs: l for l_ccs, l in l_ccs_to_l.T} - - h_ccs_to_h = np.unique( - np.stack([result_h_ccs.flatten(), result_h.flatten()]), axis=1 - ) - h_ccs_to_h = {h_ccs: h for h_ccs, h in h_ccs_to_h.T} - - combined_ccs = np.stack([result_l_ccs.flatten(), result_h_ccs.flatten()]) - uniques_ccs, counts_ccs = np.unique(combined_ccs, axis=1, return_counts=True) - uniques_ccs = uniques_ccs.T - # uniques_ccs = exact_overlap_filter(uniques_ccs) - uniques_ccs = mutual_largest_overlap_filter( - counts_ccs, uniques_ccs, - min_overlap=min_overlap - ) - - uniques = [ - (l_ccs_to_l[l_ccs], h_ccs_to_h[h_ccs]) for l_ccs, h_ccs in uniques_ccs - ] - - else: - combined = np.stack([result_l.flatten(), result_h.flatten()]).T - uniques, counts = np.unique(combined, axis=0, return_counts=True) - return uniques - - -def exact_overlap_filter(uniques): - # keep only non-zero IDs with mutually exact corresponding count to reduce merge errors - - l_partners = {} - h_partners = {} - for id_l, id_h in uniques: - l_partners.setdefault(id_l, []).append(id_h) - h_partners.setdefault(id_h, []).append(id_l) - - uniques = [ - (id_l, id_h) - for id_l, id_h in uniques - if ( - (len(l_partners[id_l]) == len(h_partners[id_h]) == 1) - and id_l != 0 - and id_h != 0 - ) - ] - return uniques - - -def mutual_largest_overlap_filter( - counts, - uniques, - min_overlap=0.5, - # 1.0: exact overlap, 0.0: any overlap, 0.5: at least half of total count -): - # keep only non-zero IDs with mutually largest corresponding count to reduce merge errors - - # todo: Merge only perfect matches? i.e. except for 0 there are no other IDs in the uniques (ignore counts) - - highest_count_l = {} - highest_count_h = {} - - total_count_l = {} - total_count_h = {} - - for (id_l, id_h), count in zip(uniques, counts): - total_count_l[id_l] = total_count_l.get(id_l, 0) + count - total_count_h[id_h] = total_count_h.get(id_h, 0) + count - - # if id_l > 0 and id_h > 0: don't filter background here: if there is more overlap with background than with another ID, it should not be merged - cur_highest_count_l, cur_highest_id_l = highest_count_l.setdefault( - id_l, (-1, -1) - ) - cur_highest_count_h, cur_highest_id_h = highest_count_h.setdefault( - id_h, (-1, -1) - ) - - if count > cur_highest_count_l: - highest_count_l[id_l] = (count, id_h) - if count > cur_highest_count_h: - highest_count_h[id_h] = (count, id_l) - # uniques = [(id_l, id_h) for id_l, (count, id_h) in highest_count_l.items()] + [ - # (id_l, id_h) for id_h, (count, id_l) in highest_count_h.items() - # ] # for non ccs case: but snakes get split because only 1 assignment per ID but should be several - - uniques = [ - (id_l, id_h) - for id_l, (count, id_h) in highest_count_l.items() - if highest_count_h[id_h][1] == id_l - and count >= min_overlap * total_count_l[id_l] - and count >= min_overlap * total_count_h[id_h] - and id_l != 0 - and id_h != 0 - and count >= 2 # single voxel branches get split - ] - - return np.array(uniques) - - -def flatten_agglomeration(fragment_agglomeration, output_dir): - """ - Computes connected components in the fragment agglomeration graph, and relabels the fragments with ids starting from 1. - Args: - fragment_agglomeration: dictionary with keys (i, j, k, id) indicating cube (i, j, k) and component id in that cube, and values a set of (i, j, k, id) in other cubes that should be connected - Returns: - fragment_agglomeration_flattened: dictionary with keys (i, j, k, id) and values the global component index - """ - cur_id = 1 - fragment_agglomeration_flattened = dict() - fragment_agglomeration_final = dict() - flattened_ids = set() - chunk_n = 0 - os.makedirs(output_dir, exist_ok=True) - for position_id in tqdm(fragment_agglomeration): # (i, j, k, idx) = position_id - if position_id not in flattened_ids: - to_visit = {position_id} - visited = set() - while len(to_visit) > 0: - current = to_visit.pop() - if current not in visited: - visited.add(current) - for neighbor in fragment_agglomeration[current]: - to_visit.add(neighbor) - for v in visited: - assert v not in fragment_agglomeration_flattened - fragment_agglomeration_flattened[v] = cur_id - flattened_ids.add(v) - if len(fragment_agglomeration_flattened) >= 10_000_000: - file_path = os.path.join(output_dir, f"chunk_{chunk_n:02}.pkl") - with open(file_path, "wb") as f: - pickle.dump(fragment_agglomeration_flattened, f) - print(f"Saved {len(fragment_agglomeration_flattened)} items to {file_path}") - fragment_agglomeration_final.update(fragment_agglomeration_flattened) - fragment_agglomeration_flattened = dict() - chunk_n += 1 - cur_id += 1 - - if fragment_agglomeration_flattened: - file_path = os.path.join(output_dir, f"chunk_{chunk_n:02}.pkl") - with open(file_path, "wb") as f: - pickle.dump(fragment_agglomeration_flattened, f) - print(f"Saved final {len(fragment_agglomeration_flattened)} items to {file_path}") - fragment_agglomeration_final.update(fragment_agglomeration_flattened) - - return fragment_agglomeration_final - - -def relabel_cube_batched_wrapped(kwargs): - return relabel_cube_batched(**kwargs) - - -def relabel_globally( - fragment_agglomeration_flattened, - patched_seg, - aff, - conf, - patch_to_coords, - ijk_to_idx, -): - """ - Unite indexing within multiple cubes - if an object spans multiple cubes, it should have the same index everywhere - """ - print(f"Global relabeling...") - print(f"Agglomerated segmentation: {conf.path_root}/agglomerated_seg.zarr") - agglomerated_seg = zarr.zeros( - (aff.shape[1:]), - chunks=(conf.patch_size, conf.patch_size, conf.patch_size), - dtype=np.uint64, # cheap because of zarr compression - store=f"{conf.path_root}/agglomerated_seg.zarr", - ) - - cubes = list(patch_to_coords.items()) - cubes_batched = chunk_list(cubes, 100) - - if not conf.use_parallelization: - result = list(tqdm(map( - partial( - relabel_cube, - patched_seg=patched_seg, - fragment_agglomeration_flattened=fragment_agglomeration_flattened, - ijk_to_idx=ijk_to_idx, - agglomerated_seg=agglomerated_seg, - conf=conf - ), - cubes), total=len(cubes))) - else: - if not conf.use_slurm: - with ThreadPoolExecutor(max_workers=32) as executor: - result = list(tqdm(executor.map( - partial( - relabel_cube, - patched_seg=patched_seg, - fragment_agglomeration_flattened=fragment_agglomeration_flattened, - ijk_to_idx=ijk_to_idx, - agglomerated_seg=agglomerated_seg, - conf=conf - ), - cubes, - chunksize=8 - ), total=len(cubes), smoothing=0)) - else: - cluster = SLURMCluster( - cores=32, - memory="500GB", - processes=1, - worker_extra_args=["--resources", "processes=1"], - log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/relabel/", - walltime="24:00:00" - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - - with Client(cluster) as client: - print("Dask relabeling Client Dashboard:", client.dashboard_link) - - start_time = time.time() - print(len(ijk_to_idx)) - with open(f"{conf.path_root}/ijk_to_idx.pkl", "wb") as f: - pickle.dump(ijk_to_idx, f) - #ijk_to_idx_future = client.scatter(ijk_to_idx, broadcast=True) - #print(f"Broadcasted ijk_to_idx in {timedelta(seconds=int(time.time() - start_time))}") - #print(list(ijk_to_idx.items())[:10]) - configs = [ - { - "cubes": cubes, - "patched_seg": patched_seg, - "fragment_agglomeration_chunks_path": f"{conf.path_root}/agglo_pkl_chunks", - #"ijk_to_idx": ijk_to_idx_future, - "agglomerated_seg": agglomerated_seg, - "conf": conf, - } - for cubes in cubes_batched - ] - futures = client.map( - relabel_cube_batched_wrapped, - configs, - resources={'processes': 1}, - #batch_size=1 - ) - for _ in tqdm(as_completed(futures), total=len(cubes_batched), smoothing=0): - pass # tqdm progress bar - print(f"Relabeling fragments took {timedelta(seconds=int(time.time() - start_time))}") - #cluster.close() - - return agglomerated_seg - - -def relabel_cube_batched(cubes, patched_seg, fragment_agglomeration_chunks_path, agglomerated_seg, conf): - print(f"{datetime.now()}: start relabel_cube_batched", flush=True) - with open(f"{conf.path_root}/ijk_to_idx.pkl", "rb") as f: - ijk_to_idx = pickle.load(f) - print(f"{datetime.now()}: ijk_to_idx loaded", flush=True) - fragment_agglomeration_flattened = dict() - chunk_files = os.listdir(fragment_agglomeration_chunks_path) - for chunk_file in tqdm(chunk_files): - with open(os.path.join(fragment_agglomeration_chunks_path, chunk_file), "rb") as f: - chunk = pickle.load(f) - fragment_agglomeration_flattened.update(chunk) - print(f"{datetime.now()}: fragment_agglomeration_flattened loaded {len(fragment_agglomeration_flattened)}", flush=True) - - print(f"{datetime.now()}: Relabeling cube chunk", flush=True) - for cube in cubes: - relabel_cube(cube, patched_seg, fragment_agglomeration_flattened, ijk_to_idx, agglomerated_seg, conf) - print(f"{datetime.now()}: End relabeling cube chunk", flush=True) - - -def relabel_cube(cube, patched_seg, fragment_agglomeration_flattened, ijk_to_idx, agglomerated_seg, conf): - """ - If an object spans multiple cubes, relabel the indices to be the same - """ - (i, j, k), ((x_start, x_end), (y_start, y_end), (z_start, z_end)) = cube - # todo: for dask: fragment_agglomeration_flattened is big, load it in here from disk (once for several items?) - - cube = patched_seg[ijk_to_idx[i, j, k]] - perm = [0] - for idx in range(1, int(cube.max()) + 1): # assuming cube has continuous indices from 0 to max - if (i, j, k, idx) in fragment_agglomeration_flattened: # object (idx) continued in neighboring cube -> already has a unique id - perm.append(fragment_agglomeration_flattened[i, j, k, idx]) - else: # object only in this cube - # use upper 32 bits to indicate cube, lower 32 bits to indicate id - perm.append((ijk_to_idx[i, j, k] + 1) * np.uint64(2 ** 32) + idx) - perm = np.array(perm, dtype=np.uint64) - - relabeled = perm[cube[conf.overlap:, conf.overlap:, conf.overlap:]] - if len(perm) > 1: - print(cube.shape, np.max(cube), (i,j,k), ijk_to_idx[i,j,k]) - print(len(perm), perm[1] if len(perm) > 1 else "out of bounds") - print(np.max(relabeled)) - # can't just use agglomerated_seg[x_start:x_end, y_start:y_end, z_start:z_end] = relabeled - # because chunks at the boundary can be smaller - cur_shape = agglomerated_seg[x_start:x_end, y_start:y_end, z_start:z_end].shape - # this is exactly 1 chunk (chunk-borders) -> no race conditions / overwriting - agglomerated_seg[x_start:x_end, y_start:y_end, z_start:z_end] = relabeled[: cur_shape[0], : cur_shape[1], - : cur_shape[2]] - - -def size_filter_relabel(seg, conf): - """ - Filters out segments that are too small (less than minsize voxels) - and relabels the remaining segments contiguously from 1. - """ - start_time = time.time() - - block_indices = [(i, j, k) for i in range(seg.cdata_shape[0]) for j in range(seg.cdata_shape[1]) for k in - range(seg.cdata_shape[2])] - block_indices_batched = chunk_list(block_indices, 16) - combined_counter = Counter() - - print("Counting occurences of fragments...") - if not conf.use_parallelization: - result = list(tqdm(map( - partial(batched_unique, seg=seg), - block_indices_batched, - ), total=len(block_indices_batched), smoothing=0)) - for counter in tqdm(result, total=len(block_indices_batched), smoothing=0): - combined_counter.update(counter) - else: - if not conf.use_slurm: - with ThreadPoolExecutor(max_workers=8) as executor: - result = list(tqdm(executor.map( - partial(batched_unique, seg=seg), - block_indices_batched, - ), total=len(block_indices_batched), smoothing=0)) - for counter in tqdm(result, total=len(block_indices_batched), smoothing=0): - combined_counter.update(counter) - else: - cluster = SLURMCluster( - cores=32, - memory="800GB", - log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/count/", - processes=1, - worker_extra_args=["--resources processes=1"], - walltime="12:00:00" - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - with Client(cluster) as client: - print("Dask counting Client Dashboard:", client.dashboard_link) - futures = client.map(batched_unique, block_indices_batched, seg=seg, resources={'processes': 1}) - - for future in (pbar := tqdm(as_completed(futures), total=len(block_indices_batched), smoothing=0)): - counter = future.result() - filtered = {k: v for k, v in counter.items() if v > 10} - combined_counter.update(filtered) - del counter - del future - mem = psutil.Process().memory_full_info() - rss = mem.rss / 1e9 - vms = mem.vms / 1e9 - pbar.set_postfix(rss=f"{rss:.1f} GB", vms=f"{vms:.1f} GB") - gc.collect() - - #cluster.close() - - remaining_ids = {id for id, count in combined_counter.items() if count > conf.minsize} - id_mapping_remaining = {old_id: new_id for new_id, old_id in enumerate(sorted(remaining_ids))} - assert id_mapping_remaining[0] == 0 - with open(f"{conf.path_root}/id_mapping.csv", 'w') as f: - for k, v in id_mapping_remaining.items(): - f.write(f"{k} {v}\n") - - print(f"Store: {conf.path_root}/relabeled_seg.zarr") - relabeled_seg = zarr.zeros( - seg.shape, - chunks=seg.chunks, - dtype=np.uint32, - store=f"{conf.path_root}/relabeled_seg.zarr", - ) - - print("Filtering out small fragments and relabeling contiguously...") - if not conf.use_parallelization: - result = list(tqdm(map(partial(batched_relabel, seg=seg, relabeled_seg=relabeled_seg, conf=conf), - block_indices_batched))) - else: - if not conf.use_slurm: - with ThreadPoolExecutor(max_workers=16) as executor: - result = list(tqdm(executor.map(partial(batched_relabel, seg=seg, relabeled_seg=relabeled_seg, conf=conf), - block_indices_batched))) - else: - cluster = SLURMCluster( - cores=16, - memory="500GB", - log_directory=f"/cajal/scratch/projects/misc/zuzur/slurm_logs/filter/", - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - with Client(cluster) as client: - print("Dask relabeling Client Dashboard:", client.dashboard_link) - futures = client.map(partial(batched_relabel, seg=seg, relabeled_seg=relabeled_seg, conf=conf), - block_indices_batched) - for _ in tqdm(as_completed(futures), total=len(block_indices_batched), smoothing=0): - pass - #cluster.close() - print(f"Filtering small fragments and relabeling took {timedelta(seconds=int(time.time() - start_time))}") - return relabeled_seg - - -def batched_unique(block_indices, seg): - print(f"{datetime.now()}: start count", flush=True) - c = Counter() - for idx in block_indices: - chunk_series = pd.Series(seg.blocks[idx].ravel()) - c.update(chunk_series.value_counts().to_dict()) - print(f"{datetime.now()}: end count", flush=True) - return c - - -def batched_relabel(block_indices, seg, relabeled_seg, conf): - mapping = {} - with open(f"{conf.path_root}/id_mapping.csv", 'r') as f: - for line in f: - key, value = line.split() - mapping[int(key)] = int(value) - for block_index in block_indices: - block = seg.blocks[block_index] - masked_block = fastremap.mask_except(block, list(mapping.keys())) - relabeled_block = fastremap.remap(masked_block, mapping) - relabeled_seg.blocks[block_index] = relabeled_block - return None - - -def main(conf): - print(conf) - aff = zarr.open(conf.aff_path, mode="r") - - if len(conf.thresholds) > 0: - thr = conf.thresholds[0] - - path_root = f"{conf.path_base}/{f'thr_{thr}'}/" - print(f"Root path: {path_root}") - os.makedirs(path_root, exist_ok=True) - - with open_dict(conf): - conf.path_root = path_root - conf.thr = thr - - start_time = time.time() - segmentation = patched_thresholding( - aff, - conf - ) - print(f"Patched thresholding took {timedelta(seconds=int(time.time() - start_time))}") - - elif len(conf.mws_biases_short) > 0: - biases = list(itertools.product(conf.mws_biases_short, conf.mws_biases_long)) - for (short, long) in biases: - print(f"SEGMENTATION FOR {short}, {long}") - path_root = f"{conf.path_base}/{f'mws_{short}_{long}'}/" - print(f"Root path: {path_root}") - os.makedirs(path_root, exist_ok=False) - with open_dict(conf): - conf.path_root = path_root - conf.mws_bias_short = short - conf.mws_bias_long = long - - start_time = time.time() - segmentation = patched_thresholding( - aff, - conf - ) - print(f"Patched thresholding took {timedelta(seconds=int(time.time() - start_time))}") - pass - return - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise configargparse.ArgumentTypeError('Boolean value expected.') - - -@hydra.main(config_path=".") -def main_wrapper(conf: DictConfig): - return main(conf) - -if __name__ == "__main__": - main_wrapper() \ No newline at end of file diff --git a/debug_progress.py b/debug_progress.py deleted file mode 100644 index fa69dd1..0000000 --- a/debug_progress.py +++ /dev/null @@ -1,40 +0,0 @@ -import dask -from distributed import LocalCluster, progress, as_completed -from tqdm import tqdm -from tqdm.dask import TqdmCallback -from dask import delayed, compute, persist -from dask.distributed import Client -import time - - -# Create some simulated tasks -def work(x): - time.sleep(1) - return x * x - -if __name__ == '__main__': - # Start a local distributed cluster - cluster = LocalCluster(n_workers=1, threads_per_worker=1) - client = Client(cluster) - - print("computing with persist") - tasks = [dask.delayed(work)(i) for i in range(20)] - x = persist(tasks) # start computation in the background - progress(x) - results1 = client.gather(x) - - print("computing with tqdm (doesn't work)") - tasks = [dask.delayed(work)(i) for i in range(20)] - with TqdmCallback(desc="Distributed compute", total=len(tasks), mininterval=0.5): - results = client.compute(tasks, sync=True) - - print("computing with as_completed") - tasks = [dask.delayed(work)(i) for i in range(20)] - futures = client.compute(tasks) - for future in tqdm( - as_completed(futures), - total=len(futures), - smoothing=0, - desc="Predicting chunks" - ): - pass diff --git a/debug_retrain.py b/debug_retrain.py deleted file mode 100644 index a8f6fc3..0000000 --- a/debug_retrain.py +++ /dev/null @@ -1,162 +0,0 @@ -import os - -import numpy as np -import torch -from pytorch_lightning import Trainer -from torch.utils.data import Dataset, DataLoader -from tqdm import tqdm -import zarr - -from BANIS import BANIS, parse_args - -from data import comp_affinities, load_data - - -def train_model_with_samples(model_path, dataloader): - model = BANIS.load_from_checkpoint(model_path) - - trainer = Trainer( - max_steps=1000, - accelerator="gpu", - devices=1, - ) - trainer.fit(model, dataloader) - - -def compare_model_weights(path_a, path_b): - model_a = BANIS.load_from_checkpoint(path_a) - model_b = BANIS.load_from_checkpoint(path_b) - - state_a = model_a.state_dict() - state_b = model_b.state_dict() - - diffs = {} - total_diff = 0.0 - total_norm = 0.0 - total_cos_sim = 0.0 - biggest_diff = 0 - nonfinite_a = 0 - nonfinite_b = 0 - total_params = 0 - - n_layers = 0 - - for key in tqdm(state_a): - #print(f"getting {key}") - param_a = state_a[key].flatten() - param_b = state_b[key].flatten() - - diff = torch.abs(param_a - param_b).mean().item() - if diff > biggest_diff: - biggest_diff = diff - norm = torch.norm(param_a - param_b).item() - cos_sim = torch.nn.functional.cosine_similarity(param_a.unsqueeze(0), param_b.unsqueeze(0)).item() - - diffs[key] = { - 'mean_abs_diff': diff, - 'l2_norm': norm, - 'cosine_similarity': cos_sim, - } - - total = param_a.numel() - nonfinite_params_a = (~torch.isfinite(param_a)).sum().item() - nonfinite_params_b = (~torch.isfinite(param_b)).sum().item() - total_params += total - nonfinite_a += nonfinite_params_a - nonfinite_b += nonfinite_params_b - - total_diff += diff - total_norm += norm - total_cos_sim += cos_sim - n_layers += 1 - - avg_diff = total_diff / n_layers - avg_norm = total_norm / n_layers - avg_cos_sim = total_cos_sim / n_layers - - return { - 'avg_mean_abs_diff': avg_diff, - 'avg_l2_norm': avg_norm, - 'avg_cosine_similarity': avg_cos_sim, - 'biggest_diff': biggest_diff, - 'nonfinite_a': nonfinite_a, - 'nonfinite_b': nonfinite_b, - 'total_params': total_params, - #'layerwise': diffs - } - - #print(compare_model_weights( - # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=110000.ckpt", - # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=115000.ckpt" - # )) - ## {'avg_mean_abs_diff': nan, 'avg_l2_norm': nan, 'avg_cosine_similarity': nan, 'biggest_diff': 0, 'nonfinite_a': 0, 'nonfinite_b': 62993031, 'total_params': 62993031} - #print(compare_model_weights( - # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=100000.ckpt", - # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=105000.ckpt" - # )) - ## {'avg_mean_abs_diff': 0.017999131043465913, 'avg_l2_norm': 2.9829090611515583, 'avg_cosine_similarity': 0.9655602666255757, 'biggest_diff': 0.06504751741886139} - #print(compare_model_weights( - # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=105000.ckpt", - # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=110000.ckpt" - # )) - ## {'avg_mean_abs_diff': 0.01891954702438203, 'avg_l2_norm': 3.084119017241339, 'avg_cosine_similarity': 0.9619980035223629, 'biggest_diff': 0.09889261424541473} - #print(compare_model_weights( - # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=90000.ckpt", - # "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=110000.ckpt" - # )) - ## {'avg_mean_abs_diff': 0.03352864947696027, 'avg_l2_norm': 5.531613261509161, 'avg_cosine_similarity': 0.9207515456647084, 'biggest_diff': 0.1563815325498581} - - -def prepare_good_samples(): - args = parse_args() - train_data, val_data, n_channels = load_data(args) - return train_data - - -class SimpleDataset(Dataset): - def __init__(self, samples): - self.samples = samples - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - sample = self.samples[idx] - return { - "img": torch.from_numpy(sample["img"]), - "aff": torch.from_numpy(sample["aff"]), - "seg": torch.from_numpy(sample["seg"]), - } - -def prepare_bad_samples(): - samples = [] - runs_root = "/cajal/scratch/projects/misc/zuzur/ss3/" - for run in os.listdir(runs_root): - if run.startswith("debug1GPU-seed"): - for candidate in os.listdir(os.path.join(runs_root, run)): - if candidate.endswith("0_img.zarr"): - img = zarr.open(os.path.join(runs_root, run, candidate)) - seg_name = candidate.replace("img", "seg") - seg = zarr.open(os.path.join(runs_root, run, seg_name)) - aff, _ = comp_affinities(seg[:]) - data = { - "img": img.astype(np.float16), - "seg": seg, - "aff": aff, - } - samples.append(data) - if len(samples) >= 1000: - return SimpleDataset(samples) - return SimpleDataset(samples) - - -if __name__ == "__main__": - good_model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=110000.ckpt" - bad_model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=115000.ckpt" - early_model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=50000.ckpt" - - bad_samples = DataLoader(prepare_bad_samples(), batch_size=1, num_workers=8, shuffle=True, drop_last=True) - good_samples = DataLoader(prepare_good_samples(), batch_size=1, num_workers=8, shuffle=True, drop_last=True) - - train_model_with_samples(good_model_path, good_samples) - train_model_with_samples(good_model_path, bad_samples) diff --git a/debug_test_inference.py b/debug_test_inference.py deleted file mode 100644 index e8a1102..0000000 --- a/debug_test_inference.py +++ /dev/null @@ -1,88 +0,0 @@ -import zarr - -from inference import measure_stats, predict_aff, full_inference, thresholding -from inference2 import Thresholding, AffinityPredictor - - -def test_local_prediction(): - input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" - img_data = zarr.open(input_path, mode="r")["img"] - - model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" - from BANIS import BANIS - - model = BANIS.load_from_checkpoint(model_path) - - all_stats = {} - - for chunk_cube_size in [200, 400, 512, 750, 1024, 1500, 3000]: - measured_predict_aff = measure_stats(predict_aff) - - result, stats = measured_predict_aff(img_data, model, chunk_cube_size=chunk_cube_size, compute_backend="local", - zarr_path=f"/cajal/scratch/projects/misc/zuzur/test{chunk_cube_size}.zarr", do_overlap=True, - prediction_channels=3, divide=255, small_size=model.hparams.small_size) - - all_stats[chunk_cube_size] = stats - print(f"chunk size {chunk_cube_size}: {stats}") - - print(all_stats) - for (value, stat) in all_stats.items(): - print(f"{value}: {stat}") - - -def test_slurm_prediction(): - input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" - img_data = zarr.open(input_path, mode="r")["img"] - - model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" - from BANIS import BANIS - model = BANIS.load_from_checkpoint(model_path) - - measured_predict_aff = measure_stats(predict_aff) - # only one run - runtime dependent on number of available slurm nodes - result, stats = measured_predict_aff(img_data, model_path=model_path, chunk_cube_size=512, compute_backend="slurm", - zarr_path=f"/cajal/scratch/projects/misc/zuzur/test_slurm.zarr", do_overlap=True, - prediction_channels=3, divide=255, small_size=model.hparams.small_size) - - print(stats) - -def test_full_inference(): - input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" - img_data = zarr.open(input_path, mode="r")["img"] - - model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" - from BANIS import BANIS - model = BANIS.load_from_checkpoint(model_path) - - full_inference(img_data, model_path, thr=0.7685) - -def test_thresholding_old(): - aff = zarr.open("/cajal/scratch/projects/misc/zuzur/skeleton_recall/rerun_base/dsbase_s0_a0_25-03-20_19-44-08-067760/pred_aff_val_6.zarr/") - thresholding(aff, 0.7685, "test_old2.zarr", 1024, "local") - -def test_thresholding_new(): - aff = zarr.open("/cajal/scratch/projects/misc/zuzur/skeleton_recall/rerun_base/dsbase_s0_a0_25-03-20_19-44-08-067760/pred_aff_val_6.zarr/") - postprocessor = Thresholding(300, "local", 0.7685) - postprocessor.aff_to_seg(aff, zarr_path="test1.zarr") - -def test_prediction_old(): - input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" - img_data = zarr.open(input_path, mode="r")["img"] - model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" - - predict_aff(img_data, model_path=model_path, chunk_cube_size=1024, compute_backend="local", - zarr_path=f"/cajal/scratch/projects/misc/zuzur/test_0.zarr", do_overlap=True, - prediction_channels=3, divide=255, small_size=128) - - -def test_prediction_new(): - input_path = "/cajal/nvmescratch/projects/NISB/base/val/seed100/data.zarr" - img_data = zarr.open(input_path, mode="r")["img"] - model_path = "/cajal/scratch/projects/misc/zuzur/ss3/debug1GPU-seed0-batch_size1-small_size128/default/checkpoints/epoch=0-step=70000.ckpt" - - predictor = AffinityPredictor(model_path=model_path, chunk_cube_size=1024, compute_backend="local", do_overlap=True, - prediction_channels=3, divide=255, small_size=128) - predictor.img_to_aff(img_data, zarr_path=f"/cajal/scratch/projects/misc/zuzur/newtest_0.zarr") - -if __name__ == "__main__": - test_prediction_new() diff --git a/debug_visualilze.py b/debug_visualilze.py deleted file mode 100644 index 4b8e79a..0000000 --- a/debug_visualilze.py +++ /dev/null @@ -1,132 +0,0 @@ -import argparse -import os -import pickle -from typing import Tuple -from collections import defaultdict, Counter, deque - -import dask.array as da -import neuroglancer -import numpy as np -from dask.array import clip -from dask_image.ndfilters import gaussian -from neuroglancer import CoordinateSpace, LocalVolume, Viewer, SegmentationLayer -import zarr -from tqdm import tqdm -import networkx as nx -from networkx import connected_components, subgraph, convert_node_labels_to_integers - -from data import comp_affinities - -""" -Visualizes where the errors in the prediction are. -""" - -class SkeletonSource(neuroglancer.skeleton.SkeletonSource): - def __init__(self, dimensions, skel): - super().__init__(dimensions) - self.skel = skel - - def get_skeleton(self, i): - print(f"Getting skeleton for {i}") - cv_s = self.skel[i] - try: - s = neuroglancer.skeleton.Skeleton(vertex_positions=(cv_s.vertices / [9,9,20]), edges=cv_s.edges) - except Exception as e: - print(e) - return s - - - -# Coordinate spaces -COORDS = { - "standard": CoordinateSpace(names=['x', 'y', 'z'], units=['nm', 'nm', 'nm'], scales=[9, 9, 20]), - "standard_c": CoordinateSpace(names=["x", "y", "z", "c^"], units=["nm", "nm", "nm", ""], scales=[9, 9, 20, 1]), - "liconn": CoordinateSpace(names=['x', 'y', 'z'], units=['nm', 'nm', 'nm'], scales=[9, 9, 12]), - "liconn_c": CoordinateSpace(names=["x", "y", "z", "c^"], units=["nm", "nm", "nm", ""], scales=[9, 9, 12, 1]), - "aff": CoordinateSpace(names=[ "c^", "x", "y", "z"], units=["", "nm", "nm", "nm"], scales=[1, 9, 9, 20]), -} - - -def load_data(data_path: str): - """Load image, segmentation, and skeleton data.""" - seg = da.from_zarr(os.path.join(data_path, "data.zarr", "seg")).astype(np.uint32)[500:1000, 500:1000, 500:1000] - img = da.from_zarr(os.path.join(data_path, "data.zarr", "img"))[500:1000, 500:1000, 500:1000] - skel = da.from_zarr(os.path.join(data_path, "data.zarr", "skel")).astype(np.uint32) - with open(os.path.join(data_path, "skeleton_dense.pkl"), 'rb') as f: - skel_pkl = pickle.load(f) - return img, seg, skel, skel_pkl - - -def add_image_layer(s, name: str, img: da.Array, c_res: CoordinateSpace): - """Add an image layer to the viewer.""" - layer = LocalVolume(img, dimensions=c_res) - s.layers.append(name=f'img_{name}', layer=layer) - - -def add_segmentation_layer(s, name: str, seg: da.Array, skel: dict, res: CoordinateSpace): - """Add a segmentation layer to the viewer.""" - layer = SegmentationLayer( - source=[LocalVolume(seg, dimensions=res, volume_type="segmentation"), SkeletonSource(res, skel)], - skeleton_shader='void main() { emitRGB(vec3(.3, .8, .76)); }', - mesh_silhouette_rendering=2.0 - ) - layer.skeleton_rendering.mode3d = "lines" #"lines_and_points" - s.layers.append(name=f'seg_{name}', layer=layer) - - -def create_viewer(args) -> Viewer: - """Create and configure the Neuroglancer viewer.""" - neuroglancer.set_server_bind_address('localhost', args.port) - viewer = Viewer() - - with viewer.txn() as s: - img, seg, skel, skel_pkl = load_data(args.data_path) - - coord_space = COORDS["standard_c"] - add_image_layer(s, "gt", img, coord_space) - - seg_space = COORDS["standard"] - add_segmentation_layer(s, "gt", seg, skel_pkl, seg_space) - - if True: - aff, _ = comp_affinities(seg) - aff = da.from_array(aff).astype(np.float32) - s.layers["gt_aff"] = neuroglancer.ImageLayer( - source=neuroglancer.LocalVolume( - aff[:3], dimensions=COORDS["aff"], voxel_offset=[0, 0, 0, 0] - ), - shader="""void main() { - emitRGB(vec3(toNormalized(getDataValue(0)), - toNormalized(getDataValue(1)), - toNormalized(getDataValue(2)))); - }""", - ) - - pred_aff = da.from_zarr(args.pred_path).astype(np.float32) - s.layers["pred_aff"] = neuroglancer.ImageLayer( - source=neuroglancer.LocalVolume( - pred_aff[:3], dimensions=COORDS["aff"], voxel_offset=[0, 0, 0, 0] - ), - shader="""void main() { - emitRGB(vec3(toNormalized(getDataValue(0)), - toNormalized(getDataValue(1)), - toNormalized(getDataValue(2)))); - }""", - ) - - - print("If on a remote server, remember port forwarding. Meshes may take time to load.") - - return viewer - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Neuroglancer Viewer for NISB project") - parser.add_argument("--data_path", type=str, default="/cajal/scratch/users/zuzur/NISB_corrected/base/val/seed100", help="Directory which contains data.zarr with segmentation + EM image + skeleton, and skeleton.pkl") - parser.add_argument("--pred_path", type=str, default="/cajal/scratch/projects/misc/zuzur/test.zarr/aff") - parser.add_argument("--port", type=int, default=8589, help="Port to run the viewer") - args = parser.parse_args() - - viewer = create_viewer(args) - print(viewer.get_viewer_url()) - input("Press Enter to quit") diff --git a/environment.yaml b/environment.yaml index 1a60e4a..a303a50 100644 --- a/environment.yaml +++ b/environment.yaml @@ -41,9 +41,11 @@ dependencies: - batchgenerators==0.25 - certifi==2024.8.30 - charset-normalizer==3.4.0 + - cloud_volume==12.4.1 - connected-components-3d==3.19.0 - contourpy==1.3.0 - cycler==0.12.1 + - dask_jobqueue==0.9.0 - dicom2nifti==2.5.0 - fasteners==0.19 - filelock==3.16.1 @@ -71,6 +73,7 @@ dependencies: - multidict==6.1.0 - mwatershed==0.5.3 - networkx==3.3 + - neuroglancer==2.40.1 - nibabel==5.3.0 - numba==0.60.0 - numcodecs==0.13.1 diff --git a/inference.py b/inference.py index b88d5e6..adbd9e5 100644 --- a/inference.py +++ b/inference.py @@ -3,6 +3,7 @@ from copy import deepcopy from typing import Union, List, Tuple +import argparse import cc3d import numba import numpy as np @@ -25,514 +26,555 @@ import mwatershed -def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: - """Scale sigmoid to avoid numerical issues in high confidence fp16.""" - return sigmoid(0.2 * x) - - -def measure_stats(func): - import os - import time - from datetime import timedelta - import tracemalloc - import threading - import psutil - - def monitor_memory(interval=0.1, result=None): - proc = psutil.Process(os.getpid()) - peak = 0 - while not getattr(monitor_memory, "stop", False): - rss = proc.memory_info().rss - peak = max(peak, rss) - time.sleep(interval) - if result is not None: - result["peak"] = peak # Save peak memory to shared dict - - def wrapper(*args, **kwargs): - memory_stats = {} - thread = threading.Thread(target=monitor_memory, kwargs={"interval": 0.1, "result": memory_stats}) - thread.start() - torch.cuda.reset_peak_memory_stats() - tracemalloc.start() - start = time.time() - - result = func(*args, **kwargs) - - end = time.time() - elapsed = timedelta(seconds=end - start) - current, peak = tracemalloc.get_traced_memory() - max_mem = torch.cuda.max_memory_reserved() - monitor_memory.stop = True - thread.join() - - stats = { - "time": f"{elapsed}", - "peak_python_mem": f"{peak / 1024**2:.2f} MB", - "max_cuda_mem": f"{max_mem / 1024 ** 2:.2f} MB", - "rss_mem": f"{memory_stats['peak'] / 1024 ** 2:.2f} MB" - } - - return result, stats - - return wrapper - - -@jit(nopython=True) -def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: - """ - Compute connected components from affinities. - - Args: - hard_aff: The (thresholded, boolean) short range affinities. Shape: (3, x, y, z). - - Returns: - The segmentation. Shape: (x, y, z). - """ - visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=numba.boolean) - seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) - cur_id = 1 - for i in range(visited.shape[0]): - for j in range(visited.shape[1]): - for k in range(visited.shape[2]): - if hard_aff[:, i, j, k].any() and not visited[i, j, k]: # If foreground - cur_to_visit = [(i, j, k)] - visited[i, j, k] = True - while cur_to_visit: - x, y, z = cur_to_visit.pop() - seg[x, y, z] = cur_id - - # Check all neighbors - if x + 1 < visited.shape[0] and hard_aff[0, x, y, z] and not visited[x + 1, y, z]: - cur_to_visit.append((x + 1, y, z)) - visited[x + 1, y, z] = True - if y + 1 < visited.shape[1] and hard_aff[1, x, y, z] and not visited[x, y + 1, z]: - cur_to_visit.append((x, y + 1, z)) - visited[x, y + 1, z] = True - if z + 1 < visited.shape[2] and hard_aff[2, x, y, z] and not visited[x, y, z + 1]: - cur_to_visit.append((x, y, z + 1)) - visited[x, y, z + 1] = True - if x - 1 >= 0 and hard_aff[0, x - 1, y, z] and not visited[x - 1, y, z]: - cur_to_visit.append((x - 1, y, z)) - visited[x - 1, y, z] = True - if y - 1 >= 0 and hard_aff[1, x, y - 1, z] and not visited[x, y - 1, z]: - cur_to_visit.append((x, y - 1, z)) - visited[x, y - 1, z] = True - if z - 1 >= 0 and hard_aff[2, x, y, z - 1] and not visited[x, y, z - 1]: - cur_to_visit.append((x, y, z - 1)) - visited[x, y, z - 1] = True - cur_id += 1 - return seg - - -@torch.no_grad() -@autocast(device_type="cuda") -def predict_aff( - img: Union[np.ndarray, zarr.Array], - model: torch.nn.Module = None, - model_path: str = None, - zarr_path: str = "aff_prediction.zarr", - small_size: int = 128, - do_overlap: bool = True, - prediction_channels: int = 6, - divide: int = 1, - chunk_cube_size: int = 1024, - compute_backend: str = "local" -): - """ - Perform patched affinity prediction with a model on an image. - - Args: - img: The input image. Shape: (x, y, z, channel). - model: The model to use for predictions (only for local prediction). - model_path: Path to the model checkpoint to use for predictions (if model not specified). - zarr_path: Output path to save the prediction in zarr format. - small_size: The size of the patches. Defaults to 128. - do_overlap: Whether to perform overlapping predictions. Defaults to True: - half of patch size for all 3 axes. - prediction_channels: The number of channels in the output (additional model output - dimensions are discarded). Defaults to 6 (3 short + 3 long range affinities). - divide: The divisor for the image. Typically, 1 or 255 if img in [0, 255] - chunk_cube_size: The maximal side length of a cube held in memory. - compute_backend: Type of computation / dask backend. One of: - - - "local": uses a cycle on the local machine (default) - - "local_cluster": uses a localGPUcluster to utilize all local GPUs without SLURM - - "slurm": uses a slurm cluster with all available nodes - - Returns: - The full prediction. Shape: (channel, x, y, z). - """ - print( - f"Performing patched inference with do_overlap={do_overlap} for img of shape {img.shape} and dtype {img.dtype}") - print(f"Parameters: cube size {chunk_cube_size}, compute backend {compute_backend}.") - - all_patch_coordinates = get_coordinates(img.shape[:3], small_size, overlap = small_size // 2 if do_overlap else 0, last_has_smaller_overlap=True) - chunked_patch_coordinates = chunk_xyzs(all_patch_coordinates, chunk_cube_size) - - z = zarr.open_group(zarr_path + "_tmp", mode='w') - zarr_chunk_size = min(chunk_cube_size, 512) - z.create_dataset('sum_pred', shape=(prediction_channels, *img.shape[:3]), - chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), - chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - - if compute_backend == "local": - for chunk in tqdm(chunked_patch_coordinates): - predict_aff_patches_chunked(chunk, img, model_path, zarr_path + "_tmp", small_size, do_overlap, prediction_channels, divide) - torch.cuda.empty_cache() # TODO: does this help? - else: - if compute_backend == "local_cluster": - from dask_cuda import LocalCUDACluster - cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU - elif compute_backend == "slurm": - from dask_jobqueue import SLURMCluster - cluster = SLURMCluster( - cores=8, - memory="400GB", - processes=1, - worker_extra_args=["--resources processes=1", "--nthreads=1"], - job_extra_directives=["--gres=gpu:1"], - walltime="1-00:00:00" - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - +class Utils: + @staticmethod + def get_coordinates(shape: Tuple[int, int, int], small_size: int, overlap: int = 0, + last_has_smaller_overlap: bool = True) -> List[Tuple[int, int, int]]: + """ + Get coordinates for smaller patches to process a big cube in memory. + Args: + shape: The shape of the input (x, y, z). + small_size: The size of the patches. + overlap: The overlap between patches. The default 0 means no overlap (next patch starts on the next pixel from the previous patch). For half-cube overlap set overlap=small_size//2, for 1-pixel overlap set overlap=1. + last_has_smaller_overlap: If the last patch with the specified size and overlap would exceed the big cube, move the patch so that it ends with the big cube, creating a bigger overlap in this patch. + Returns: + List of (x, y, z) coordinates (starting voxel of a patch) for processing of smaller patches. + """ + if overlap < 0 or overlap >= small_size: + raise ValueError(f"Overlap must be between 0 and {small_size}.") + offsets = [Utils.get_offsets(s, small_size, small_size - overlap, last_has_smaller_overlap) for s in shape] + xyzs = [(x, y, z) for x in offsets[0] for y in offsets[1] for z in offsets[2]] + return xyzs + + @staticmethod + def get_offsets(big_size, small_size, step, last_has_smaller_overlap): + offsets = list(range(0, big_size - small_size + 1, step)) + if small_size > big_size: + offsets.append(0) + elif offsets[-1] != big_size - small_size and last_has_smaller_overlap: + offsets.append(big_size - small_size) + elif offsets[-1] != big_size - small_size and not last_has_smaller_overlap: + offsets.append(len(offsets) * step) + return offsets + + @staticmethod + def chunk_xyzs(xyzs, chunk_cube_size=1024): + """ + Chunks the patch coordinates into chunks containing coordinates from the same part of the big cube. + Args: + xyzs: list of all coordinates + chunk_cube_size: side length of each chunk + Returns: + chunked coordinates + """ + chunks = defaultdict(list) + for x, y, z in xyzs: + chunks[(x // chunk_cube_size, y // chunk_cube_size, z // chunk_cube_size)].append((x, y, z)) + return list(chunks.values()) + + @staticmethod + def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: + """Scale sigmoid to avoid numerical issues in high confidence fp16.""" + return sigmoid(0.2 * x) + + @staticmethod + def get_xyz_end(chunk, chunk_cube_size, aff_shape): + """ + Returns the end indices of a chunk, that correspond either to the chunk size, or align with the size of the affinities. + """ + x, y, z = chunk + x_end, y_end, z_end = (min(x + chunk_cube_size, aff_shape[1]), + min(y + chunk_cube_size, aff_shape[2]), + min(z + chunk_cube_size, aff_shape[3])) + return (x_end, y_end, z_end) + + +class AffinityPredictor: + def __init__(self, + chunk_cube_size: int = 1024, + compute_backend: str = "local", + model: torch.nn.Module = None, + model_path: str = None, + small_size: int = 128, + do_overlap: bool = True, + prediction_channels: int = 6, + divide: int = 1, + ): + self.chunk_cube_size = chunk_cube_size + self.compute_backend = compute_backend + + self.model = model # only for local prediction + self.model_path = model_path # loads model in the worker in case of distributed inference (model not pickleable) + self.small_size = small_size + self.do_overlap = do_overlap + self.prediction_channels = prediction_channels + self.divide = divide + + def img_to_aff(self, img, zarr_path): + """ + Complete prediction of affinities from the input image, with the model previously specified in AffinityPredictor. + """ + print(f"Performing patched inference with do_overlap={self.do_overlap} for img of shape {img.shape} and dtype {img.dtype}") + print(f"Parameters: cube size {self.chunk_cube_size}, compute backend {self.compute_backend}.") + + all_patch_coordinates = Utils.get_coordinates(img.shape[:3], self.small_size, overlap=self.small_size // 2 if self.do_overlap else 0, last_has_smaller_overlap=True) + chunked_patch_coordinates = Utils.chunk_xyzs(all_patch_coordinates, self.chunk_cube_size) + + z = zarr.open_group(zarr_path + "_tmp", mode='w') + zarr_chunk_size = min(self.chunk_cube_size, 512) + z.create_dataset('sum_pred', shape=(self.prediction_channels, *img.shape[:3]), chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') + + if self.compute_backend == "local": + for chunk in tqdm(chunked_patch_coordinates, desc="chunks"): + self.predict_aff_patches_chunked(chunk, img, zarr_path + "_tmp") + torch.cuda.empty_cache() else: - raise NotImplementedError(f"Compute backend {compute_backend} not available.") - - client = Client(cluster) - print(f"Waiting for workers...") - client.wait_for_workers(n_workers=1) - print("Dask Client Dashboard:", client.dashboard_link) - tasks = [dask.delayed(predict_aff_patches_chunked)(chunk, img, model_path, zarr_path + "_tmp", small_size, do_overlap, prediction_channels, divide) - for chunk in chunked_patch_coordinates - ] - futures = persist(tasks) - progress(futures) # progress bar - compute(futures) - - tmp_sum_pred = da.from_zarr(f"{zarr_path}_tmp/sum_pred") - tmp_sum_weight = da.from_zarr(f"{zarr_path}_tmp/sum_weight") - aff = tmp_sum_pred / tmp_sum_weight - aff.to_zarr(zarr_path, overwrite=True) - - shutil.rmtree(zarr_path + "_tmp") - - return zarr.open(zarr_path, mode="r") - - -def get_coordinates(shape: Tuple[int, int, int], small_size: int, overlap: int = 0, last_has_smaller_overlap: bool = True) -> List[Tuple[int, int, int]]: - """ - Get coordinates for smaller patches to process a big cube in memory. - Args: - shape: The shape of the input (x, y, z). - small_size: The size of the patches. - overlap: The overlap between patches. The default 0 means no overlap (next patch starts on the next pixel from the previous patch). For half-cube overlap set overlap=small_size//2, for 1-pixel overlap set overlap=1. - last_has_smaller_overlap: If the last patch with the specified size and overlap would exceed the big cube, move the patch so that it ends with the big cube, creating a bigger overlap in this patch. - Returns: - List of (x, y, z) coordinates (starting voxel of a patch) for processing of smaller patches. - """ - if overlap < 0 or overlap >= small_size: - raise ValueError(f"Overlap must be between 0 and {small_size}.") - offsets = [get_offsets(s, small_size, small_size-overlap, last_has_smaller_overlap) for s in shape] - xyzs = [(x, y, z) for x in offsets[0] for y in offsets[1] for z in offsets[2]] - return xyzs - - -def get_offsets(big_size, small_size, step, last_has_smaller_overlap): - offsets = list(range(0, big_size - small_size + 1, step)) - if small_size > big_size: - offsets.append(0) - elif offsets[-1] != big_size - small_size and last_has_smaller_overlap: - offsets.append(big_size - small_size) - elif offsets[-1] != big_size - small_size and not last_has_smaller_overlap: - offsets.append(len(offsets) * step) - return offsets - - -def get_single_pred_weight(do_overlap: bool, small_size: int) -> Union[np.ndarray, None]: - """ - Get the weight for a single prediction. - - Args: - do_overlap: Whether to perform overlapping predictions. - small_size: The size of the patches. - - Returns: - The weight array for a single prediction, or None if no overlap. - """ - if do_overlap: - # The weight (confidence/expected quality) of the predictions: - # Low at the surface of the predicted cube, high in the center - pred_weight_helper = np.pad(np.ones((small_size,) * 3), 1, mode='constant') - return distance_transform_cdt(pred_weight_helper).astype(np.float32)[1:-1, 1:-1, 1:-1] - else: - return None - - -def chunk_xyzs(xyzs, chunk_cube_size=1024): - """ - Chunks the patch coordinates into chunks containing coordinates from the same part of the big cube. - Args: - xyzs: list of all coordinates - chunk_cube_size: side length of each chunk - Returns: - chunked coordinates - """ - chunks = defaultdict(list) - for x, y, z in xyzs: - chunks[(x // chunk_cube_size, y // chunk_cube_size, z // chunk_cube_size)].append((x, y, z)) - return list(chunks.values()) - - -@torch.no_grad() -@autocast(device_type="cuda") -def predict_aff_patches_chunked(patch_coordinates, img, model_path, zarr_path, small_size, do_overlap, prediction_channels, divide): - """ - Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. - Args: - patch_coordinates: List of patch coordinates. The extension of the coordinates must fit in memory (use adequate chunk size). - Returns: - Affinity prediction of the input chunk. - """ - max_x = max(x for x, y, z in patch_coordinates) - max_y = max(y for x, y, z in patch_coordinates) - max_z = max(z for x, y, z in patch_coordinates) - min_x = min(x for x, y, z in patch_coordinates) - min_y = min(y for x, y, z in patch_coordinates) - min_z = min(z for x, y, z in patch_coordinates) - - img_tmp = img[ - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, - ] - pred_tmp = np.zeros((prediction_channels, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) - weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) - single_pred_weight = get_single_pred_weight(do_overlap, small_size) - - from BANIS import BANIS - print(f"model path: {model_path}", flush=True) - model = BANIS.load_from_checkpoint(model_path) - - for x_global, y_global, z_global in patch_coordinates: - x = x_global - min_x - y = y_global - min_y - z = z_global - min_z - img_patch = torch.tensor(np.moveaxis( - img_tmp[x: x + small_size, y: y + small_size, z: z + small_size], - -1, 0)[None]).to(model.device) / divide - pred = scale_sigmoid(model(img_patch))[0, :prediction_channels] - - weight_tmp[:, x: x + small_size, y: y + small_size, - z: z + small_size] += single_pred_weight if do_overlap else 1 - pred_tmp[:, x: x + small_size, y: y + small_size, z: z + small_size] += pred.detach().cpu().numpy() * ( - single_pred_weight[None] if do_overlap else 1) - - z = zarr.open_group(zarr_path, mode='a') - weight_mask = z['sum_weight'] - full_pred = z['sum_pred'] - - with FileLock(f"{zarr_path}/sum_weight.lock"): - weight_mask[ - :, - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, - ] += weight_tmp - - with FileLock(f"{zarr_path}/sum_pred.lock"): - full_pred[ - :, - min_x: max_x + small_size, - min_y: max_y + small_size, - min_z: max_z + small_size, - ] += pred_tmp - - -def update_fragment_agglomeration(fragment_agglomeration, matching_l, matching_h, chunk_l, chunk_h): - combined = np.stack([matching_l, matching_h]).T - uniques = np.unique(combined, axis=0) - for id_l, id_h in uniques: - if id_l > 0 and id_h > 0: - fragment_agglomeration.setdefault((chunk_h, id_h), set()).add((chunk_l, id_l)) - fragment_agglomeration.setdefault((chunk_l, id_l), set()).add((chunk_h, id_h)) - return fragment_agglomeration - - -def flatten_agglomeration(fragment_agglomeration): - """ - Computes connected components in the fragment agglomeration graph, and assigns the fragments new ids starting from 1. - Args: - fragment_agglomeration: dictionary with keys (chunk_id, fragment_id), and values a set of (chunk_id, fragment_id) in another chunk (cube) that should be connected - Returns: - fragment_agglomeration_flattened: dictionary with keys (chunk_id, fragment_id) and values the global component index - """ - cur_id = 1 - fragment_agglomeration_flattened = dict() - for position_id in tqdm(fragment_agglomeration): # (chunk, idx) = position_id - if position_id not in fragment_agglomeration_flattened: - to_visit = {position_id} - visited = set() - while len(to_visit) > 0: - current = to_visit.pop() - if current not in visited: - visited.add(current) - for neighbor in fragment_agglomeration[current]: - to_visit.add(neighbor) - for v in visited: - assert v not in fragment_agglomeration_flattened - fragment_agglomeration_flattened[v] = cur_id - cur_id += 1 - - return cur_id, fragment_agglomeration_flattened - - -def add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, cur_id, chunks, zarr_path): - z = zarr.open(f"{zarr_path}_tmp/instances_patched") - for i, chunk in enumerate(tqdm(chunks)): - data = z[i, :, :, :] - for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max - if (i, idx) not in fragment_agglomeration_flattened: - fragment_agglomeration_flattened[(i, idx)] = cur_id - cur_id += 1 - return fragment_agglomeration_flattened - - -def thresholding(aff, thr, zarr_path, chunk_cube_size, compute_backend): - chunks = get_coordinates(aff.shape[1:], chunk_cube_size, overlap=0, last_has_smaller_overlap=False) - reverse_chunks = {chunk: i for i, chunk in enumerate(chunks)} - - z_root = zarr.open_group(zarr_path + "_tmp", mode='w') - zarr_chunk_size = min(chunk_cube_size, 512) - z_root.create_dataset('instances_patched', shape=(len(chunks), chunk_cube_size, chunk_cube_size, chunk_cube_size), - chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='i4') - - # SEGMENT AFFINITIES IN CHUNKS THAT FIT IN MEMORY - if compute_backend == "local": - for i, chunk in enumerate(tqdm(chunks)): - x, y, z = chunk - x_end, y_end, z_end = min(x + chunk_cube_size, aff.shape[1]), min(y + chunk_cube_size, aff.shape[2]), min(z + chunk_cube_size, aff.shape[3]) - curr_aff = aff[:3, x : x_end, y : y_end, z : z_end] - curr_seg = compute_connected_component_segmentation(curr_aff > thr) - z_root["instances_patched"][i, : x_end - x, : y_end - y, : z_end - z] = curr_seg - else: - raise NotImplementedError(f"Compute backend {compute_backend} not implemented.") + if self.compute_backend == "local_cluster": + from dask_cuda import LocalCUDACluster + cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU + elif self.compute_backend == "slurm": + from dask_jobqueue import SLURMCluster + cluster = SLURMCluster( + cores=8, + memory="400GB", + processes=1, + worker_extra_args=["--resources processes=1", "--nthreads=1"], + job_extra_directives=["--gres=gpu:1"], + walltime="1-00:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") + + client = Client(cluster) + print(f"Waiting for workers...") + client.wait_for_workers(n_workers=1) + print("Dask Client Dashboard:", client.dashboard_link) + tasks = [dask.delayed(self.predict_aff_patches_chunked)(chunk, img, zarr_path + "_tmp") for chunk in chunked_patch_coordinates] + futures = persist(tasks) + progress(futures) # progress bar + compute(futures) + + tmp_sum_pred = da.from_zarr(f"{zarr_path}_tmp/sum_pred") + tmp_sum_weight = da.from_zarr(f"{zarr_path}_tmp/sum_weight") + aff = tmp_sum_pred / tmp_sum_weight + aff.to_zarr(zarr_path, overwrite=True) + + shutil.rmtree(zarr_path + "_tmp") + + return + + def predict_aff_patches_chunked(self, patch_coordinates, img, zarr_path): + """ + Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. + Args: + patch_coordinates: List of patch coordinates. The extension of the coordinates must fit in memory (use adequate chunk size). + Returns: + Affinity prediction of the input chunk. + """ + max_x = max(x for x, y, z in patch_coordinates) + max_y = max(y for x, y, z in patch_coordinates) + max_z = max(z for x, y, z in patch_coordinates) + min_x = min(x for x, y, z in patch_coordinates) + min_y = min(y for x, y, z in patch_coordinates) + min_z = min(z for x, y, z in patch_coordinates) + + img_tmp = img[ + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] + pred_tmp = np.zeros((self.prediction_channels, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) + weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) + single_pred_weight = self.get_single_pred_weight(self.do_overlap, self.small_size) + + if not self.model: + from BANIS import BANIS + print(self.model_path, flush=True) + model = BANIS.load_from_checkpoint(self.model_path) + else: + model = self.model + + for x_global, y_global, z_global in tqdm(patch_coordinates, desc=f'cube ({min_x}, {max_x + self.small_size}), ({min_y}, {max_y + self.small_size}), ({min_z}, {max_z + self.small_size})'): + x = x_global - min_x + y = y_global - min_y + z = z_global - min_z + img_patch = torch.tensor(np.moveaxis(img_tmp[x: x + self.small_size, y: y + self.small_size, z: z + self.small_size], -1, 0)[None]).to(model.device) / self.divide + pred = Utils.scale_sigmoid(model(img_patch))[0, :self.prediction_channels] + + weight_tmp[:, x: x + self.small_size, y: y + self.small_size, z: z + self.small_size] += single_pred_weight if self.do_overlap else 1 + pred_tmp[:, x: x + self.small_size, y: y + self.small_size, z: z + self.small_size] += pred.detach().cpu().numpy() * (single_pred_weight[None] if self.do_overlap else 1) + + z = zarr.open_group(zarr_path, mode='a') + weight_mask = z['sum_weight'] + full_pred = z['sum_pred'] + + with FileLock(f"{zarr_path}/sum_weight.lock"): + weight_mask[ + :, + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] += weight_tmp + + with FileLock(f"{zarr_path}/sum_pred.lock"): + full_pred[ + :, + min_x: max_x + self.small_size, + min_y: max_y + self.small_size, + min_z: max_z + self.small_size, + ] += pred_tmp + + def get_single_pred_weight(self, do_overlap: bool, small_size: int) -> Union[np.ndarray, None]: + """ + Get the weight for a single prediction. + + Args: + do_overlap: Whether to perform overlapping predictions. + small_size: The size of the patches. + + Returns: + The weight array for a single prediction, or None if no overlap. + """ + if do_overlap: + # The weight (confidence/expected quality) of the predictions: + # Low at the surface of the predicted cube, high in the center + pred_weight_helper = np.pad(np.ones((small_size,) * 3), 1, mode='constant') + return distance_transform_cdt(pred_weight_helper).astype(np.float32)[1:-1, 1:-1, 1:-1] + else: + return None - # FIND GROUPS OF FRAGMENTS THAT SHOULD BE MERGED BETWEEN CHUNKS - if compute_backend == "local": - fragment_agglomeration = {} - for i, chunk in enumerate(tqdm(chunks)): - x, y, z = chunk - x_end, y_end, z_end = min(x + chunk_cube_size, aff.shape[1]), min(y + chunk_cube_size, aff.shape[2]), min(z + chunk_cube_size, aff.shape[3]) - - # merge according to short range affinities between each pair of IDs in neighboring cubes - if x_end < aff.shape[1]: - chunk_h = reverse_chunks[x + chunk_cube_size, y, z] - border_aff = aff[0, x_end - 1 : x_end, y : y_end, z : z_end] >= thr - matching_ids_l = z_root["instances_patched"][i, -1:, :, :][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] - matching_ids_h = z_root["instances_patched"][chunk_h, -1:, :, :][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] - fragment_agglomeration = update_fragment_agglomeration(fragment_agglomeration, matching_ids_l, matching_ids_h, i, chunk_h) - - if y_end < aff.shape[2]: - chunk_h = reverse_chunks[x, y + chunk_cube_size, z] - border_aff = aff[0, x : x_end, y_end - 1 : y_end, z : z_end] >= thr - matching_ids_l = z_root["instances_patched"][i, :, -1:, :][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] - matching_ids_h = z_root["instances_patched"][chunk_h, :, -1:, :][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] - fragment_agglomeration = update_fragment_agglomeration(fragment_agglomeration, matching_ids_l, matching_ids_h, i, chunk_h) - - if z_end < aff.shape[3]: - chunk_h = reverse_chunks[x, y, z + chunk_cube_size] - border_aff = aff[0, x : x_end, y : y_end, z_end - 1 : z_end] >= thr - matching_ids_l = z_root["instances_patched"][i, :, :, -1:][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] - matching_ids_h = z_root["instances_patched"][chunk_h, :, :, -1:][:border_aff.shape[0], :border_aff.shape[1], :border_aff.shape[2]][border_aff] - fragment_agglomeration = update_fragment_agglomeration(fragment_agglomeration, matching_ids_l, matching_ids_h, i, chunk_h) - - curr_id, fragment_agglomeration_flattened = flatten_agglomeration(fragment_agglomeration) - print("MERGING CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) - fragment_agglomeration_flattened = add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, curr_id, chunks, zarr_path) - print("ALL CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) - else: - raise NotImplementedError(f"Compute backend {compute_backend} not implemented.") - - # MERGE AND RELABEL INSTANCES GLOBALLY - z_final = zarr.create(shape=aff.shape[1:], - chunks=(zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='i4', - store=zarr_path, overwrite=True) - - if compute_backend == "local": - for i, chunk in enumerate(tqdm(chunks)): - x, y, z = chunk - x_end, y_end, z_end = min(x + chunk_cube_size, aff.shape[1]), min(y + chunk_cube_size, aff.shape[2]), min(z + chunk_cube_size, aff.shape[3]) - data = z_root["instances_patched"][i, : x_end - x, : y_end - y, : z_end - z] - perm = [0] - for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max - assert (i, idx) in fragment_agglomeration_flattened # all fragments have a new index (congiguous from 0) - perm.append(fragment_agglomeration_flattened[(i, idx)]) - perm = np.array(perm, dtype=np.uint64) - relabeled = perm[data] - z_final[x : x_end, y : y_end, z : z_end] = relabeled +class Postprocessing: + def __init__(self, + chunk_cube_size: int = 1024, + compute_backend: str = "local" + ): + self.chunk_cube_size = chunk_cube_size + self.compute_backend = compute_backend - else: - raise NotImplementedError(f"Compute backend {compute_backend} not implemented.") - - shutil.rmtree(zarr_path + "_tmp") - - -def compute_mws_segmentation(cur_aff, mws_bias_short, mws_bias_long, long_range=10): - """ - Mutex Watershed segmentation. - Args: - cur_aff: An affinity array with 3 short-range and 3 long-range affinities (size must fit in memory). - mws_bias_short: Short-range bias - mws_bias_long: Long-range bias - Returns: - Segmentation of the affinities. - """ - cur_aff = deepcopy(cur_aff).astype(np.float64) - cur_aff[:3] += mws_bias_short - cur_aff[3:] += mws_bias_long - - cur_aff[:3] = np.clip(cur_aff[:3], 0, 1) # short-range attractive edges - cur_aff[3:] = np.clip(cur_aff[3:], -1, 0) # long-range repulsive edges (see the Mutex Watershed paper) - - mws_pred = mwatershed.agglom( - affinities=cur_aff, - offsets=( - [ - [1, 0, 0], - [0, 1, 0], - [0, 0, 1], - [long_range, 0, 0], - [0, long_range, 0], - [0, 0, long_range], - ] - ), - ) + def aff_to_seg(self, aff, zarr_path): + chunks = Utils.get_coordinates(aff.shape[1:], self.chunk_cube_size, overlap=1, last_has_smaller_overlap=False) + reverse_chunks = {chunk: i for i, chunk in enumerate(chunks)} + patched_zarr_path = zarr_path + "_tmp" - # mwatershed is wasteful with IDs (not contiguous) -> filter out single voxel objects and relabel again - # size filter. single voxel objects are irrelevant for merging, take ~95% of IDs in an example cube, causing OOM when creating fragment_agglomeration - dusted = cc3d.dust( # does a cc first (reducing false mergers in add_to_agglomeration) - mws_pred, - threshold=2, - connectivity=6, - in_place=False, - ) - # relabeling to save IDs - pred_relabeled, N = cc3d.connected_components( - dusted, return_N=True, connectivity=6 - ) + zarr_chunk_size = min(self.chunk_cube_size, 512) + z_root = zarr.create(shape=(len(chunks), self.chunk_cube_size, self.chunk_cube_size, self.chunk_cube_size), + store=patched_zarr_path, dtype='i4', overwrite=True, + chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) + + # SEGMENT AFFINITIES IN CHUNKS THAT FIT IN MEMORY + self.patched_segment_affinities(aff, patched_zarr_path, chunks) + + # FIND GROUPS OF FRAGMENTS THAT SHOULD BE MERGED BETWEEN CHUNKS + fragment_agglomeration, max_id = self.agglomerate_fragments(chunks, reverse_chunks, patched_zarr_path, aff.shape) + + # MERGE AND RELABEL INSTANCES GLOBALLY + self.merge_and_relabel(fragment_agglomeration, max_id, patched_zarr_path, zarr_path, chunks, aff.shape) + + return + + def patched_segment_affinities(self, aff, patched_zarr_path, chunks): + if self.compute_backend == "local": + for i, chunk in enumerate(tqdm(chunks)): + self.segment_chunk_wrapped(chunk, i, aff, patched_zarr_path) + else: + if self.compute_backend == "local_cluster": + from dask_cuda import LocalCUDACluster + cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU + elif self.compute_backend == "slurm": + from dask_jobqueue import SLURMCluster + cluster = SLURMCluster( + cores=8, + memory="400GB", + processes=1, + worker_extra_args=["--resources processes=1", "--nthreads=1"], + job_extra_directives=["--gres=gpu:1"], + walltime="1-00:00:00" + ) + cluster.adapt(minimum_jobs=1, maximum_jobs=32) + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") + + client = Client(cluster) + print(f"Waiting for workers...") + client.wait_for_workers(n_workers=1) + print("Dask Client Dashboard:", client.dashboard_link) + tasks = [dask.delayed(self.segment_chunk_wrapped)(chunk, i, aff, patched_zarr_path) for (i, chunk) in enumerate(chunks)] + futures = persist(tasks) + progress(futures) # progress bar + compute(futures) + + def agglomerate_fragments(self, chunks, reverse_chunks, patched_zarr_path, aff_shape): + if self.compute_backend == "local": + fragment_agglomeration = {} + for i, chunk in enumerate(tqdm(chunks)): + chunk_agglomeration = self.agglomerate_chunk(chunk, reverse_chunks, patched_zarr_path, aff_shape) + for node, nbrs in chunk_agglomeration.items(): + for nbr in nbrs: + fragment_agglomeration.setdefault(node, set()).add(nbr) + if len(fragment_agglomeration) > 10_000_000: + print("WARNING: fragment agglomeration too long, might cause problems!") + # TODO: solve this + + curr_id, fragment_agglomeration_flattened = self.flatten_agglomeration(fragment_agglomeration) + #print("MERGING CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) + #fragment_agglomeration_flattened = self.add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, curr_id, chunks, patched_zarr_path) + #print("ALL CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) + + else: + # TODO: add slurm (and measure memory) + raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") - assert (pred_relabeled[mws_pred == 0] == 0).all() # 0 stays 0 - assert N <= np.iinfo(np.uint32).max + return fragment_agglomeration_flattened, curr_id - pred = pred_relabeled.astype(np.uint32) - return pred + def agglomerate_chunk(self, chunk, reverse_chunks, patched_zarr_path, aff_shape): + fragment_agglomeration = {} + x, y, z = chunk + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff_shape) + z_root = zarr.open(patched_zarr_path, mode='r') + + # for (x,y,z) get the last slice of the current cube (l, low) and the first slice of the next cube (h, high) + # these slices overlap, so the voxels should have the same global id + + if x_end < aff_shape[1]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x + self.chunk_cube_size - 1, y, z] + result_l = z_root[chunk_l, -1:, :, :] + result_h = z_root[chunk_h, :1, :, :] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + if y_end < aff_shape[2]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x, y + self.chunk_cube_size - 1, z] + result_l = z_root[chunk_l, :, -1:, :] + result_h = z_root[chunk_h, :, :1, :] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + if z_end < aff_shape[3]: + chunk_l = reverse_chunks[chunk] + chunk_h = reverse_chunks[x, y, z + self.chunk_cube_size - 1] + result_l = z_root[chunk_l, :, :, -1:] + result_h = z_root[chunk_h, :, :, :1] + combined = np.stack([result_l.flatten(), result_h.flatten()]).T + uniques = np.unique(combined, axis=0) + fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) + + return fragment_agglomeration + + def update_fragment_agglomeration(self, fragment_agglomeration, uniques, chunk_l, chunk_h): + for id_l, id_h in uniques: + if id_l > 0 and id_h > 0: + fragment_agglomeration.setdefault((chunk_h, id_h), set()).add( + (chunk_l, id_l) + ) + fragment_agglomeration.setdefault((chunk_l, id_l), set()).add( + (chunk_h, id_h) + ) + return fragment_agglomeration + + def flatten_agglomeration(self, fragment_agglomeration): + """ + Computes connected components in the fragment agglomeration graph, and assigns the fragments new ids starting from 1. + Args: + fragment_agglomeration: dictionary with keys (chunk_id, fragment_id), and values a set of (chunk_id, fragment_id) in another chunk (cube) that should be connected + Returns: + fragment_agglomeration_flattened: dictionary with keys (chunk_id, fragment_id) and values the global component index + """ + cur_id = 1 + fragment_agglomeration_flattened = dict() + for position_id in tqdm(fragment_agglomeration): # (chunk, idx) = position_id + if position_id not in fragment_agglomeration_flattened: + to_visit = {position_id} + visited = set() + while len(to_visit) > 0: + current = to_visit.pop() + if current not in visited: + visited.add(current) + for neighbor in fragment_agglomeration[current]: + to_visit.add(neighbor) + for v in visited: + assert v not in fragment_agglomeration_flattened + fragment_agglomeration_flattened[v] = cur_id + cur_id += 1 + + return cur_id, fragment_agglomeration_flattened + + #def add_all_fragments_to_agglomeration(self, fragment_agglomeration_flattened, cur_id, chunks, patched_zarr_path): + # z_root = zarr.open(patched_zarr_path) + # for i, chunk in enumerate(tqdm(chunks)): + # data = z_root[i, :, :, :] + # for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max + # if (i, idx) not in fragment_agglomeration_flattened: + # fragment_agglomeration_flattened[(i, idx)] = cur_id + # cur_id += 1 + # return fragment_agglomeration_flattened + + def merge_and_relabel(self, fragment_agglomeration, max_id, zarr_patched, zarr_final, chunks, aff_shape): + zarr_chunk_size = min(self.chunk_cube_size, 512) + z_root = zarr.open(zarr_patched) + z_final = zarr.create(shape=aff_shape[1:], + store=zarr_final, dtype='i4', overwrite=True, + chunks=(zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) + + if self.compute_backend == "local": + for i, chunk in enumerate(tqdm(chunks)): + x, y, z = chunk + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff_shape) + data = z_root[i, : x_end - x, : y_end - y, : z_end - z] + perm = [0] + for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max + if not (i, idx) in fragment_agglomeration: + max_id += 1 + perm.append(max_id) + else: + perm.append(fragment_agglomeration[(i, idx)]) + perm = np.array(perm, dtype=np.uint64) + relabeled = perm[data] + z_final[x: x_end, y: y_end, z: z_end] = relabeled + + else: + raise NotImplementedError(f"Compute backend {self.compute_backend} not implemented.") + + shutil.rmtree(zarr_patched) + + def segment_chunk_wrapped(self, chunk, i, aff, zarr_path): + x, y, z = chunk + x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff.shape) + curr_aff = aff[:, x : x_end, y : y_end, z : z_end] + curr_seg = self.segment_chunk(curr_aff) + z_root = zarr.open(zarr_path, mode="r+") + z_root[i, : x_end - x, : y_end - y, : z_end - z] = curr_seg + + def segment_chunk(self, curr_aff): + """ + In-memory segmentation of a chunk of affinities. + Args: + curr_aff: The affinities to segment (must fit in memory). + Returns: + Segmentation of the given affinities. + """ + raise NotImplementedError(f"This method should be overridden in a subclass.") + + +class MutexWatershed(Postprocessing): + def __init__(self, chunk_cube_size, compute_backend, mws_bias_short, mws_bias_long, long_range=10): + super().__init__(chunk_cube_size, compute_backend) + self.mws_bias_short = mws_bias_short + self.mws_bias_long = mws_bias_long + self.long_range = long_range + + def compute_mws_segmentation(self, cur_aff): + cur_aff = deepcopy(cur_aff).astype(np.float64) + cur_aff[:3] += self.mws_bias_short + cur_aff[3:] += self.mws_bias_long + + cur_aff[:3] = np.clip(cur_aff[:3], 0, 1) # short-range attractive edges + cur_aff[3:] = np.clip(cur_aff[3:], -1, 0) # long-range repulsive edges (see the Mutex Watershed paper) + + mws_pred = mwatershed.agglom( + affinities=cur_aff, + offsets=( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [self.long_range, 0, 0], + [0, self.long_range, 0], + [0, 0, self.long_range], + ] + ), + ) + + # mwatershed is wasteful with IDs (not contiguous) -> filter out single voxel objects and relabel again + # size filter. single voxel objects are irrelevant for merging, take ~95% of IDs in an example cube, causing OOM when creating fragment_agglomeration + dusted = cc3d.dust( # does a cc first (reducing false mergers in add_to_agglomeration) + mws_pred, + threshold=2, + connectivity=6, + in_place=False, + ) + # relabeling to save IDs + pred_relabeled, N = cc3d.connected_components( + dusted, return_N=True, connectivity=6 + ) + + assert (pred_relabeled[mws_pred == 0] == 0).all() # 0 stays 0 + assert N <= np.iinfo(np.uint32).max + + pred = pred_relabeled.astype(np.uint32) + return pred + + def segment_chunk(self, curr_aff): + return self.compute_mws_segmentation(curr_aff) + + + +class Thresholding(Postprocessing): + def __init__(self, chunk_cube_size, compute_backend, thr): + super().__init__(chunk_cube_size, compute_backend) + self.thr = thr + + @staticmethod + @jit(nopython=True) + def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: + """ + Compute connected components from affinities. + + Args: + hard_aff: The (thresholded, boolean) short range affinities. Shape: (3, x, y, z). + + Returns: + The segmentation. Shape: (x, y, z). + """ + visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=numba.boolean) + seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) + cur_id = 1 + for i in range(visited.shape[0]): + for j in range(visited.shape[1]): + for k in range(visited.shape[2]): + if hard_aff[:, i, j, k].any() and not visited[i, j, k]: # If foreground + cur_to_visit = [(i, j, k)] + visited[i, j, k] = True + while cur_to_visit: + x, y, z = cur_to_visit.pop() + seg[x, y, z] = cur_id + + # Check all neighbors + if x + 1 < visited.shape[0] and hard_aff[0, x, y, z] and not visited[x + 1, y, z]: + cur_to_visit.append((x + 1, y, z)) + visited[x + 1, y, z] = True + if y + 1 < visited.shape[1] and hard_aff[1, x, y, z] and not visited[x, y + 1, z]: + cur_to_visit.append((x, y + 1, z)) + visited[x, y + 1, z] = True + if z + 1 < visited.shape[2] and hard_aff[2, x, y, z] and not visited[x, y, z + 1]: + cur_to_visit.append((x, y, z + 1)) + visited[x, y, z + 1] = True + if x - 1 >= 0 and hard_aff[0, x - 1, y, z] and not visited[x - 1, y, z]: + cur_to_visit.append((x - 1, y, z)) + visited[x - 1, y, z] = True + if y - 1 >= 0 and hard_aff[1, x, y - 1, z] and not visited[x, y - 1, z]: + cur_to_visit.append((x, y - 1, z)) + visited[x, y - 1, z] = True + if z - 1 >= 0 and hard_aff[2, x, y, z - 1] and not visited[x, y, z - 1]: + cur_to_visit.append((x, y, z - 1)) + visited[x, y, z - 1] = True + cur_id += 1 + return seg + + def segment_chunk(self, curr_aff): + return self.compute_connected_component_segmentation(curr_aff[:3] > self.thr) def full_inference( # RESOURCES ARGUMENTS: - chunk_cube_size: int = 1024, + chunk_cube_size: int = 3000, compute_backend: str = "local", # AFFINITY PREDICTION ARGUMENTS: img: Union[np.ndarray, zarr.Array] = None, @@ -549,25 +591,62 @@ def full_inference( mws_bias_short: float = -0.5, mws_bias_long: float = -0.5, ): - - aff = predict_aff( - img, + affinity_predictor = AffinityPredictor( + chunk_cube_size=chunk_cube_size, + compute_backend=compute_backend, model_path=model_path, - zarr_path=aff_zarr_path, small_size=small_size, do_overlap=do_overlap, prediction_channels=prediction_channels, divide=divide, - chunk_cube_size=chunk_cube_size, - compute_backend=compute_backend ) + affinity_predictor.img_to_aff(img, zarr_path=aff_zarr_path) + aff = zarr.open(aff_zarr_path, mode="r") if postprocessing_type == "thresholding": - seg = compute_connected_component_segmentation(aff[:3] > thr) - zarr.array(seg, store=seg_zarr_path) + postprocessor = Thresholding(chunk_cube_size, compute_backend, thr) elif postprocessing_type == "mws": - seg = mws(aff, seg_zarr_path, mws_bias_short, mws_bias_long) + postprocessor = MutexWatershed(chunk_cube_size, compute_backend, mws_bias_short, mws_bias_long) else: raise NotImplementedError(f"Postprocessing type {postprocessing_type} is not implemented") + postprocessor.aff_to_seg(aff, zarr_path=seg_zarr_path) + seg = zarr.open(seg_zarr_path, mode="r") print(f"Segmentation saved at {seg_zarr_path}.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--chunk_cube_size", type=int, default=3000, help="The maximal side length of a cube held in memory.") + parser.add_argument("--compute_backend", type=str, default="local", help="Compute backend to use: local, slurm, or local_cluster.") + parser.add_argument("--img_path", type=str, help="The image to segment (path to zarr).") + parser.add_argument("--model_path", type=str, help="The path to the trained model.") + parser.add_argument("--aff_zarr_path", type=str, default="aff_prediction.zarr", help="Where to save the predicted affinities.") + parser.add_argument("--small_size", type=int, default=128, help="Size of the small patches for affinity prediction (model parameter).") + parser.add_argument("--do_overlap", type=bool, default=True, help="Use overlapping patches for affinity prediction for better precision.") + parser.add_argument("--prediction_channels", type=int, default=6, help="The number of prediction channels. Defaults to 6 (3 short + 3 long range affinities).") + parser.add_argument("--divide", type=int, default=255, help="The divisor for the image. Typically, 1 or 255 if img in [0, 255].") + parser.add_argument("--postprocessing_type", type=str, default="thresholding", help="Type of postprocessing to use: thresholding, or mws (mutex watershed).") + parser.add_argument("--seg_zarr_path", type=str, default="seg_prediction.zarr", help="Where to save the final segmentation.") + parser.add_argument("--thr", type=float, default=0.5, help="Threshold in case of thresholding.") + parser.add_argument("--mws_bias_short", type=float, default=-0.5, help="Short-range bias for mutex watershed.") + parser.add_argument("--mws_bias_long", type=float, default=-0.5, help="Long-range bias for mutex watershed.") + + args = parser.parse_args() + + img = zarr.open(args.img_path, mode="r")["img"] + full_inference( + chunk_cube_size=args.chunk_cube_size, + compute_backend=args.compute_backend, + img=img, + model_path=args.model_path, + aff_zarr_path=args.aff_zarr_path, + small_size=args.small_size, + do_overlap=args.do_overlap, + prediction_channels=args.prediction_channels, + divide=args.divide, + postprocessing_type=args.postprocessing_type, + seg_zarr_path=args.seg_zarr_path, + thr=args.thr, + mws_bias_short=args.mws_bias_short, + mws_bias_long=args.mws_bias_long, + ) \ No newline at end of file From 11a9075fa0618c0a3c184453366e4c036b007a8c Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 10 Sep 2025 14:37:56 +0200 Subject: [PATCH 31/33] forgotten file --- inference2.py | 611 -------------------------------------------------- 1 file changed, 611 deletions(-) delete mode 100644 inference2.py diff --git a/inference2.py b/inference2.py deleted file mode 100644 index de865cd..0000000 --- a/inference2.py +++ /dev/null @@ -1,611 +0,0 @@ -import shutil -from collections import defaultdict -from copy import deepcopy -from typing import Union, List, Tuple - -import cc3d -import numba -import numpy as np -import torch -import torch.utils -import zarr -import dask -from dask import compute, persist, delayed -from dask.distributed import Client, LocalCluster -from dask.diagnostics import ProgressBar -import dask.array as da -from distributed import progress -from filelock import FileLock -from numba import jit -from numpy.f2py.crackfortran import updatevars -from scipy.ndimage import distance_transform_cdt -from torch import autocast -from torch.nn.functional import sigmoid -from tqdm import tqdm -import mwatershed - - -class Utils: - @staticmethod - def get_coordinates(shape: Tuple[int, int, int], small_size: int, overlap: int = 0, - last_has_smaller_overlap: bool = True) -> List[Tuple[int, int, int]]: - """ - Get coordinates for smaller patches to process a big cube in memory. - Args: - shape: The shape of the input (x, y, z). - small_size: The size of the patches. - overlap: The overlap between patches. The default 0 means no overlap (next patch starts on the next pixel from the previous patch). For half-cube overlap set overlap=small_size//2, for 1-pixel overlap set overlap=1. - last_has_smaller_overlap: If the last patch with the specified size and overlap would exceed the big cube, move the patch so that it ends with the big cube, creating a bigger overlap in this patch. - Returns: - List of (x, y, z) coordinates (starting voxel of a patch) for processing of smaller patches. - """ - if overlap < 0 or overlap >= small_size: - raise ValueError(f"Overlap must be between 0 and {small_size}.") - offsets = [Utils.get_offsets(s, small_size, small_size - overlap, last_has_smaller_overlap) for s in shape] - xyzs = [(x, y, z) for x in offsets[0] for y in offsets[1] for z in offsets[2]] - return xyzs - - @staticmethod - def get_offsets(big_size, small_size, step, last_has_smaller_overlap): - offsets = list(range(0, big_size - small_size + 1, step)) - if small_size > big_size: - offsets.append(0) - elif offsets[-1] != big_size - small_size and last_has_smaller_overlap: - offsets.append(big_size - small_size) - elif offsets[-1] != big_size - small_size and not last_has_smaller_overlap: - offsets.append(len(offsets) * step) - return offsets - - @staticmethod - def chunk_xyzs(xyzs, chunk_cube_size=1024): - """ - Chunks the patch coordinates into chunks containing coordinates from the same part of the big cube. - Args: - xyzs: list of all coordinates - chunk_cube_size: side length of each chunk - Returns: - chunked coordinates - """ - chunks = defaultdict(list) - for x, y, z in xyzs: - chunks[(x // chunk_cube_size, y // chunk_cube_size, z // chunk_cube_size)].append((x, y, z)) - return list(chunks.values()) - - @staticmethod - def scale_sigmoid(x: torch.Tensor) -> torch.Tensor: - """Scale sigmoid to avoid numerical issues in high confidence fp16.""" - return sigmoid(0.2 * x) - - @staticmethod - def get_xyz_end(chunk, chunk_cube_size, aff_shape): - """ - Returns the end indices of a chunk, that correspond either to the chunk size, or align with the size of the affinities. - """ - x, y, z = chunk - x_end, y_end, z_end = (min(x + chunk_cube_size, aff_shape[1]), - min(y + chunk_cube_size, aff_shape[2]), - min(z + chunk_cube_size, aff_shape[3])) - return (x_end, y_end, z_end) - - -class AffinityPredictor: - def __init__(self, - chunk_cube_size: int = 1024, - compute_backend: str = "local", - model: torch.nn.Module = None, - model_path: str = None, - small_size: int = 128, - do_overlap: bool = True, - prediction_channels: int = 6, - divide: int = 1, - ): - self.chunk_cube_size = chunk_cube_size - self.compute_backend = compute_backend - - self.model = model # only for local prediction - self.model_path = model_path # loads model in the worker in case of distributed inference (model not pickleable) - self.small_size = small_size - self.do_overlap = do_overlap - self.prediction_channels = prediction_channels - self.divide = divide - - def img_to_aff(self, img, zarr_path): - """ - Complete prediction of affinities from the input image, with the model previously specified in AffinityPredictor. - """ - print(f"Performing patched inference with do_overlap={self.do_overlap} for img of shape {img.shape} and dtype {img.dtype}") - print(f"Parameters: cube size {self.chunk_cube_size}, compute backend {self.compute_backend}.") - - all_patch_coordinates = Utils.get_coordinates(img.shape[:3], self.small_size, overlap=self.small_size // 2 if self.do_overlap else 0, last_has_smaller_overlap=True) - chunked_patch_coordinates = Utils.chunk_xyzs(all_patch_coordinates, self.chunk_cube_size) - - z = zarr.open_group("tmp_" + zarr_path, mode='w') - zarr_chunk_size = min(self.chunk_cube_size, 512) - z.create_dataset('sum_pred', shape=(self.prediction_channels, *img.shape[:3]), chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - z.create_dataset('sum_weight', shape=(1, *img.shape[:3]), chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size), dtype='f4') - - if self.compute_backend == "local": - for chunk in tqdm(chunked_patch_coordinates): - self.predict_aff_patches_chunked(chunk, img, "tmp_" + zarr_path) - torch.cuda.empty_cache() - else: - if self.compute_backend == "local_cluster": - from dask_cuda import LocalCUDACluster - cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU - elif self.compute_backend == "slurm": - from dask_jobqueue import SLURMCluster - cluster = SLURMCluster( - cores=8, - memory="400GB", - processes=1, - worker_extra_args=["--resources processes=1", "--nthreads=1"], - job_extra_directives=["--gres=gpu:1"], - walltime="1-00:00:00" - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - - else: - raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") - - client = Client(cluster) - print(f"Waiting for workers...") - client.wait_for_workers(n_workers=1) - print("Dask Client Dashboard:", client.dashboard_link) - tasks = [dask.delayed(self.predict_aff_patches_chunked)(chunk, img, "tmp_" + zarr_path) for chunk in chunked_patch_coordinates] - futures = persist(tasks) - progress(futures) # progress bar - compute(futures) - - tmp_sum_pred = da.from_zarr(f"tmp_{zarr_path}/sum_pred") - tmp_sum_weight = da.from_zarr(f"tmp_{zarr_path}/sum_weight") - aff = tmp_sum_pred / tmp_sum_weight - aff.to_zarr(zarr_path, overwrite=True) - - shutil.rmtree("tmp_" + zarr_path) - - return - - def predict_aff_patches_chunked(self, patch_coordinates, img, zarr_path): - """ - Patch-wise predicts affinities in-memory, using coordinates of all patches inside a chunk. - Args: - patch_coordinates: List of patch coordinates. The extension of the coordinates must fit in memory (use adequate chunk size). - Returns: - Affinity prediction of the input chunk. - """ - max_x = max(x for x, y, z in patch_coordinates) - max_y = max(y for x, y, z in patch_coordinates) - max_z = max(z for x, y, z in patch_coordinates) - min_x = min(x for x, y, z in patch_coordinates) - min_y = min(y for x, y, z in patch_coordinates) - min_z = min(z for x, y, z in patch_coordinates) - - img_tmp = img[ - min_x: max_x + self.small_size, - min_y: max_y + self.small_size, - min_z: max_z + self.small_size, - ] - pred_tmp = np.zeros((self.prediction_channels, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) - weight_tmp = np.zeros((1, img_tmp.shape[0], img_tmp.shape[1], img_tmp.shape[2]), dtype=np.float32) - single_pred_weight = self.get_single_pred_weight(self.do_overlap, self.small_size) - - if not self.model: - from BANIS import BANIS - print(self.model_path, flush=True) - model = BANIS.load_from_checkpoint(self.model_path) - else: - model = self.model - - for x_global, y_global, z_global in patch_coordinates: - x = x_global - min_x - y = y_global - min_y - z = z_global - min_z - img_patch = torch.tensor(np.moveaxis(img_tmp[x: x + self.small_size, y: y + self.small_size, z: z + self.small_size], -1, 0)[None]).to(model.device) / self.divide - pred = Utils.scale_sigmoid(model(img_patch))[0, :self.prediction_channels] - - weight_tmp[:, x: x + self.small_size, y: y + self.small_size, z: z + self.small_size] += single_pred_weight if self.do_overlap else 1 - pred_tmp[:, x: x + self.small_size, y: y + self.small_size, z: z + self.small_size] += pred.detach().cpu().numpy() * (single_pred_weight[None] if self.do_overlap else 1) - - z = zarr.open_group(zarr_path, mode='a') - weight_mask = z['sum_weight'] - full_pred = z['sum_pred'] - - with FileLock(f"{zarr_path}/sum_weight.lock"): - weight_mask[ - :, - min_x: max_x + self.small_size, - min_y: max_y + self.small_size, - min_z: max_z + self.small_size, - ] += weight_tmp - - with FileLock(f"{zarr_path}/sum_pred.lock"): - full_pred[ - :, - min_x: max_x + self.small_size, - min_y: max_y + self.small_size, - min_z: max_z + self.small_size, - ] += pred_tmp - - def get_single_pred_weight(self, do_overlap: bool, small_size: int) -> Union[np.ndarray, None]: - """ - Get the weight for a single prediction. - - Args: - do_overlap: Whether to perform overlapping predictions. - small_size: The size of the patches. - - Returns: - The weight array for a single prediction, or None if no overlap. - """ - if do_overlap: - # The weight (confidence/expected quality) of the predictions: - # Low at the surface of the predicted cube, high in the center - pred_weight_helper = np.pad(np.ones((small_size,) * 3), 1, mode='constant') - return distance_transform_cdt(pred_weight_helper).astype(np.float32)[1:-1, 1:-1, 1:-1] - else: - return None - - -class Postprocessing: - def __init__(self, - chunk_cube_size: int = 1024, - compute_backend: str = "local" - ): - self.chunk_cube_size = chunk_cube_size - self.compute_backend = compute_backend - - def aff_to_seg(self, aff, zarr_path): - chunks = Utils.get_coordinates(aff.shape[1:], self.chunk_cube_size, overlap=1, last_has_smaller_overlap=False) - reverse_chunks = {chunk: i for i, chunk in enumerate(chunks)} - patched_zarr_path = "tmp_" + zarr_path - - zarr_chunk_size = min(self.chunk_cube_size, 512) - z_root = zarr.create(shape=(len(chunks), self.chunk_cube_size, self.chunk_cube_size, self.chunk_cube_size), - store=patched_zarr_path, dtype='i4', overwrite=True, - chunks=(1, zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) - - # SEGMENT AFFINITIES IN CHUNKS THAT FIT IN MEMORY - self.patched_segment_affinities(aff, patched_zarr_path, chunks) - - # FIND GROUPS OF FRAGMENTS THAT SHOULD BE MERGED BETWEEN CHUNKS - fragment_agglomeration = self.agglomerate_fragments(chunks, reverse_chunks, patched_zarr_path, aff.shape) - - # MERGE AND RELABEL INSTANCES GLOBALLY - self.merge_and_relabel(fragment_agglomeration, patched_zarr_path, zarr_path, chunks, aff.shape) - - return - - def patched_segment_affinities(self, aff, patched_zarr_path, chunks): - if self.compute_backend == "local": - for i, chunk in enumerate(tqdm(chunks)): - self.segment_chunk_wrapped(chunk, i, aff, patched_zarr_path) - else: - if self.compute_backend == "local_cluster": - from dask_cuda import LocalCUDACluster - cluster = LocalCUDACluster(threads_per_worker=1) # 1 worker per GPU - elif self.compute_backend == "slurm": - from dask_jobqueue import SLURMCluster - cluster = SLURMCluster( - cores=8, - memory="400GB", - processes=1, - worker_extra_args=["--resources processes=1", "--nthreads=1"], - job_extra_directives=["--gres=gpu:1"], - walltime="1-00:00:00" - ) - cluster.adapt(minimum_jobs=1, maximum_jobs=32) - else: - raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") - - client = Client(cluster) - print(f"Waiting for workers...") - client.wait_for_workers(n_workers=1) - print("Dask Client Dashboard:", client.dashboard_link) - tasks = [dask.delayed(self.segment_chunk_wrapped)(chunk, i, aff, patched_zarr_path) for (i, chunk) in enumerate(chunks)] - futures = persist(tasks) - progress(futures) # progress bar - compute(futures) - - def agglomerate_fragments(self, chunks, reverse_chunks, patched_zarr_path, aff_shape): - if self.compute_backend == "local": - fragment_agglomeration = {} - for i, chunk in enumerate(tqdm(chunks)): - chunk_agglomeration = self.agglomerate_chunk(chunk, reverse_chunks, patched_zarr_path, aff_shape) - for node, nbrs in chunk_agglomeration.items(): - for nbr in nbrs: - fragment_agglomeration.setdefault(node, set()).add(nbr) - if len(fragment_agglomeration) > 10_000_000: - print("WARNING: fragment agglomeration too long, might cause problems!") - # TODO: solve this - - curr_id, fragment_agglomeration_flattened = self.flatten_agglomeration(fragment_agglomeration) - print("MERGING CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) - fragment_agglomeration_flattened = self.add_all_fragments_to_agglomeration(fragment_agglomeration_flattened, curr_id, chunks, patched_zarr_path) - print("ALL CHUNKS FLATTENED AGGLOMERATION LENGTH", len(fragment_agglomeration_flattened)) - - else: - # TODO: add slurm (and measure memory) - raise NotImplementedError(f"Compute backend {self.compute_backend} not available.") - - return fragment_agglomeration_flattened - - def agglomerate_chunk(self, chunk, reverse_chunks, patched_zarr_path, aff_shape): - fragment_agglomeration = {} - x, y, z = chunk - x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff_shape) - z_root = zarr.open(patched_zarr_path, mode='r') - - # for (x,y,z) get the last slice of the current cube (l, low) and the first slice of the next cube (h, high) - # these slices overlap, so the voxels should have the same global id - - if x_end < aff_shape[1]: - chunk_l = reverse_chunks[chunk] - chunk_h = reverse_chunks[x + self.chunk_cube_size - 1, y, z] - result_l = z_root[chunk_l, -1:, :, :] - result_h = z_root[chunk_h, :1, :, :] - combined = np.stack([result_l.flatten(), result_h.flatten()]).T - uniques = np.unique(combined, axis=0) - fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) - - if y_end < aff_shape[2]: - chunk_l = reverse_chunks[chunk] - chunk_h = reverse_chunks[x, y + self.chunk_cube_size - 1, z] - result_l = z_root[chunk_l, :, -1:, :] - result_h = z_root[chunk_h, :, :1, :] - combined = np.stack([result_l.flatten(), result_h.flatten()]).T - uniques = np.unique(combined, axis=0) - fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) - - if z_end < aff_shape[3]: - chunk_l = reverse_chunks[chunk] - chunk_h = reverse_chunks[x, y, z + self.chunk_cube_size - 1] - result_l = z_root[chunk_l, :, :, -1:] - result_h = z_root[chunk_h, :, :, :1] - combined = np.stack([result_l.flatten(), result_h.flatten()]).T - uniques = np.unique(combined, axis=0) - fragment_agglomeration = self.update_fragment_agglomeration(fragment_agglomeration, uniques, chunk_l, chunk_h) - - return fragment_agglomeration - - def update_fragment_agglomeration(self, fragment_agglomeration, uniques, chunk_l, chunk_h): - for id_l, id_h in uniques: - if id_l > 0 and id_h > 0: - fragment_agglomeration.setdefault((chunk_h, id_h), set()).add( - (chunk_l, id_l) - ) - fragment_agglomeration.setdefault((chunk_l, id_l), set()).add( - (chunk_h, id_h) - ) - return fragment_agglomeration - - def flatten_agglomeration(self, fragment_agglomeration): - """ - Computes connected components in the fragment agglomeration graph, and assigns the fragments new ids starting from 1. - Args: - fragment_agglomeration: dictionary with keys (chunk_id, fragment_id), and values a set of (chunk_id, fragment_id) in another chunk (cube) that should be connected - Returns: - fragment_agglomeration_flattened: dictionary with keys (chunk_id, fragment_id) and values the global component index - """ - cur_id = 1 - fragment_agglomeration_flattened = dict() - for position_id in tqdm(fragment_agglomeration): # (chunk, idx) = position_id - if position_id not in fragment_agglomeration_flattened: - to_visit = {position_id} - visited = set() - while len(to_visit) > 0: - current = to_visit.pop() - if current not in visited: - visited.add(current) - for neighbor in fragment_agglomeration[current]: - to_visit.add(neighbor) - for v in visited: - assert v not in fragment_agglomeration_flattened - fragment_agglomeration_flattened[v] = cur_id - cur_id += 1 - - return cur_id, fragment_agglomeration_flattened - - def add_all_fragments_to_agglomeration(self, fragment_agglomeration_flattened, cur_id, chunks, patched_zarr_path): - z_root = zarr.open(patched_zarr_path) - for i, chunk in enumerate(tqdm(chunks)): - data = z_root[i, :, :, :] - for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max - if (i, idx) not in fragment_agglomeration_flattened: - fragment_agglomeration_flattened[(i, idx)] = cur_id - cur_id += 1 - return fragment_agglomeration_flattened - - def merge_and_relabel(self, fragment_agglomeration, zarr_patched, zarr_final, chunks, aff_shape): - zarr_chunk_size = min(self.chunk_cube_size, 512) - z_root = zarr.open(zarr_patched) - z_final = zarr.create(shape=aff_shape[1:], - store=zarr_final, dtype='i4', overwrite=True, - chunks=(zarr_chunk_size, zarr_chunk_size, zarr_chunk_size)) - - if self.compute_backend == "local": - for i, chunk in enumerate(tqdm(chunks)): - x, y, z = chunk - x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff_shape) - data = z_root[i, : x_end - x, : y_end - y, : z_end - z] - perm = [0] - for idx in range(1, int(data.max()) + 1): # assuming each chunk has contiguous indices from 0 to max - assert (i, idx) in fragment_agglomeration # all fragments have a new index (contiguous from 0) - perm.append(fragment_agglomeration[(i, idx)]) - perm = np.array(perm, dtype=np.uint64) - relabeled = perm[data] - z_final[x: x_end, y: y_end, z: z_end] = relabeled - - else: - raise NotImplementedError(f"Compute backend {self.compute_backend} not implemented.") - - shutil.rmtree(zarr_patched) - - def segment_chunk_wrapped(self, chunk, i, aff, zarr_path): - x, y, z = chunk - x_end, y_end, z_end = Utils.get_xyz_end(chunk, self.chunk_cube_size, aff.shape) - curr_aff = aff[:, x : x_end, y : y_end, z : z_end] - curr_seg = self.segment_chunk(curr_aff) - z_root = zarr.open(zarr_path, mode="r+") - z_root[i, : x_end - x, : y_end - y, : z_end - z] = curr_seg - - def segment_chunk(self, curr_aff): - """ - In-memory segmentation of a chunk of affinities. - Args: - curr_aff: The affinities to segment (must fit in memory). - Returns: - Segmentation of the given affinities. - """ - raise NotImplementedError(f"This method should be overridden in a subclass.") - - -class MutexWatershed(Postprocessing): - def __init__(self, chunk_cube_size, compute_backend, mws_bias_short, mws_bias_long, long_range=10): - super().__init__(chunk_cube_size, compute_backend) - self.mws_bias_short = mws_bias_short - self.mws_bias_long = mws_bias_long - self.long_range = long_range - - def compute_mws_segmentation(self, cur_aff): - cur_aff = deepcopy(cur_aff).astype(np.float64) - cur_aff[:3] += self.mws_bias_short - cur_aff[3:] += self.mws_bias_long - - cur_aff[:3] = np.clip(cur_aff[:3], 0, 1) # short-range attractive edges - cur_aff[3:] = np.clip(cur_aff[3:], -1, 0) # long-range repulsive edges (see the Mutex Watershed paper) - - mws_pred = mwatershed.agglom( - affinities=cur_aff, - offsets=( - [ - [1, 0, 0], - [0, 1, 0], - [0, 0, 1], - [self.long_range, 0, 0], - [0, self.long_range, 0], - [0, 0, self.long_range], - ] - ), - ) - - # mwatershed is wasteful with IDs (not contiguous) -> filter out single voxel objects and relabel again - # size filter. single voxel objects are irrelevant for merging, take ~95% of IDs in an example cube, causing OOM when creating fragment_agglomeration - dusted = cc3d.dust( # does a cc first (reducing false mergers in add_to_agglomeration) - mws_pred, - threshold=2, - connectivity=6, - in_place=False, - ) - # relabeling to save IDs - pred_relabeled, N = cc3d.connected_components( - dusted, return_N=True, connectivity=6 - ) - - assert (pred_relabeled[mws_pred == 0] == 0).all() # 0 stays 0 - assert N <= np.iinfo(np.uint32).max - - pred = pred_relabeled.astype(np.uint32) - return pred - - def segment_chunk(self, curr_aff): - return self.compute_mws_segmentation(curr_aff) - - - -class Thresholding(Postprocessing): - def __init__(self, chunk_cube_size, compute_backend, thr): - super().__init__(chunk_cube_size, compute_backend) - self.thr = thr - - @staticmethod - @jit(nopython=True) - def compute_connected_component_segmentation(hard_aff: np.ndarray) -> np.ndarray: - """ - Compute connected components from affinities. - - Args: - hard_aff: The (thresholded, boolean) short range affinities. Shape: (3, x, y, z). - - Returns: - The segmentation. Shape: (x, y, z). - """ - visited = np.zeros(tuple(hard_aff.shape[1:]), dtype=numba.boolean) - seg = np.zeros(tuple(hard_aff.shape[1:]), dtype=np.uint32) - cur_id = 1 - for i in range(visited.shape[0]): - for j in range(visited.shape[1]): - for k in range(visited.shape[2]): - if hard_aff[:, i, j, k].any() and not visited[i, j, k]: # If foreground - cur_to_visit = [(i, j, k)] - visited[i, j, k] = True - while cur_to_visit: - x, y, z = cur_to_visit.pop() - seg[x, y, z] = cur_id - - # Check all neighbors - if x + 1 < visited.shape[0] and hard_aff[0, x, y, z] and not visited[x + 1, y, z]: - cur_to_visit.append((x + 1, y, z)) - visited[x + 1, y, z] = True - if y + 1 < visited.shape[1] and hard_aff[1, x, y, z] and not visited[x, y + 1, z]: - cur_to_visit.append((x, y + 1, z)) - visited[x, y + 1, z] = True - if z + 1 < visited.shape[2] and hard_aff[2, x, y, z] and not visited[x, y, z + 1]: - cur_to_visit.append((x, y, z + 1)) - visited[x, y, z + 1] = True - if x - 1 >= 0 and hard_aff[0, x - 1, y, z] and not visited[x - 1, y, z]: - cur_to_visit.append((x - 1, y, z)) - visited[x - 1, y, z] = True - if y - 1 >= 0 and hard_aff[1, x, y - 1, z] and not visited[x, y - 1, z]: - cur_to_visit.append((x, y - 1, z)) - visited[x, y - 1, z] = True - if z - 1 >= 0 and hard_aff[2, x, y, z - 1] and not visited[x, y, z - 1]: - cur_to_visit.append((x, y, z - 1)) - visited[x, y, z - 1] = True - cur_id += 1 - return seg - - def segment_chunk(self, curr_aff): - return self.compute_connected_component_segmentation(curr_aff[:3] > self.thr) - - -def full_inference( - # RESOURCES ARGUMENTS: - chunk_cube_size: int = 1024, - compute_backend: str = "local", - # AFFINITY PREDICTION ARGUMENTS: - img: Union[np.ndarray, zarr.Array] = None, - model_path: str = None, - aff_zarr_path: str = "aff_prediction.zarr", - small_size: int = 128, - do_overlap: bool = True, - prediction_channels: int = 6, - divide: int = 1, - # POSTPROCESSING ARGUMENTS: - postprocessing_type: str = "thresholding", - seg_zarr_path: str = "seg_prediction.zarr", - thr: float = 0.5, - mws_bias_short: float = -0.5, - mws_bias_long: float = -0.5, -): - affinity_predictor = AffinityPredictor( - chunk_cube_size=chunk_cube_size, - compute_backend=compute_backend, - model_path=model_path, - small_size=small_size, - do_overlap=do_overlap, - prediction_channels=prediction_channels, - divide=divide, - ) - affinity_predictor.img_to_aff(img, zarr_path=aff_zarr_path) - aff = zarr.open(aff_zarr_path, mode="r") - - if postprocessing_type == "thresholding": - postprocessor = Thresholding(chunk_cube_size, compute_backend, thr) - elif postprocessing_type == "mws": - postprocessor = MutexWatershed(chunk_cube_size, compute_backend, mws_bias_short, mws_bias_long) - else: - raise NotImplementedError(f"Postprocessing type {postprocessing_type} is not implemented") - postprocessor.aff_to_seg(aff, zarr_path=seg_zarr_path) - seg = zarr.open(seg_zarr_path, mode="r") - - print(f"Segmentation saved at {seg_zarr_path}.") From 288a6c7f1f61b5987be29245b275d15f6e10bed7 Mon Sep 17 00:00:00 2001 From: Zuzana Urbanova Date: Wed, 10 Sep 2025 14:42:45 +0200 Subject: [PATCH 32/33] return to original config --- aff_train.sh | 12 ++++++------ config.yaml | 43 +++++++++++++++++++++---------------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/aff_train.sh b/aff_train.sh index 074bcfb..90430d1 100644 --- a/aff_train.sh +++ b/aff_train.sh @@ -1,14 +1,14 @@ #!/bin/bash -l -#SBATCH --nodes=2 -#SBATCH --gres=gpu:4 -#SBATCH --ntasks-per-node=4 +#SBATCH --nodes=1 +#SBATCH --gres=gpu:1 +#SBATCH --ntasks-per-node=1 #SBATCH --time=7-00 -#SBATCH --cpus-per-task=16 -#SBATCH --mem=1000G +#SBATCH --cpus-per-task=32 +#SBATCH --mem=500G #SBATCH --signal=B:USR1@300 #SBATCH --open-mode=append -#SBATCH --partition=p.large +#SBATCH --partition=p.share mamba activate nisb diff --git a/config.yaml b/config.yaml index 33cd0d2..2c9209e 100644 --- a/config.yaml +++ b/config.yaml @@ -5,20 +5,20 @@ params: - 1e-2 seed: - 0 - #- 1 - #- 2 - #- 3 - #- 4 + - 1 + - 2 + - 3 + - 4 long_range: - 10 batch_size: - - 1 + - 8 scheduler: - true model_id: - - "L" + - "S" kernel_size: - - 5 + - 3 synthetic: - 1.0 drop_slice_prob: @@ -32,35 +32,34 @@ params: affine: - 0.5 n_steps: - - 1_000_000 + - 50000 small_size: - 128 data_setting: - #- "base" - #- "liconn" - #- "multichannel" - #- "neg_guidance" - #- "no_touch_thick" - #- "pos_guidance" - #- "slice_perturbed" - #- "touching_thin" + - "base" + - "liconn" + - "multichannel" + - "neg_guidance" + - "no_touch_thick" + - "pos_guidance" + - "slice_perturbed" + - "touching_thin" - "train_100" base_data_path: - "/cajal/nvmescratch/projects/NISB/" save_path: - #- "/cajal/scratch/projects/misc/riegerfr/aff_nis/" - - "/cajal/scratch/projects/misc/zuzur/xl_banis" + - "/cajal/scratch/projects/misc/riegerfr/aff_nis/" exp_name: - - "xl_test" + - "exp" real_data_path: #https://colab.research.google.com/github/funkelab/lsd/blob/master/lsd/tutorial/notebooks/lsd_data_download.ipynb - "/cajal/scratch/projects/misc/mdraw/data/funke/zebrafinch/training/" auto_resubmit: - - True + - False distributed: - False compile: - - False + - True validate_extern: - True augment: - - False \ No newline at end of file + - True \ No newline at end of file From cdbfb374a0ad46fcf96394f8a09ff8d7be010f3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zuzana=20Urbanov=C3=A1?= Date: Wed, 10 Sep 2025 16:16:17 +0200 Subject: [PATCH 33/33] new inference modules --- BANIS.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/BANIS.py b/BANIS.py index b4321b7..d5bb6de 100644 --- a/BANIS.py +++ b/BANIS.py @@ -1,13 +1,13 @@ import argparse import gc import os +import shutil + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" from collections import defaultdict from datetime import datetime from typing import Any, Dict -import random -import numpy as np import pytorch_lightning as pl import torch import torchvision @@ -23,7 +23,7 @@ from tqdm import tqdm from data import load_data -from inference import scale_sigmoid, compute_connected_component_segmentation, predict_aff +from inference import AffinityPredictor, Thresholding from metrics import compute_metrics @@ -35,7 +35,7 @@ class BANIS(LightningModule): def __init__(self, **kwargs: Any): super().__init__() self.save_hyperparameters() - print(f"hparams: \n{self.hparams}") + # print(f"hparams: \n{self.hparams}") self.model = create_mednext_v1( num_input_channels=self.hparams.num_input_channels, @@ -163,8 +163,17 @@ def full_cube_inference(self, mode: str, global_step=None): img_data = zarr.open(os.path.join(seed_path, "data.zarr"), mode="r")["img"] - aff_pred = predict_aff(img_data, model=self, zarr_path=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", do_overlap=True, prediction_channels=3, divide=255, - small_size=self.hparams.small_size, compute_backend="local") + affinity_predictor = AffinityPredictor( + chunk_cube_size=3000, # can be adjusted + compute_backend="local", + model=self, + small_size=self.hparams.small_size, + do_overlap=True, + prediction_channels=3, + divide=255, + ) + affinity_predictor.img_to_aff(img_data, zarr_path=f"{self.hparams.save_dir}/pred_aff_{mode}.zarr") + aff_pred = zarr.open(f"{self.hparams.save_dir}/pred_aff_{mode}.zarr", mode="r") self._evaluate_thresholds(aff_pred, os.path.join(seed_path, "skeleton.pkl"), mode, global_step) @@ -179,9 +188,9 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str, torch.cuda.empty_cache() print(f"threshold {thr}") - pred_seg = compute_connected_component_segmentation( - aff_pred[:3] > thr # hard affinities - ) + postprocessor = Thresholding(3000, "local", thr) + postprocessor.aff_to_seg(aff_pred, f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr") + pred_seg = zarr.open(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr", mode="r") metrics = compute_metrics(pred_seg, skel_path) @@ -201,9 +210,11 @@ def _evaluate_thresholds(self, aff_pred: zarr.Array, skel_path: str, mode: str, self.best_thr_so_far[mode] = thr with open(f"{self.hparams.save_dir}/best_thr_{mode}.txt", "w") as f: f.write(str(self.best_thr_so_far[mode])) - seg_pred = zarr.array(pred_seg, dtype=np.uint32, - store=f"{self.hparams.save_dir}/pred_seg_{mode}.zarr", - chunks=(512, 512, 512), overwrite=True) + if os.path.exists(f"{self.hparams.save_dir}/pred_seg_{mode}.zarr"): + shutil.rmtree(f"{self.hparams.save_dir}/pred_seg_{mode}.zarr") + os.replace(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr", f"{self.hparams.save_dir}/pred_seg_{mode}.zarr") + else: + shutil.rmtree(f"{self.hparams.save_dir}/pred_seg_{mode}_tmp.zarr") best_voi = min(best_voi, metrics["voi_sum"]) self.safe_add_scalar(f"{mode}_best_nerl", best_nerl, global_step) @@ -266,9 +277,9 @@ def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_cli self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm) total_norm_after = torch.norm(torch.stack([p.grad.norm(2) for p in self.parameters() if p.grad is not None])) - self.log("clipped_gradients/total_norm", total_norm_after.item(), on_step=True) + self.log("gradients/total_norm_clipped", total_norm_after.item(), on_step=True) max_grad_after = max([p.grad.abs().max().item() for p in self.parameters() if p.grad is not None]) - self.log("clipped_gradients/max_grad", max_grad_after) + self.log("gradients/max_grad_clipped", max_grad_after) def main():