From 47a8eb81170fe5a31d0c4156745a2b8866aac56f Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 29 Apr 2024 19:41:11 +0100 Subject: [PATCH 001/108] wrap fixes --- autoarray/abstract_ndarray.py | 4 ++++ autoarray/dataset/abstract/dataset.py | 5 ++++- autoarray/inversion/pixelization/mappers/mapper_util.py | 2 +- autoarray/inversion/pixelization/mesh/mesh_util.py | 2 +- autoarray/structures/decorators/abstract.py | 3 --- autoarray/structures/decorators/relocate_radial.py | 6 +++++- 6 files changed, 15 insertions(+), 7 deletions(-) diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index 65df94209..ac663e163 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -6,6 +6,10 @@ from abc import abstractmethod import numpy as np +import os +if os.environ.get("USE_JAX") == "1": + from jax import numpy as np + from autoarray.numpy_wrapper import numpy as npw, register_pytree_node, Array from typing import TYPE_CHECKING diff --git a/autoarray/dataset/abstract/dataset.py b/autoarray/dataset/abstract/dataset.py index 3c7c3654e..388eb7af3 100644 --- a/autoarray/dataset/abstract/dataset.py +++ b/autoarray/dataset/abstract/dataset.py @@ -242,7 +242,10 @@ def apply_over_sampling( if over_sampling is not None: self.over_sampling = over_sampling - del self.__dict__["grid"] + try: + del self.__dict__["grid"] + except KeyError: + pass if over_sampling_pixelization is not None: self.over_sampling_pixelization = over_sampling_pixelization diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index a10865aae..288361956 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -425,7 +425,7 @@ def pix_size_weights_voronoi_nn_from( "In order to use the VoronoiNN pixelization you must install the " "Natural Neighbor Interpolation c package.\n\n" "" - "See: https://github.com/Jammy2211/PyAutoArray/tree/master/autoarray/util/nn" + "See: https://github.com/Jammy2211/PyAutoArray/tree/main/autoarray/util/nn" ) from e max_nneighbours = conf.instance["general"]["pixelization"][ diff --git a/autoarray/inversion/pixelization/mesh/mesh_util.py b/autoarray/inversion/pixelization/mesh/mesh_util.py index c84d2da03..419740cfa 100644 --- a/autoarray/inversion/pixelization/mesh/mesh_util.py +++ b/autoarray/inversion/pixelization/mesh/mesh_util.py @@ -601,7 +601,7 @@ def voronoi_nn_interpolated_array_from( "In order to use the VoronoiNN pixelization you must install the " "Natural Neighbor Interpolation c package.\n\n" "" - "See: https://github.com/Jammy2211/PyAutoArray/tree/master/autoarray/util/nn" + "See: https://github.com/Jammy2211/PyAutoArray/tree/main/autoarray/util/nn" ) from e pixel_points = voronoi.points diff --git a/autoarray/structures/decorators/abstract.py b/autoarray/structures/decorators/abstract.py index 1401d1e76..9844ced99 100644 --- a/autoarray/structures/decorators/abstract.py +++ b/autoarray/structures/decorators/abstract.py @@ -3,9 +3,6 @@ from autoarray.mask.mask_1d import Mask1D from autoarray.mask.mask_2d import Mask2D from autoarray.operators.over_sampling.abstract import AbstractOverSampling -from autoarray.structures.arrays.irregular import ArrayIrregular -from autoarray.structures.arrays.uniform_1d import Array1D -from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.uniform_1d import Grid1D from autoarray.structures.grids.irregular_2d import Grid2DIrregular from autoarray.structures.grids.uniform_2d import Grid2D diff --git a/autoarray/structures/decorators/relocate_radial.py b/autoarray/structures/decorators/relocate_radial.py index db8fd25c5..c25747d87 100644 --- a/autoarray/structures/decorators/relocate_radial.py +++ b/autoarray/structures/decorators/relocate_radial.py @@ -1,4 +1,6 @@ -import numpy as np +import os + +from autofit.jax_wrapper import numpy as np from functools import wraps from typing import Union @@ -57,6 +59,8 @@ def wrapper( ------- The grid_like object whose coordinates are radially moved from (0.0, 0.0). """ + if os.environ.get("USE_JAX", "0") == "1": + return grid try: grid_radial_minimum = conf.instance["grids"]["radial_minimum"][ From 43e6fef0d25001530bf8373ba8a8413c77230fcb Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 30 Apr 2024 16:05:44 +0100 Subject: [PATCH 002/108] abstract ndarray --- autoarray/abstract_ndarray.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index ac663e163..bcd7a64d4 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -7,6 +7,7 @@ import numpy as np import os + if os.environ.get("USE_JAX") == "1": from jax import numpy as np From e833a7028b7b2d14030cb60e937bd39e6af14966 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 13 May 2024 10:46:27 +0100 Subject: [PATCH 003/108] decorator hack fix --- autoarray/structures/decorators/abstract.py | 1 - autoarray/structures/decorators/relocate_radial.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/autoarray/structures/decorators/abstract.py b/autoarray/structures/decorators/abstract.py index 9844ced99..a0b850367 100644 --- a/autoarray/structures/decorators/abstract.py +++ b/autoarray/structures/decorators/abstract.py @@ -92,7 +92,6 @@ def evaluate_func(self): if isinstance(self.grid, Grid1D): grid = self.grid.grid_2d_radial_projected_from() return self.func(self.obj, grid, *self.args, **self.kwargs) - return self.func(self.obj, self.grid, *self.args, **self.kwargs) @property diff --git a/autoarray/structures/decorators/relocate_radial.py b/autoarray/structures/decorators/relocate_radial.py index c25747d87..ea65de2bd 100644 --- a/autoarray/structures/decorators/relocate_radial.py +++ b/autoarray/structures/decorators/relocate_radial.py @@ -60,7 +60,7 @@ def wrapper( The grid_like object whose coordinates are radially moved from (0.0, 0.0). """ if os.environ.get("USE_JAX", "0") == "1": - return grid + return func(obj, grid, *args, **kwargs) try: grid_radial_minimum = conf.instance["grids"]["radial_minimum"][ From 84920e54111ff815806c05c39f22be1dc727c4e4 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 10 Jun 2024 16:14:46 +0100 Subject: [PATCH 004/108] skip casting to float as jax does not like it --- autoarray/fit/fit_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index 61cbd0143..6a550d5a8 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -85,7 +85,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - return float(np.sum(chi_squared_map)) + return np.sum(chi_squared_map) def noise_normalization_from(*, noise_map: ty.DataLike) -> float: From bafc6da5faf3fc216652e1168ed755a23a1a16cb Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 10 Jun 2024 16:46:15 +0100 Subject: [PATCH 005/108] replicating strange jax issue --- test_autoarray/test_jax_changes.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test_autoarray/test_jax_changes.py b/test_autoarray/test_jax_changes.py index d977abfa2..95e220b35 100644 --- a/test_autoarray/test_jax_changes.py +++ b/test_autoarray/test_jax_changes.py @@ -1,6 +1,9 @@ import autoarray as aa import pytest +from autoarray import Grid2D, Mask2D +from autofit.jax_wrapper import numpy as np + @pytest.fixture(name="array") def make_array(): @@ -23,3 +26,10 @@ def test_in_place_multiply(array): array[0] *= 2.0 assert array[0] == 2.0 + + +def test_boolean_issue(): + grid = Grid2D.from_mask( + mask=Mask2D.all_false((10, 10), pixel_scales=1.0), + ) + print(np.array(grid)) From 6a6e4d9c5204812f16c379b0b33ffc41e0ad3a95 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 10 Jun 2024 16:56:29 +0100 Subject: [PATCH 006/108] fix weird array issue --- autoarray/abstract_ndarray.py | 2 +- test_autoarray/test_jax_changes.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index bcd7a64d4..b714fd4aa 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -117,7 +117,7 @@ def instance_unflatten(cls, aux_data, children): Unflatten a tuple of attributes (i.e. a pytree) into an instance of an autoarray class """ instance = cls.__new__(cls) - for key, value in zip(aux_data, children[1:]): + for key, value in zip(aux_data, children): setattr(instance, key, value) return instance diff --git a/test_autoarray/test_jax_changes.py b/test_autoarray/test_jax_changes.py index 95e220b35..f5104a942 100644 --- a/test_autoarray/test_jax_changes.py +++ b/test_autoarray/test_jax_changes.py @@ -32,4 +32,5 @@ def test_boolean_issue(): grid = Grid2D.from_mask( mask=Mask2D.all_false((10, 10), pixel_scales=1.0), ) - print(np.array(grid)) + values, keys = Grid2D.instance_flatten(grid) + np.array(Grid2D.instance_unflatten(keys, values)) From 23dd826cf781c1dcdd10e95d11c9211de13e8682 Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Thu, 27 Jun 2024 11:46:58 +0100 Subject: [PATCH 007/108] Add JAX path for `convolve_image` These changes allow the PSF convolution to be calculated in a JAX friendly way that works with `jax.grad`. The method will take `image` and `blurring_image`, convert them from their 1D slim versions back into their 2D naive arrays (with masked points set to zero) and use `jax.scipy.signal.convolve` to do the convolution. A new keyword has been added to the function `jax_method` that can be either `direct` or `fft` to control how the convolution is done in JAX. Typically if the PSF kernel is more than about 5x5 the FFT method will be faster. Because of this `fft` is the default. Note: with the FFT some more speed might be gained by pre-caching the FFT of the PSF kernel as that is fixed and does not need to be re-computed every time, but this is an optimization that can be put in place at a later time once we know the current function works as expected. --- autoarray/operators/convolver.py | 83 +++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 13 deletions(-) diff --git a/autoarray/operators/convolver.py b/autoarray/operators/convolver.py index c963311a2..7c05dbe5d 100644 --- a/autoarray/operators/convolver.py +++ b/autoarray/operators/convolver.py @@ -6,6 +6,14 @@ from autoarray import exc from autoarray.mask import mask_2d_util +from os import environ + +use_jax = environ.get("USE_JAX", "0") == "1" + +if use_jax: + import jax + import jax.numpy as jnp + class Convolver: def __init__(self, mask, kernel): @@ -309,7 +317,33 @@ def frame_at_coordinates_jit(coordinates, mask, mask_index_array, kernel_2d): return frame, kernel_frame - def convolve_image(self, image, blurring_image): + def jax_convolve(self, image, blurring_image, method='auto'): + slim_to_2D_index_image = jnp.nonzero( + np.logical_not(self.mask.array), + size=image.shape[0] + ) + slim_to_2D_index_blurring = jnp.nonzero( + np.logical_not(self.blurring_mask), + size=blurring_image.shape[0] + ) + expanded_image_native = jnp.zeros(self.mask.shape) + expanded_image_native = expanded_image_native.at[ + slim_to_2D_index_image + ].set(image) + expanded_image_native = expanded_image_native.at[ + slim_to_2D_index_blurring + ].set(blurring_image) + kernel = np.array(self.kernel.native.array) + convolve_native = jax.scipy.signal.convolve( + expanded_image_native, + kernel, + mode='same', + method=method + ) + convolve_slim = convolve_native[slim_to_2D_index_image] + return convolve_slim + + def convolve_image(self, image, blurring_image, jax_method='fft'): """ For a given 1D array and blurring array, convolve the two using this convolver. @@ -319,26 +353,49 @@ def convolve_image(self, image, blurring_image): 1D array of the values which are to be blurred with the convolver's PSF. blurring_image 1D array of the blurring values which blur into the array after PSF convolution. + jax_method + If JAX is enabled this keyword will indicate what method is used for the PSF + convolution. Can be either `direct` to calculate it in real space or `fft` + to calculated it via a fast Fourier transform. `fft` is typically faster for + kernels that are more than about 5x5. Default is `fft`. """ - if self.blurring_mask is None: + def exception_message(): raise exc.KernelException( "You cannot use the convolve_image function of a Convolver if the Convolver was" "not created with a blurring_mask." ) - convolved_image = self.convolve_jit( - image_1d_array=np.array(image.slim), - image_frame_1d_indexes=self.image_frame_1d_indexes, - image_frame_1d_kernels=self.image_frame_1d_kernels, - image_frame_1d_lengths=self.image_frame_1d_lengths, - blurring_1d_array=np.array(blurring_image.slim), - blurring_frame_1d_indexes=self.blurring_frame_1d_indexes, - blurring_frame_1d_kernels=self.blurring_frame_1d_kernels, - blurring_frame_1d_lengths=self.blurring_frame_1d_lengths, - ) + if use_jax: + jax.lax.cond( + self.blurring_mask is None, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None + ) - return Array2D(values=convolved_image, mask=self.mask) + return self.jax_convolve( + image, + blurring_image, + method=jax_method + ) + + else: + if self.blurring_mask is None: + exception_message() + + convolved_image = self.convolve_jit( + image_1d_array=np.array(image.slim), + image_frame_1d_indexes=self.image_frame_1d_indexes, + image_frame_1d_kernels=self.image_frame_1d_kernels, + image_frame_1d_lengths=self.image_frame_1d_lengths, + blurring_1d_array=np.array(blurring_image.slim), + blurring_frame_1d_indexes=self.blurring_frame_1d_indexes, + blurring_frame_1d_kernels=self.blurring_frame_1d_kernels, + blurring_frame_1d_lengths=self.blurring_frame_1d_lengths, + ) + + return Array2D(values=convolved_image, mask=self.mask) @staticmethod @numba_util.jit() From 7aa1c1c1a96f2d108fd3adbae473ec556cd76097 Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Thu, 27 Jun 2024 13:59:42 +0100 Subject: [PATCH 008/108] Make sure `.array` is called on inputs Jax needs the base arrays passed in to work. --- autoarray/operators/convolver.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/autoarray/operators/convolver.py b/autoarray/operators/convolver.py index 7c05dbe5d..21ce28f0f 100644 --- a/autoarray/operators/convolver.py +++ b/autoarray/operators/convolver.py @@ -319,20 +319,20 @@ def frame_at_coordinates_jit(coordinates, mask, mask_index_array, kernel_2d): def jax_convolve(self, image, blurring_image, method='auto'): slim_to_2D_index_image = jnp.nonzero( - np.logical_not(self.mask.array), + jnp.logical_not(self.mask.array), size=image.shape[0] ) slim_to_2D_index_blurring = jnp.nonzero( - np.logical_not(self.blurring_mask), + jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0] ) expanded_image_native = jnp.zeros(self.mask.shape) expanded_image_native = expanded_image_native.at[ slim_to_2D_index_image - ].set(image) + ].set(image.array) expanded_image_native = expanded_image_native.at[ slim_to_2D_index_blurring - ].set(blurring_image) + ].set(blurring_image.array) kernel = np.array(self.kernel.native.array) convolve_native = jax.scipy.signal.convolve( expanded_image_native, From 5580992eb7661d4d15c86c08d93993f39895f992 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 24 Jul 2024 11:17:15 +0100 Subject: [PATCH 009/108] remove print --- autoarray/inversion/linear_obj/linear_obj.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/autoarray/inversion/linear_obj/linear_obj.py b/autoarray/inversion/linear_obj/linear_obj.py index 89add2e92..3bcc1ae57 100644 --- a/autoarray/inversion/linear_obj/linear_obj.py +++ b/autoarray/inversion/linear_obj/linear_obj.py @@ -156,9 +156,6 @@ def regularization_matrix(self) -> np.ndarray: regularization it is bypassed. """ - print(type(self)) - print(self.regularization) - if self.regularization is None: return np.zeros((self.params, self.params)) From 65f3bff892ecf2aed65f76d578711bb3223b5f1a Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 12 Aug 2024 15:47:14 +0100 Subject: [PATCH 010/108] use jax_wrapper in util --- autoarray/geometry/geometry_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 44c71a977..1686ced29 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -1,5 +1,5 @@ from typing import Tuple, Union -import numpy as np +from autoarray.numpy_wrapper import np from autoarray import numba_util from autoarray import type as ty From df7bd38dd76ba220c7ab5e6afa8255e6fc115bcf Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Wed, 18 Sep 2024 09:57:04 +0100 Subject: [PATCH 011/108] Adjust numpy to jax.numpy imports for files used for the MGE Sersic These files are all used by the Sersic profile's MGE deflection angles. This wholesale replacement of numpy for jax.numpy solves many of the issues around the array wrappers (e.g. "Grid2d"). The `numpy_wrapper` is just a simple file that can be imported to provide the correct `np` and an optional `use_jax` bool in case further `if` blocks are needed. Thankfully, with the full replacement much of the control flow `if` blocks seem to work without modification! --- autoarray/abstract_ndarray.py | 10 +--- autoarray/fit/fit_util.py | 6 +- autoarray/geometry/geometry_util.py | 17 +++--- autoarray/mask/abstract_mask.py | 5 +- autoarray/mask/derive/indexes_2d.py | 4 +- autoarray/numpy_wrapper.py | 57 +----------------- .../over_sampling/over_sample_util.py | 31 ++++++---- autoarray/operators/over_sampling/uniform.py | 6 +- .../structures/decorators/relocate_radial.py | 4 +- autoarray/structures/decorators/to_grid.py | 2 +- .../structures/decorators/to_vector_yx.py | 2 +- autoarray/structures/decorators/transform.py | 3 +- autoarray/structures/grids/grid_2d_util.py | 58 +++++++++++++++---- autoarray/structures/grids/uniform_2d.py | 16 +++-- 14 files changed, 109 insertions(+), 112 deletions(-) diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index b714fd4aa..0c72dfd35 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -4,14 +4,8 @@ from abc import ABC from abc import abstractmethod -import numpy as np -import os - -if os.environ.get("USE_JAX") == "1": - from jax import numpy as np - -from autoarray.numpy_wrapper import numpy as npw, register_pytree_node, Array +from autoarray.numpy_wrapper import np, register_pytree_node, Array from typing import TYPE_CHECKING @@ -337,7 +331,7 @@ def __getitem__(self, item): def __setitem__(self, key, value): if isinstance(key, (np.ndarray, AbstractNDArray, Array)): - self._array = npw.where(key, value, self._array) + self._array = np.where(key, value, self._array) else: self._array[key] = value diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index 6a550d5a8..e5d34675b 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -1,8 +1,6 @@ from functools import wraps -import numpy as np - -from autoarray.numpy_wrapper import numpy as npw +from autoarray.numpy_wrapper import np from autoarray.mask.abstract_mask import Mask from autoarray import type as ty @@ -99,7 +97,7 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float: noise_map The masked noise-map of the dataset. """ - return npw.sum(npw.log(2 * np.pi * noise_map**2.0)) + return np.sum(np.log(2 * np.pi * noise_map**2.0)) def normalized_residual_map_complex_from( diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 1686ced29..71fb9b1f1 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -1,5 +1,5 @@ from typing import Tuple, Union -from autoarray.numpy_wrapper import np +from autoarray.numpy_wrapper import np, use_jax from autoarray import numba_util from autoarray import type as ty @@ -382,15 +382,18 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - shifted_grid_2d = grid_2d - np.array(centre) - radius = np.sqrt(np.sum(shifted_grid_2d**2.0, 1)) + if use_jax: + shifted_grid_2d = grid_2d.array - np.array(centre) + else: + shifted_grid_2d = grid_2d - np.array(centre) + radius = np.sqrt(np.sum(shifted_grid_2d**2.0, axis=1)) theta_coordinate_to_profile = np.arctan2( shifted_grid_2d[:, 0], shifted_grid_2d[:, 1] ) - np.radians(angle) - return np.vstack( - radius - * (np.sin(theta_coordinate_to_profile), np.cos(theta_coordinate_to_profile)) - ).T + return np.vstack([ + radius * np.sin(theta_coordinate_to_profile), + radius * np.cos(theta_coordinate_to_profile) + ]).T def transform_grid_2d_from_reference_frame( diff --git a/autoarray/mask/abstract_mask.py b/autoarray/mask/abstract_mask.py index 76994b3ee..5401bcd09 100644 --- a/autoarray/mask/abstract_mask.py +++ b/autoarray/mask/abstract_mask.py @@ -2,7 +2,8 @@ from abc import ABC import logging -import numpy as np + +from autoarray.numpy_wrapper import np, use_jax from pathlib import Path from typing import Dict, Union @@ -116,7 +117,7 @@ def pixels_in_mask(self) -> int: """ The total number of unmasked pixels (values are `False`) in the mask. """ - return int(np.size(self._array) - np.sum(self._array)) + return (np.size(self._array) - np.sum(self._array)).astype(int) @property def is_all_true(self) -> bool: diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index 9d38a2140..d8cf61a6e 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -import numpy as np -from autoarray.numpy_wrapper import register_pytree_node_class + +from autoarray.numpy_wrapper import np, register_pytree_node_class from typing import TYPE_CHECKING if TYPE_CHECKING: diff --git a/autoarray/numpy_wrapper.py b/autoarray/numpy_wrapper.py index 60758e977..54edb6c40 100644 --- a/autoarray/numpy_wrapper.py +++ b/autoarray/numpy_wrapper.py @@ -1,65 +1,14 @@ import logging -import numpy as np from os import environ - -def unwrap_arrays(args): - from autoarray.abstract_ndarray import AbstractNDArray - - for arg in args: - if isinstance(arg, AbstractNDArray): - yield arg.array - elif isinstance(arg, (list, tuple)): - yield type(arg)(unwrap_arrays(arg)) - else: - yield arg - - -class Callable: - def __init__(self, func): - self.func = func - - def __call__(self, *args, **kwargs): - from autoarray.abstract_ndarray import AbstractNDArray - - try: - first_argument = args[0] - except IndexError: - first_argument = None - - args = unwrap_arrays(args) - result = self.func(*args, **kwargs) - if isinstance(first_argument, AbstractNDArray) and not isinstance( - result, float - ): - return first_argument.with_new_array(result) - return result - - -class Numpy: - def __init__(self, jnp): - self.jnp = jnp - - def __getattr__(self, item): - try: - attribute = getattr(self.jnp, item) - except AttributeError as e: - logging.debug(e) - attribute = getattr(np, item) - if callable(attribute): - return Callable(attribute) - return attribute - - use_jax = environ.get("USE_JAX", "0") == "1" if use_jax: try: - import jax.numpy as jnp - from jax import jit + import jax + from jax import numpy as np, jit - numpy = Numpy(jnp) print("JAX mode enabled") except ImportError: @@ -67,7 +16,7 @@ def __getattr__(self, item): "JAX is not installed. Please install it with `pip install jax`." ) else: - numpy = Numpy(np) + import numpy as np def jit(function, *_, **__): return function diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index 3adc92077..1a0ae4290 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -1,7 +1,5 @@ -from autoarray.numpy_wrapper import register_pytree_node_class +from autoarray.numpy_wrapper import np, register_pytree_node_class, use_jax - -import numpy as np from typing import List, Tuple from autoarray.mask.mask_2d import Mask2D @@ -324,7 +322,11 @@ def sub_size_radial_bins_from( for i in range(radial_grid.shape[0]): for j in range(len(radial_list)): if radial_grid[i] < radial_list[j]: - sub_size[i] = sub_size_list[j] + if use_jax: + # while this makes it run, it is very, very slow + sub_size = sub_size.at[i].set(sub_size_list[j]) + else: + sub_size[i] = sub_size_list[j] break return sub_size @@ -403,12 +405,21 @@ def grid_2d_slim_over_sampled_via_mask_from( for y1 in range(sub): for x1 in range(sub): - grid_slim[sub_index, 0] = -( - y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0) - ) - grid_slim[sub_index, 1] = ( - x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0) - ) + if use_jax: + # while this makes it run, it is very, very slow + grid_slim = grid_slim.at[sub_index, 0].set(-( + y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0) + )) + grid_slim = grid_slim.at[sub_index, 1].set( + x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0) + ) + else: + grid_slim[sub_index, 0] = -( + y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0) + ) + grid_slim[sub_index, 1] = ( + x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0) + ) sub_index += 1 index += 1 diff --git a/autoarray/operators/over_sampling/uniform.py b/autoarray/operators/over_sampling/uniform.py index a31c5d50b..e615df2ea 100644 --- a/autoarray/operators/over_sampling/uniform.py +++ b/autoarray/operators/over_sampling/uniform.py @@ -1,4 +1,4 @@ -import numpy as np +from autoarray.numpy_wrapper import np from typing import List, Tuple, Union from autoconf import conf @@ -188,7 +188,7 @@ def from_radial_bins( sub_size = np.zeros(grid.shape_slim) for centre in centre_list: - radial_grid = grid.distances_to_coordinate_from(coordinate=centre) + radial_grid = grid.distances_to_coordinate_from(coordinate=centre).array sub_size_of_centre = over_sample_util.sub_size_radial_bins_from( radial_grid=np.array(radial_grid), @@ -379,7 +379,7 @@ def over_sampled_grid(self) -> Grid2DIrregular: grid = over_sample_util.grid_2d_slim_over_sampled_via_mask_from( mask_2d=np.array(self.mask), pixel_scales=self.mask.pixel_scales, - sub_size=np.array(self.sub_size).astype("int"), + sub_size=np.array(self.sub_size.array).astype("int"), origin=self.mask.origin, ) diff --git a/autoarray/structures/decorators/relocate_radial.py b/autoarray/structures/decorators/relocate_radial.py index 7063e1e09..59af4386b 100644 --- a/autoarray/structures/decorators/relocate_radial.py +++ b/autoarray/structures/decorators/relocate_radial.py @@ -1,6 +1,6 @@ import os -from autofit.jax_wrapper import numpy as np +from autofit.jax_wrapper import numpy as np, use_jax from functools import wraps from typing import Union @@ -59,7 +59,7 @@ def wrapper( ------- The grid_like object whose coordinates are radially moved from (0.0, 0.0). """ - if os.environ.get("USE_JAX", "0") == "1": + if use_jax: return func(obj, grid, *args, **kwargs) try: diff --git a/autoarray/structures/decorators/to_grid.py b/autoarray/structures/decorators/to_grid.py index 43a025d3d..5f5de3b9b 100644 --- a/autoarray/structures/decorators/to_grid.py +++ b/autoarray/structures/decorators/to_grid.py @@ -1,4 +1,4 @@ -import numpy as np +from autoarray.numpy_wrapper import np from functools import wraps from typing import List, Union diff --git a/autoarray/structures/decorators/to_vector_yx.py b/autoarray/structures/decorators/to_vector_yx.py index 1eea320f3..1cf23346d 100644 --- a/autoarray/structures/decorators/to_vector_yx.py +++ b/autoarray/structures/decorators/to_vector_yx.py @@ -1,4 +1,4 @@ -import numpy as np +from autoarray.numpy_wrapper import np from functools import wraps from typing import List, Union diff --git a/autoarray/structures/decorators/transform.py b/autoarray/structures/decorators/transform.py index ab2424884..947af5d01 100644 --- a/autoarray/structures/decorators/transform.py +++ b/autoarray/structures/decorators/transform.py @@ -1,6 +1,5 @@ -import numpy as np +from autoarray.numpy_wrapper import np from functools import wraps - from typing import Union from autoarray.structures.grids.uniform_1d import Grid1D diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 764f749b0..e3eb09c71 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -1,5 +1,9 @@ from __future__ import annotations -import numpy as np +from autoarray.numpy_wrapper import np, use_jax + +if use_jax: + import jax + from typing import TYPE_CHECKING, List, Optional, Tuple, Union if TYPE_CHECKING: @@ -53,7 +57,7 @@ def check_grid_2d(grid_2d: np.ndarray): def check_grid_2d_and_mask_2d(grid_2d: np.ndarray, mask_2d: Mask2D): if len(grid_2d.shape) == 2: - if grid_2d.shape[0] != mask_2d.pixels_in_mask: + def exception_message(): raise exc.GridException( f""" The input 2D grid does not have the same number of values as pixels in @@ -64,9 +68,18 @@ def check_grid_2d_and_mask_2d(grid_2d: np.ndarray, mask_2d: Mask2D): The mask number of pixels is {mask_2d.pixels_in_mask}. """ ) + if use_jax: + jax.lax.cond( + grid_2d.shape[0] != mask_2d.pixels_in_mask, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None + ) + elif grid_2d.shape[0] != mask_2d.pixels_in_mask: + exception_message() elif len(grid_2d.shape) == 3: - if (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: + def exception_message(): raise exc.GridException( f""" The input 2D grid is not the same dimensions as the mask @@ -76,6 +89,15 @@ def check_grid_2d_and_mask_2d(grid_2d: np.ndarray, mask_2d: Mask2D): The mask shape_native is {mask_2d.shape_native}. """ ) + if use_jax: + jax.lax.cond( + (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None + ) + elif (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: + exception_message() def convert_grid_2d( @@ -111,8 +133,12 @@ def convert_grid_2d( is_native = len(grid_2d.shape) == 3 if is_native: - grid_2d[:, :, 0] *= np.invert(mask_2d) - grid_2d[:, :, 1] *= np.invert(mask_2d) + if use_jax: + grid_2d = grid_2d.at[:, :, 0].multiply(np.invert(mask_2d.array)) + grid_2d = grid_2d.at[:, :, 1].multiply(np.invert(mask_2d.array)) + else: + grid_2d[:, :, 0] *= np.invert(mask_2d) + grid_2d[:, :, 1] *= np.invert(mask_2d) if is_native == store_native: return grid_2d @@ -121,10 +147,16 @@ def convert_grid_2d( grid_2d_native=np.array(grid_2d), mask=np.array(mask_2d), ) - return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d), - mask_2d=np.array(mask_2d), - ) + if use_jax: + return grid_2d_native_from( + grid_2d_slim=np.array(grid_2d.array), + mask_2d=np.array(mask_2d), + ) + else: + return grid_2d_native_from( + grid_2d_slim=np.array(grid_2d), + mask_2d=np.array(mask_2d), + ) def convert_grid_2d_to_slim( @@ -250,8 +282,12 @@ def grid_2d_slim_via_mask_from( for y in range(mask_2d.shape[0]): for x in range(mask_2d.shape[1]): if not mask_2d[y, x]: - grid_slim[index, 0] = -(y - centres_scaled[0]) * pixel_scales[0] - grid_slim[index, 1] = (x - centres_scaled[1]) * pixel_scales[1] + if use_jax: + grid_slim = grid_slim.at[index, 0].set(-(y - centres_scaled[0]) * pixel_scales[0]) + grid_slim = grid_slim.at[index, 1].set((x - centres_scaled[1]) * pixel_scales[1]) + else: + grid_slim[index, 0] = -(y - centres_scaled[0]) * pixel_scales[0] + grid_slim[index, 1] = (x - centres_scaled[1]) * pixel_scales[1] index += 1 return grid_slim diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index dfcec468b..3ff0f5c80 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +from autoarray.numpy_wrapper import np, use_jax from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -824,9 +824,14 @@ def squared_distances_to_coordinate_from( coordinate The (y,x) coordinate from which the squared distance of every grid (y,x) coordinate is computed. """ - squared_distances = np.square(self[:, 0] - coordinate[0]) + np.square( - self[:, 1] - coordinate[1] - ) + if use_jax: + squared_distances = np.square(self.array[:, 0] - coordinate[0]) + np.square( + self.array[:, 1] - coordinate[1] + ) + else: + squared_distances = np.square(self[:, 0] - coordinate[0]) + np.square( + self[:, 1] - coordinate[1] + ) return Array2D(values=squared_distances, mask=self.mask) def distances_to_coordinate_from( @@ -840,8 +845,9 @@ def distances_to_coordinate_from( coordinate The (y,x) coordinate from which the distance of every grid (y,x) coordinate is computed. """ + squared_distance = self.squared_distances_to_coordinate_from(coordinate=coordinate) distances = np.sqrt( - self.squared_distances_to_coordinate_from(coordinate=coordinate) + squared_distance.array ) return Array2D(values=distances, mask=self.mask) From c1a16eebd5b578a163ea90d687557cc14ceed04c Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 20 Sep 2024 13:48:59 +0100 Subject: [PATCH 012/108] specify sizes --- autoarray/structures/triangles/jax_array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/autoarray/structures/triangles/jax_array.py b/autoarray/structures/triangles/jax_array.py index 184dd3f8a..d7fd00be0 100644 --- a/autoarray/structures/triangles/jax_array.py +++ b/autoarray/structures/triangles/jax_array.py @@ -101,6 +101,7 @@ def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": axis=0, return_inverse=True, equal_nan=True, + size=selected_indices.shape[0] * 3, ) nan_mask = np.isnan(unique_vertices).any(axis=1) @@ -113,6 +114,7 @@ def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": unique_triangles_indices = np.unique( new_indices_sorted, axis=0, + size=new_indices_sorted.shape[0], ) return ArrayTriangles( From 2aca124e06682e058b0d02bb5598cd4aabaa5fa1 Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 20 Sep 2024 14:01:26 +0100 Subject: [PATCH 013/108] use jax wrapper --- autoarray/structures/grids/irregular_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/structures/grids/irregular_2d.py b/autoarray/structures/grids/irregular_2d.py index 619ffef11..25bfd09fb 100644 --- a/autoarray/structures/grids/irregular_2d.py +++ b/autoarray/structures/grids/irregular_2d.py @@ -1,4 +1,4 @@ -import numpy as np +from autoarray.numpy_wrapper import np from typing import List, Optional, Tuple, Union from autoarray.abstract_ndarray import AbstractNDArray From 76258e7be231581afa8f41132799b15e125429f9 Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 20 Sep 2024 14:02:47 +0100 Subject: [PATCH 014/108] fix import --- test_autoarray/structures/triangles/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_autoarray/structures/triangles/conftest.py b/test_autoarray/structures/triangles/conftest.py index 41a82f31e..ddd3dca4a 100644 --- a/test_autoarray/structures/triangles/conftest.py +++ b/test_autoarray/structures/triangles/conftest.py @@ -1,4 +1,4 @@ -from autoarray.numpy_wrapper import numpy as np +from autoarray.numpy_wrapper import np from autoarray.structures.triangles.array import ArrayTriangles From dd3c62920a508fb0191a2b73de3e4432656a1909 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 30 Sep 2024 08:19:12 +0100 Subject: [PATCH 015/108] unused import --- autoarray/structures/triangles/array.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/autoarray/structures/triangles/array.py b/autoarray/structures/triangles/array.py index 0d129dac2..05eb8a3b0 100644 --- a/autoarray/structures/triangles/array.py +++ b/autoarray/structures/triangles/array.py @@ -1,5 +1,3 @@ -from typing import Tuple - import numpy as np from autoarray.structures.triangles.abstract import AbstractTriangles From 11cf911db5785abb8af0964105417c6a1a6fb213 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 30 Sep 2024 08:30:29 +0100 Subject: [PATCH 016/108] use nansum to account of area of jax triangles --- autoarray/structures/triangles/abstract.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/autoarray/structures/triangles/abstract.py b/autoarray/structures/triangles/abstract.py index f622ad255..963bf8284 100644 --- a/autoarray/structures/triangles/abstract.py +++ b/autoarray/structures/triangles/abstract.py @@ -34,13 +34,12 @@ def area(self) -> float: The total area covered by the triangles. """ triangles = self.triangles - return ( - 0.5 - * np.abs( + return 0.5 * np.nansum( + np.abs( (triangles[:, 0, 0] * (triangles[:, 1, 1] - triangles[:, 2, 1])) + (triangles[:, 1, 0] * (triangles[:, 2, 1] - triangles[:, 0, 1])) + (triangles[:, 2, 0] * (triangles[:, 0, 1] - triangles[:, 1, 1])) - ).sum() + ) ) @property From a435c681607e620d576bc57cbae67e0d06e7ea28 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 30 Sep 2024 11:17:09 +0100 Subject: [PATCH 017/108] kwargs to fix interface --- autoarray/structures/triangles/abstract.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autoarray/structures/triangles/abstract.py b/autoarray/structures/triangles/abstract.py index 963bf8284..c6b2b9a2f 100644 --- a/autoarray/structures/triangles/abstract.py +++ b/autoarray/structures/triangles/abstract.py @@ -13,6 +13,7 @@ def __init__( self, indices, vertices, + **kwargs, ): """ Represents a set of triangles in efficient NumPy arrays. From 4416648e78ee35296aa4da88170c9857be0fed48 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 30 Sep 2024 11:57:01 +0100 Subject: [PATCH 018/108] ensure max array size is passed a --- autoarray/structures/triangles/jax_array.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/autoarray/structures/triangles/jax_array.py b/autoarray/structures/triangles/jax_array.py index a5ad4507c..9142330ff 100644 --- a/autoarray/structures/triangles/jax_array.py +++ b/autoarray/structures/triangles/jax_array.py @@ -122,6 +122,7 @@ def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": return ArrayTriangles( indices=unique_triangles_indices, vertices=unique_vertices, + max_containing_size=self.max_containing_size, ) def up_sample(self) -> "ArrayTriangles": @@ -135,6 +136,7 @@ def up_sample(self) -> "ArrayTriangles": return ArrayTriangles( indices=new_indices, vertices=unique_vertices, + max_containing_size=self.max_containing_size, ) def neighborhood(self) -> "ArrayTriangles": @@ -148,6 +150,7 @@ def neighborhood(self) -> "ArrayTriangles": return ArrayTriangles( indices=new_indices, vertices=unique_vertices, + max_containing_size=self.max_containing_size, ) def with_vertices(self, vertices: np.ndarray) -> "ArrayTriangles": @@ -166,6 +169,7 @@ def with_vertices(self, vertices: np.ndarray) -> "ArrayTriangles": return ArrayTriangles( indices=self.indices, vertices=vertices, + max_containing_size=self.max_containing_size, ) def __iter__(self): From 81774a0d95012e6ad4b0472406b3a7dd94073abf Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 30 Sep 2024 16:31:35 +0100 Subject: [PATCH 019/108] unused import[ --- autoarray/structures/decorators/relocate_radial.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/autoarray/structures/decorators/relocate_radial.py b/autoarray/structures/decorators/relocate_radial.py index 59af4386b..04fef4221 100644 --- a/autoarray/structures/decorators/relocate_radial.py +++ b/autoarray/structures/decorators/relocate_radial.py @@ -1,5 +1,3 @@ -import os - from autofit.jax_wrapper import numpy as np, use_jax from functools import wraps From 6c2e9a1cf9769619f74b0383d9be7be20d87da84 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 30 Sep 2024 16:32:16 +0100 Subject: [PATCH 020/108] use array version of jax wrapper to avoid build error --- autoarray/structures/decorators/relocate_radial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/structures/decorators/relocate_radial.py b/autoarray/structures/decorators/relocate_radial.py index 04fef4221..58411714f 100644 --- a/autoarray/structures/decorators/relocate_radial.py +++ b/autoarray/structures/decorators/relocate_radial.py @@ -1,4 +1,4 @@ -from autofit.jax_wrapper import numpy as np, use_jax +from autoarray.numpy_wrapper import np, use_jax from functools import wraps from typing import Union From 478317a2b38b34485d48dd801b1111ff55e399bb Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Thu, 3 Oct 2024 10:34:12 +0100 Subject: [PATCH 021/108] Small update --- autoarray/structures/decorators/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/structures/decorators/transform.py b/autoarray/structures/decorators/transform.py index 947af5d01..bd837a399 100644 --- a/autoarray/structures/decorators/transform.py +++ b/autoarray/structures/decorators/transform.py @@ -50,7 +50,7 @@ def wrapper( """ if not kwargs.get("is_transformed"): - kwargs = {"is_transformed": True} + kwargs["is_transformed"] = True transformed_grid = obj.transformed_to_reference_frame_grid_from( grid, **kwargs From f53269061bab14a1b98703b106c575403936dcd5 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Oct 2024 12:45:52 +0100 Subject: [PATCH 022/108] preloads removed --- autoarray/config/visualize/include.yaml | 2 +- autoarray/inversion/inversion/abstract.py | 10 - .../inversion/inversion/imaging/abstract.py | 19 - .../inversion/inversion/imaging/mapping.py | 21 +- .../inversion/inversion/imaging/w_tilde.py | 70 +-- .../inversion/pixelization/mesh/abstract.py | 9 +- autoarray/preloads.py | 482 ------------------ test_autoarray/config/general.yaml | 5 +- test_autoarray/config/visualize.yaml | 4 +- .../inversion/inversion/test_abstract.py | 54 -- test_autoarray/test_preloads.py | 434 ---------------- 11 files changed, 23 insertions(+), 1087 deletions(-) diff --git a/autoarray/config/visualize/include.yaml b/autoarray/config/visualize/include.yaml index 270b6b2f4..f010d9f8b 100644 --- a/autoarray/config/visualize/include.yaml +++ b/autoarray/config/visualize/include.yaml @@ -7,7 +7,7 @@ include_1d: mask: false # Include a Mask ? origin: false # Include the (x,) origin of the data's coordinate system ? include_2d: - border: true # Include the border of the mask (all pixels on the outside of the mask) ? + border: false # Include the border of the mask (all pixels on the outside of the mask) ? grid: false # Include the data's 2D grid of (y,x) coordinates ? mapper_image_plane_mesh_grid: false # For an Inversion, include the pixel centres computed in the image-plane / data frame? mapper_source_plane_data_grid: false # For an Inversion, include the centres of the image-plane grid mapped to the source-plane / frame in source-plane figures? diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index d4810219c..253381339 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -322,10 +322,6 @@ def operated_mapping_matrix(self) -> np.ndarray: If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous linear equations are solved simultaneously. """ - - if self.preloads.operated_mapping_matrix is not None: - return self.preloads.operated_mapping_matrix - return np.hstack(self.operated_mapping_matrix_list) @cached_property @@ -356,9 +352,6 @@ def regularization_matrix(self) -> Optional[np.ndarray]: If the `settings.force_edge_pixels_to_zeros` is `True`, the edge pixels of each mapper in the inversion are regularized so high their value is forced to zero. """ - if self.preloads.regularization_matrix is not None: - return self.preloads.regularization_matrix - return block_diag( *[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list] ) @@ -735,9 +728,6 @@ def log_det_regularization_matrix_term(self) -> float: if not self.has(cls=AbstractRegularization): return 0.0 - if self.preloads.log_det_regularization_matrix_term is not None: - return self.preloads.log_det_regularization_matrix_term - try: lu = splu(csc_matrix(self.regularization_matrix_reduced)) diagL = lu.L.diagonal() diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 5eccfa7d1..234ed3be1 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -140,12 +140,6 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: A dictionary mapping every linear function object to its operated mapping matrix. """ - if self.preloads.linear_func_operated_mapping_matrix_dict is not None: - return self._updated_cls_key_dict_from( - cls=AbstractLinearObjFuncList, - preload_dict=self.preloads.linear_func_operated_mapping_matrix_dict, - ) - linear_func_operated_mapping_matrix_dict = {} for linear_func in self.cls_list_from(cls=AbstractLinearObjFuncList): @@ -192,12 +186,6 @@ def data_linear_func_matrix_dict(self): A matrix of shape [data_pixels, total_fixed_linear_functions] that for each data pixel, maps it to the sum of the values of a linear object function convolved with the PSF kernel at the data pixel. """ - if self.preloads.data_linear_func_matrix_dict is not None: - return self._updated_cls_key_dict_from( - cls=AbstractLinearObjFuncList, - preload_dict=self.preloads.data_linear_func_matrix_dict, - ) - linear_func_list = self.cls_list_from(cls=AbstractLinearObjFuncList) data_linear_func_matrix_dict = {} @@ -237,13 +225,6 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict: ------- A dictionary mapping every mapper object to its operated mapping matrix. """ - - if self.preloads.mapper_operated_mapping_matrix_dict is not None: - return self._updated_cls_key_dict_from( - cls=AbstractMapper, - preload_dict=self.preloads.mapper_operated_mapping_matrix_dict, - ) - mapper_operated_mapping_matrix_dict = {} for mapper in self.cls_list_from(cls=AbstractMapper): diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 96f5d36d8..e078125ef 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -70,9 +70,6 @@ def _data_vector_mapper(self) -> np.ndarray: in the inversion, and is separated into a separate method to enable preloading of the mapper `data_vector`. """ - if self.preloads.data_vector_mapper is not None: - return self.preloads.data_vector_mapper - if not self.has(cls=AbstractMapper): return None @@ -117,16 +114,8 @@ def data_vector(self) -> np.ndarray: The calculation is described in more detail in `inversion_util.data_vector_via_blurred_mapping_matrix_from`. """ - if self.preloads.data_vector_mapper is not None: - return self.preloads.data_vector_mapper - - if self.preloads.operated_mapping_matrix is not None: - operated_mapping_matrix = self.preloads.operated_mapping_matrix - else: - operated_mapping_matrix = self.operated_mapping_matrix - return inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=operated_mapping_matrix, + blurred_mapping_matrix=self.operated_mapping_matrix, image=np.array(self.data), noise_map=np.array(self.noise_map), ) @@ -143,9 +132,6 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: other calculations to enable preloading of this calculation. """ - if self.preloads.curvature_matrix_mapper_diag is not None: - return self.preloads.curvature_matrix_mapper_diag - if not self.has(cls=AbstractMapper): return None @@ -201,11 +187,6 @@ def curvature_matrix(self): array of memory. """ - if self.preloads.curvature_matrix is not None: - # Need to copy because of how curvature_reg_matirx overwrites memory. - - return copy.copy(self.preloads.curvature_matrix) - return inversion_util.curvature_matrix_via_mapping_matrix_from( mapping_matrix=self.operated_mapping_matrix, noise_map=self.noise_map, diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index 4540f71df..6a9e2d545 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -94,9 +94,6 @@ def _data_vector_mapper(self) -> np.ndarray: in the inversion, and is separated into a separate method to enable preloading of the mapper `data_vector`. """ - if self.preloads.data_vector_mapper is not None: - return self.preloads.data_vector_mapper - if not self.has(cls=AbstractMapper): return None @@ -153,9 +150,6 @@ def _data_vector_x1_mapper(self) -> np.ndarray: which circumvents `np.concatenate` for speed up. """ - if self.preloads.data_vector_mapper is not None: - return self.preloads.data_vector_mapper - linear_obj = self.linear_obj_list[0] return inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( @@ -177,9 +171,6 @@ def _data_vector_multi_mapper(self) -> np.ndarray: which computes the `data_vector` of each object and concatenates them. """ - if self.preloads.data_vector_mapper is not None: - return self.preloads.data_vector_mapper - return np.concatenate( [ inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( @@ -255,11 +246,6 @@ def curvature_matrix(self) -> np.ndarray: array of memory. """ - if self.preloads.curvature_matrix is not None: - # Need to copy because of how curvature_reg_matirx overwrites memory. - - return copy.copy(self.preloads.curvature_matrix) - if self.has(cls=AbstractLinearObjFuncList): curvature_matrix = self._curvature_matrix_func_list_and_mapper elif self.total(cls=AbstractMapper) == 1: @@ -292,9 +278,6 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: other calculations to enable preloading of this calculation. """ - if self.preloads.curvature_matrix_mapper_diag is not None: - return self.preloads.curvature_matrix_mapper_diag - if not self.has(cls=AbstractMapper): return None @@ -451,44 +434,21 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: for func_index, linear_func in enumerate(linear_func_list): linear_func_param_range = linear_func_param_range_list[func_index] - if self.preloads.data_linear_func_matrix_dict is not None: - off_diag = inversion_imaging_util.curvature_matrix_off_diags_via_data_linear_func_matrix_from( - data_linear_func_matrix=self.data_linear_func_matrix_dict[ - linear_func - ], - data_to_pix_unique=mapper.unique_mappings.data_to_pix_unique, - data_weights=mapper.unique_mappings.data_weights, - pix_lengths=mapper.unique_mappings.pix_lengths, - pix_pixels=mapper.params, - ) - - elif self.preloads.mapper_operated_mapping_matrix_dict is not None: - operated_mapping_matrix = self.mapper_operated_mapping_matrix_dict[ - mapper - ] - - curvature_weights = ( - self.linear_func_operated_mapping_matrix_dict[linear_func] - ) / self.noise_map[:, None] ** 2 - - off_diag = np.dot(operated_mapping_matrix.T, curvature_weights) - - else: - curvature_weights = ( - self.linear_func_operated_mapping_matrix_dict[linear_func] - / self.noise_map[:, None] ** 2 - ) - - off_diag = inversion_imaging_util.curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from( - data_to_pix_unique=mapper.unique_mappings.data_to_pix_unique, - data_weights=mapper.unique_mappings.data_weights, - pix_lengths=mapper.unique_mappings.pix_lengths, - pix_pixels=mapper.params, - curvature_weights=curvature_weights, - image_frame_1d_lengths=self.convolver.image_frame_1d_lengths, - image_frame_1d_indexes=self.convolver.image_frame_1d_indexes, - image_frame_1d_kernels=self.convolver.image_frame_1d_kernels, - ) + curvature_weights = ( + self.linear_func_operated_mapping_matrix_dict[linear_func] + / self.noise_map[:, None] ** 2 + ) + + off_diag = inversion_imaging_util.curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from( + data_to_pix_unique=mapper.unique_mappings.data_to_pix_unique, + data_weights=mapper.unique_mappings.data_weights, + pix_lengths=mapper.unique_mappings.pix_lengths, + pix_pixels=mapper.params, + curvature_weights=curvature_weights, + image_frame_1d_lengths=self.convolver.image_frame_1d_lengths, + image_frame_1d_indexes=self.convolver.image_frame_1d_indexes, + image_frame_1d_kernels=self.convolver.image_frame_1d_kernels, + ) curvature_matrix[ mapper_param_range[0] : mapper_param_range[1], diff --git a/autoarray/inversion/pixelization/mesh/abstract.py b/autoarray/inversion/pixelization/mesh/abstract.py index dc1807238..472a28889 100644 --- a/autoarray/inversion/pixelization/mesh/abstract.py +++ b/autoarray/inversion/pixelization/mesh/abstract.py @@ -47,12 +47,9 @@ def relocated_grid_from( Contains quantities which may already be computed and can be preloaded to speed up calculations, in this case the relocated grid. """ - if preloads.relocated_grid is None: - if border_relocator is not None: - return border_relocator.relocated_grid_from(grid=source_plane_data_grid) - return source_plane_data_grid - - return preloads.relocated_grid + if border_relocator is not None: + return border_relocator.relocated_grid_from(grid=source_plane_data_grid) + return source_plane_data_grid @profile_func def relocated_mesh_grid_from( diff --git a/autoarray/preloads.py b/autoarray/preloads.py index bb9c7d282..ba6fcfcc9 100644 --- a/autoarray/preloads.py +++ b/autoarray/preloads.py @@ -21,46 +21,10 @@ def __init__( self, w_tilde=None, use_w_tilde=None, - image_plane_mesh_grid_pg_list=None, - relocated_grid=None, - mapper_list=None, - operated_mapping_matrix=None, - linear_func_operated_mapping_matrix_dict=None, - data_linear_func_matrix_dict=None, - mapper_operated_mapping_matrix_dict=None, - curvature_matrix=None, - data_vector_mapper=None, - curvature_matrix_mapper_diag=None, - regularization_matrix=None, - log_det_regularization_matrix_term=None, - traced_mesh_grids_list_of_planes=None, - image_plane_mesh_grid_list=None, ): self.w_tilde = w_tilde self.use_w_tilde = use_w_tilde - self.image_plane_mesh_grid_pg_list = image_plane_mesh_grid_pg_list - self.relocated_grid = relocated_grid - self.mapper_list = mapper_list - self.operated_mapping_matrix = operated_mapping_matrix - self.linear_func_operated_mapping_matrix_dict = ( - linear_func_operated_mapping_matrix_dict - ) - self.data_linear_func_matrix_dict = data_linear_func_matrix_dict - self.mapper_operated_mapping_matrix_dict = mapper_operated_mapping_matrix_dict - self.curvature_matrix = curvature_matrix - self.data_vector_mapper = data_vector_mapper - self.curvature_matrix_mapper_diag = curvature_matrix_mapper_diag - self.regularization_matrix = regularization_matrix - self.log_det_regularization_matrix_term = log_det_regularization_matrix_term - - self.traced_mesh_grids_list_of_planes = traced_mesh_grids_list_of_planes - self.image_plane_mesh_grid_list = image_plane_mesh_grid_list - - @property - def check_threshold(self): - return conf.instance["general"]["test"]["preloads_check_threshold"] - def set_w_tilde_imaging(self, fit_0, fit_1): """ The w-tilde linear algebra formalism speeds up inversions by computing beforehand quantities that enable @@ -117,449 +81,3 @@ def set_w_tilde_imaging(self, fit_0, fit_1): self.use_w_tilde = True logger.info("PRELOADS - W-Tilde preloaded for this model-fit.") - - def set_relocated_grid(self, fit_0, fit_1): - """ - If the `MassProfile`'s in a model are fixed their traced grids (which may have had coordinates relocated at - the border) does not change during the model=fit and can therefore be preloaded. - - This function compares the relocated grids of the mappers of two fit corresponding to two model instances, and - preloads the grid if the grids of both fits are the same. - - The preload is typically used in adapt searches, where the mass model is fixed and the parameters are - varied. - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.relocated_grid = None - - if fit_0.inversion is None: - return - - if ( - fit_0.inversion.total(cls=AbstractMapper) > 1 - or fit_0.inversion.total(cls=AbstractMapper) == 0 - ): - return - - mapper_0 = fit_0.inversion.cls_list_from(cls=AbstractMapper)[0] - mapper_1 = fit_1.inversion.cls_list_from(cls=AbstractMapper)[0] - - if ( - mapper_0.source_plane_data_grid.shape[0] - == mapper_1.source_plane_data_grid.shape[0] - ): - if ( - np.max( - abs( - mapper_0.source_plane_data_grid - - mapper_1.source_plane_data_grid - ) - ) - < 1.0e-8 - ): - self.relocated_grid = mapper_0.source_plane_data_grid - - logger.info( - "PRELOADS - Relocated grid of pxielization preloaded for this model-fit." - ) - - def set_mapper_list(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and the list of `Mapper`'s containing this information can - be preloaded. This includes preloading the `mapping_matrix`. - - This function compares the mapping matrix of two fit's corresponding to two model instances, and preloads the - list of mappers if the mapping matrix of both fits are the same. - - The preload is typically used in searches where only light profiles vary (e.g. when only the lens's light is - being fitted for). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.mapper_list = None - - if fit_0.inversion is None: - return - - if fit_0.inversion.total(cls=AbstractMapper) == 0: - return - - from autoarray.inversion.inversion.interferometer.lop import ( - InversionInterferometerMappingPyLops, - ) - - if isinstance(fit_0.inversion, InversionInterferometerMappingPyLops): - return - - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0.mapping_matrix.shape[1] == inversion_1.mapping_matrix.shape[1]: - if np.allclose(inversion_0.mapping_matrix, inversion_1.mapping_matrix): - self.mapper_list = inversion_0.cls_list_from(cls=AbstractMapper) - - logger.info( - "PRELOADS - Mappers of planes preloaded for this model-fit." - ) - - def set_operated_mapping_matrix_with_preloads(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and matrices used to perform the linear algebra in an - inversion can be preloaded, which help efficiently construct the curvature matrix. - - This function compares the operated mapping matrix of two fit's corresponding to two model instances, and - preloads the mapper if the mapping matrix of both fits are the same. - - The preload is typically used in searches where only light profiles vary (e.g. when only the lens's light is - being fitted for). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.operated_mapping_matrix = None - - from autoarray.inversion.inversion.interferometer.lop import ( - InversionInterferometerMappingPyLops, - ) - - if isinstance(fit_0.inversion, InversionInterferometerMappingPyLops): - return - - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0 is None: - return - - if ( - inversion_0.operated_mapping_matrix.shape[1] - == inversion_1.operated_mapping_matrix.shape[1] - ): - if ( - np.max( - abs( - inversion_0.operated_mapping_matrix - - inversion_1.operated_mapping_matrix - ) - ) - < 1e-8 - ): - self.operated_mapping_matrix = inversion_0.operated_mapping_matrix - - logger.info( - "PRELOADS - Inversion linear algebra quantities preloaded for this model-fit." - ) - - def set_linear_func_inversion_dicts(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and matrices used to perform the linear algebra in an - inversion can be preloaded, which help efficiently construct the curvature matrix. - - This function compares the operated mapping matrix of two fit's corresponding to two model instances, and - preloads the mapper if the mapping matrix of both fits are the same. - - The preload is typically used in searches where only light profiles vary (e.g. when only the lens's light is - being fitted for). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - from autoarray.inversion.pixelization.pixelization import Pixelization - - self.linear_func_operated_mapping_matrix_dict = None - - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0 is None: - return - - if not inversion_0.has(cls=AbstractMapper): - return - - if not inversion_0.has(cls=AbstractLinearObjFuncList): - return - - try: - inversion_0.linear_func_operated_mapping_matrix_dict - except NotImplementedError: - return - - if not hasattr(inversion_0, "linear_func_operated_mapping_matrix_dict"): - return - - should_preload = False - - for operated_mapping_matrix_0, operated_mapping_matrix_1 in zip( - inversion_0.linear_func_operated_mapping_matrix_dict.values(), - inversion_1.linear_func_operated_mapping_matrix_dict.values(), - ): - if ( - np.max(abs(operated_mapping_matrix_0 - operated_mapping_matrix_1)) - < 1e-8 - ): - should_preload = True - - if should_preload: - self.linear_func_operated_mapping_matrix_dict = ( - inversion_0.linear_func_operated_mapping_matrix_dict - ) - self.data_linear_func_matrix_dict = inversion_0.data_linear_func_matrix_dict - - logger.info( - "PRELOADS - Inversion linear light profile operated mapping matrix / data linear func matrix preloaded for this model-fit." - ) - - def set_curvature_matrix(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and therefore its associated curvature matrix is also - fixed, meaning the curvature matrix preloaded. - - If linear ``LightProfiles``'s are included, the regions of the curvature matrix associatd with these - objects vary but the diagonals of the mapper do not change. In this case, the `curvature_matrix_mapper_diag` - is preloaded. - - This function compares the curvature matrix of two fit's corresponding to two model instances, and preloads - this value if it is the same for both fits. - - The preload is typically used in **PyAutoGalaxy** inversions using a `Rectangular` pixelization. - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.curvature_matrix = None - self.data_vector_mapper = None - self.curvature_matrix_mapper_diag = None - self.mapper_operated_mapping_matrix_dict = None - - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0 is None: - return - - try: - inversion_0._curvature_matrix_mapper_diag - except NotImplementedError: - return - - if inversion_0.curvature_matrix.shape == inversion_1.curvature_matrix.shape: - if ( - np.max(abs(inversion_0.curvature_matrix - inversion_1.curvature_matrix)) - < 1e-8 - ): - self.curvature_matrix = inversion_0.curvature_matrix - - logger.info( - "PRELOADS - Inversion Curvature Matrix preloaded for this model-fit." - ) - - return - - if inversion_0._curvature_matrix_mapper_diag is not None: - if ( - np.max( - abs( - inversion_0._curvature_matrix_mapper_diag - - inversion_1._curvature_matrix_mapper_diag - ) - ) - < 1e-8 - ): - self.mapper_operated_mapping_matrix_dict = ( - inversion_0.mapper_operated_mapping_matrix_dict - ) - self.data_vector_mapper = inversion_0._data_vector_mapper - self.curvature_matrix_mapper_diag = ( - inversion_0._curvature_matrix_mapper_diag - ) - - logger.info( - "PRELOADS - Inversion Curvature Matrix Mapper Diag preloaded for this model-fit." - ) - - def set_regularization_matrix_and_term(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and therefore its associated regularization matrices are - also fixed, meaning the log determinant of the regularization matrix term of the Bayesian evidence can be - preloaded. - - This function compares the value of the log determinant of the regularization matrix of two fit's corresponding - to two model instances, and preloads this value if it is the same for both fits. - - The preload is typically used in searches where only light profiles vary (e.g. when only the lens's light is - being fitted for). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - self.regularization_matrix = None - self.log_det_regularization_matrix_term = None - - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0 is None: - return - - if inversion_0.total(cls=AbstractMapper) == 0: - return - - if ( - abs( - inversion_0.log_det_regularization_matrix_term - - inversion_1.log_det_regularization_matrix_term - ) - < 1e-8 - ): - self.regularization_matrix = inversion_0.regularization_matrix - self.log_det_regularization_matrix_term = ( - inversion_0.log_det_regularization_matrix_term - ) - - logger.info( - "PRELOADS - Inversion Log Det Regularization Matrix Term preloaded for this model-fit." - ) - - def check_via_fit(self, fit): - import copy - - settings_inversion = copy.deepcopy(fit.settings_inversion) - - fit_with_preloads = fit.refit_with_new_preloads( - preloads=self, settings_inversion=settings_inversion - ) - - fit_without_preloads = fit.refit_with_new_preloads( - preloads=self.__class__(use_w_tilde=False), - settings_inversion=settings_inversion, - ) - - if os.environ.get("PYAUTOFIT_TEST_MODE") == "1": - return - - try: - if ( - abs( - fit_with_preloads.figure_of_merit - - fit_without_preloads.figure_of_merit - ) - > self.check_threshold - ): - raise exc.PreloadsException( - f""" - The log likelihood of fits using and not using preloads are not consistent by a value larger than - the preloads check threshold of {self.check_threshold}, indicating preloading has gone wrong. - - The likelihood values are: - - With Preloads: {fit_with_preloads.figure_of_merit} - Without Preloads: {fit_without_preloads.figure_of_merit} - - Double check that the model-fit is set up correctly and that the preloads are being used correctly. - - This exception can be turned off by setting the general.yaml -> test -> check_preloads to False - in the config files. However, care should be taken when doing this. - """ - ) - - except exc.InversionException: - data_vector_difference = np.max( - np.abs( - fit_with_preloads.inversion.data_vector - - fit_without_preloads.inversion.data_vector - ) - ) - - if data_vector_difference > 1.0e-4: - raise exc.PreloadsException( - f""" - The data vectors of fits using and not using preloads are not consistent, indicating - preloading has gone wrong. - - The maximum value a data vector absolute value difference is: {data_vector_difference} - """ - ) - - curvature_reg_matrix_difference = np.max( - np.abs( - fit_with_preloads.inversion.curvature_reg_matrix - - fit_without_preloads.inversion.curvature_reg_matrix - ) - ) - - if curvature_reg_matrix_difference > 1.0e-4: - raise exc.PreloadsException( - f""" - The curvature matrices of fits using and not using preloads are not consistent, indicating - preloading has gone wrong. - - The maximum value of a curvature matrix absolute value difference is: {curvature_reg_matrix_difference} - """ - ) - - @property - def info(self) -> List[str]: - """ - The information on what has or has not been preloaded, which is written to the file `preloads.summary`. - - Returns - ------- - A list of strings containing statements on what has or has not been preloaded. - """ - line = [f"W Tilde = {self.w_tilde is not None}\n"] - line += [f"Relocated Grid = {self.relocated_grid is not None}\n"] - line += [f"Mapper = {self.mapper_list is not None}\n"] - line += [ - f"Blurred Mapping Matrix = {self.operated_mapping_matrix is not None}\n" - ] - line += [ - f"Inversion Linear Func (Linear Light Profile) Dicts = {self.linear_func_operated_mapping_matrix_dict is not None}\n" - ] - line += [f"Curvature Matrix = {self.curvature_matrix is not None}\n"] - line += [ - f"Curvature Matrix Mapper Diag = {self.curvature_matrix_mapper_diag is not None}\n" - ] - line += [f"Regularization Matrix = {self.regularization_matrix is not None}\n"] - line += [ - f"Log Det Regularization Matrix Term = {self.log_det_regularization_matrix_term is not None}\n" - ] - - return line diff --git a/test_autoarray/config/general.yaml b/test_autoarray/config/general.yaml index 73b76904f..6f331d141 100644 --- a/test_autoarray/config/general.yaml +++ b/test_autoarray/config/general.yaml @@ -1,6 +1,5 @@ analysis: n_cores: 1 - preload_attempts: 250 fits: flip_for_ds9: false grid: @@ -36,6 +35,4 @@ structures: native_binned_only: false # If True, data structures are only stored in their native and binned format. This is used to reduce memory usage in autocti. test: check_likelihood_function: true # if True, when a search is resumed the likelihood of a previous sample is recalculated to ensure it is consistent with the previous run. - check_preloads: false - exception_override: false - preloads_check_threshold: 1.0 # If the figure of merit of a fit with and without preloads is greater than this threshold, the check preload test fails and an exception raised for a model-fit. + exception_override: false \ No newline at end of file diff --git a/test_autoarray/config/visualize.yaml b/test_autoarray/config/visualize.yaml index 568a11349..8934bb465 100644 --- a/test_autoarray/config/visualize.yaml +++ b/test_autoarray/config/visualize.yaml @@ -5,7 +5,7 @@ general: zoom_around_mask: true disable_zoom_for_fits: true # If True, the zoom-in around the masked region is disabled when outputting .fits files, which is useful to retain the same dimensions as the input data. include_2d: - border: true + border: false mapper_image_plane_mesh_grid: false mapper_source_plane_data_grid: false mapper_source_plane_mesh_grid: false @@ -33,7 +33,7 @@ include: mask: false origin: false include_2d: - border: true + border: false mapper_image_plane_mesh_grid: false mapper_source_plane_data_grid: false mapper_source_plane_mesh_grid: false diff --git a/test_autoarray/inversion/inversion/test_abstract.py b/test_autoarray/inversion/inversion/test_abstract.py index 73e8d35fe..7dde7f051 100644 --- a/test_autoarray/inversion/inversion/test_abstract.py +++ b/test_autoarray/inversion/inversion/test_abstract.py @@ -319,60 +319,6 @@ def test__regularization_matrix(): assert inversion.regularization_matrix == pytest.approx(regularization_matrix) -def test__preloads__operated_mapping_matrix(): - operated_mapping_matrix = 2.0 * np.ones((9, 3)) - - preloads = aa.Preloads( - operated_mapping_matrix=operated_mapping_matrix, - ) - - # noinspection PyTypeChecker - inversion = aa.m.MockInversionImaging( - noise_map=np.ones(9), linear_obj_list=aa.m.MockMapper(), preloads=preloads - ) - - assert inversion.operated_mapping_matrix[0, 0] == 2.0 - - -def test__linear_func_operated_mapping_matrix_dict(): - dict_0 = {"key0": np.array([1.0, 2.0])} - - preloads = aa.Preloads(linear_func_operated_mapping_matrix_dict=dict_0) - - # noinspection PyTypeChecker - inversion = aa.m.MockInversionImagingWTilde( - noise_map=np.ones(9), - linear_obj_list=[aa.m.MockLinearObjFuncList()], - preloads=preloads, - ) - - assert list(inversion.linear_func_operated_mapping_matrix_dict.values())[ - 0 - ] == pytest.approx(dict_0["key0"], 1.0e-4) - - -def test__curvature_matrix_mapper_diag_preload(): - curvature_matrix_mapper_diag = 2.0 * np.ones((9, 3)) - - preloads = aa.Preloads(curvature_matrix_mapper_diag=curvature_matrix_mapper_diag) - - # noinspection PyTypeChecker - inversion = aa.m.MockInversionImagingWTilde( - noise_map=np.ones(9), linear_obj_list=aa.m.MockMapper(), preloads=preloads - ) - - assert inversion._curvature_matrix_mapper_diag == pytest.approx( - curvature_matrix_mapper_diag, 1.0e-4 - ) - - -def test__preload_of_regularization_matrix__overwrites_calculation(): - inversion = aa.m.MockInversion( - preloads=aa.Preloads(regularization_matrix=np.ones((2, 2))) - ) - - assert (inversion.regularization_matrix == np.ones((2, 2))).all() - def test__reconstruction_reduced(): linear_obj_list = [ diff --git a/test_autoarray/test_preloads.py b/test_autoarray/test_preloads.py index 9a0ddee1c..e58c32b93 100644 --- a/test_autoarray/test_preloads.py +++ b/test_autoarray/test_preloads.py @@ -82,437 +82,3 @@ def test__set_w_tilde(): assert preloads.use_w_tilde == True -def test__set_relocated_grid(): - # Inversion is None so there is no mapper, thus preload mapper to None. - - fit_0 = aa.m.MockFitImaging(inversion=None) - fit_1 = aa.m.MockFitImaging(inversion=None) - - preloads = aa.Preloads(relocated_grid=1) - preloads.set_relocated_grid(fit_0=fit_0, fit_1=fit_1) - - assert preloads.relocated_grid is None - - # Mapper's mapping matrices are different, thus preload mapper to None. - - inversion_0 = aa.m.MockInversion( - linear_obj_list=[aa.m.MockMapper(source_plane_data_grid=np.ones((3, 2)))] - ) - inversion_1 = aa.m.MockInversion( - linear_obj_list=[aa.m.MockMapper(source_plane_data_grid=2.0 * np.ones((3, 2)))] - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads(relocated_grid=1) - preloads.set_relocated_grid(fit_0=fit_0, fit_1=fit_1) - - assert preloads.relocated_grid is None - - # Mapper's mapping matrices are the same, thus preload mapper. - - inversion_0 = aa.m.MockInversion( - linear_obj_list=[aa.m.MockMapper(source_plane_data_grid=np.ones((3, 2)))] - ) - inversion_1 = aa.m.MockInversion( - linear_obj_list=[aa.m.MockMapper(source_plane_data_grid=np.ones((3, 2)))] - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads(relocated_grid=1) - preloads.set_relocated_grid(fit_0=fit_0, fit_1=fit_1) - - assert (preloads.relocated_grid == np.ones((3, 2))).all() - - -def test__set_mapper_list(): - # Inversion is None so there is no mapper, thus preload mapper to None. - - fit_0 = aa.m.MockFitImaging(inversion=None) - fit_1 = aa.m.MockFitImaging(inversion=None) - - preloads = aa.Preloads(mapper_list=1) - preloads.set_mapper_list(fit_0=fit_0, fit_1=fit_1) - - assert preloads.mapper_list is None - - # Mapper's mapping matrices are different, thus preload mapper to None. - - inversion_0 = aa.m.MockInversion( - linear_obj_list=[aa.m.MockMapper(mapping_matrix=np.ones((3, 2)))] - ) - inversion_1 = aa.m.MockInversion( - linear_obj_list=[aa.m.MockMapper(mapping_matrix=2.0 * np.ones((3, 2)))] - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads(mapper_list=1) - preloads.set_mapper_list(fit_0=fit_0, fit_1=fit_1) - - assert preloads.mapper_list is None - - # Mapper's mapping matrices are the same, thus preload mapper. - - inversion_0 = aa.m.MockInversion( - linear_obj_list=[aa.m.MockMapper(mapping_matrix=np.ones((3, 2)))] - ) - inversion_1 = aa.m.MockInversion( - linear_obj_list=[aa.m.MockMapper(mapping_matrix=np.ones((3, 2)))] - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads(mapper_list=1) - preloads.set_mapper_list(fit_0=fit_0, fit_1=fit_1) - - assert (preloads.mapper_list[0].mapping_matrix == np.ones((3, 2))).all() - - # Multiple mappers pre inversion still preloads full mapper list. - - inversion_0 = aa.m.MockInversion( - linear_obj_list=[ - aa.m.MockMapper(mapping_matrix=np.ones((3, 2))), - aa.m.MockMapper(mapping_matrix=np.ones((3, 2))), - ] - ) - inversion_1 = aa.m.MockInversion( - linear_obj_list=[ - aa.m.MockMapper(mapping_matrix=np.ones((3, 2))), - aa.m.MockMapper(mapping_matrix=np.ones((3, 2))), - ] - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads(mapper_list=1) - preloads.set_mapper_list(fit_0=fit_0, fit_1=fit_1) - - assert (preloads.mapper_list[0].mapping_matrix == np.ones((3, 2))).all() - assert (preloads.mapper_list[1].mapping_matrix == np.ones((3, 2))).all() - - -def test__set_operated_mapping_matrix_with_preloads(): - # Inversion is None thus preload it to None. - - fit_0 = aa.m.MockFitImaging(inversion=None) - fit_1 = aa.m.MockFitImaging(inversion=None) - - preloads = aa.Preloads( - operated_mapping_matrix=1, - ) - preloads.set_operated_mapping_matrix_with_preloads(fit_0=fit_0, fit_1=fit_1) - - assert preloads.operated_mapping_matrix is None - - # Inversion's blurred mapping matrices are different thus no preloading. - - operated_mapping_matrix_0 = np.array( - [[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]] - ) - - operated_mapping_matrix_1 = np.array( - [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]] - ) - - inversion_0 = aa.m.MockInversionImaging( - operated_mapping_matrix=operated_mapping_matrix_0 - ) - inversion_1 = aa.m.MockInversionImaging( - operated_mapping_matrix=operated_mapping_matrix_1 - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads( - operated_mapping_matrix=1, - ) - preloads.set_operated_mapping_matrix_with_preloads(fit_0=fit_0, fit_1=fit_1) - - assert preloads.operated_mapping_matrix is None - - # Inversion's blurred mapping matrices are the same therefore preload it and the curvature sparse terms. - - inversion_0 = aa.m.MockInversionImaging( - operated_mapping_matrix=operated_mapping_matrix_0, - ) - inversion_1 = aa.m.MockInversionImaging( - operated_mapping_matrix=operated_mapping_matrix_0, - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads( - operated_mapping_matrix=1, - ) - preloads.set_operated_mapping_matrix_with_preloads(fit_0=fit_0, fit_1=fit_1) - - assert (preloads.operated_mapping_matrix == operated_mapping_matrix_0).all() - - -def test__set_linear_func_operated_mapping_matrix_dict(): - # Inversion is None thus preload it to None. - - fit_0 = aa.m.MockFitImaging(inversion=None) - fit_1 = aa.m.MockFitImaging(inversion=None) - - preloads = aa.Preloads( - linear_func_operated_mapping_matrix_dict=0, - ) - preloads.set_linear_func_inversion_dicts(fit_0=fit_0, fit_1=fit_1) - - assert preloads.linear_func_operated_mapping_matrix_dict is None - assert preloads.data_linear_func_matrix_dict is None - - # Inversion's blurred mapping matrices are different thus no preloading. - - dict_0 = {"key0": np.array([1.0, 2.0])} - dict_1 = {"key1": np.array([1.0, 3.0])} - - inversion_0 = aa.m.MockInversionImaging( - linear_obj_list=[aa.m.MockLinearObjFuncList()], - linear_func_operated_mapping_matrix_dict=dict_0, - ) - inversion_1 = aa.m.MockInversionImaging( - linear_obj_list=[aa.m.MockLinearObjFuncList()], - linear_func_operated_mapping_matrix_dict=dict_1, - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads() - preloads.set_linear_func_inversion_dicts(fit_0=fit_0, fit_1=fit_1) - - assert preloads.linear_func_operated_mapping_matrix_dict is None - assert preloads.data_linear_func_matrix_dict is None - - # Inversion's blurred mapping matrices are the same therefore preload it and the curvature sparse terms. - - inversion_0 = aa.m.MockInversionImaging( - linear_obj_list=[aa.m.MockLinearObjFuncList(), aa.m.MockMapper()], - linear_func_operated_mapping_matrix_dict=dict_0, - data_linear_func_matrix_dict=dict_0, - ) - inversion_1 = aa.m.MockInversionImaging( - linear_obj_list=[aa.m.MockLinearObjFuncList(), aa.m.MockMapper()], - linear_func_operated_mapping_matrix_dict=dict_0, - data_linear_func_matrix_dict=dict_0, - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads() - preloads.set_linear_func_inversion_dicts(fit_0=fit_0, fit_1=fit_1) - - assert ( - preloads.linear_func_operated_mapping_matrix_dict["key0"] == dict_0["key0"] - ).all() - assert (preloads.data_linear_func_matrix_dict["key0"] == dict_0["key0"]).all() - - -def test__set_curvature_matrix(): - # Inversion is None thus preload curvature_matrix to None. - - fit_0 = aa.m.MockFitImaging(inversion=None) - fit_1 = aa.m.MockFitImaging(inversion=None) - - preloads = aa.Preloads( - curvature_matrix=1, data_vector_mapper=1, curvature_matrix_mapper_diag=1 - ) - preloads.set_curvature_matrix(fit_0=fit_0, fit_1=fit_1) - - assert preloads.curvature_matrix is None - - # Inversion's curvature matrices are different thus no preloading. - - curvature_matrix_0 = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) - - curvature_matrix_1 = np.array([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) - - fit_0 = aa.m.MockFitImaging( - inversion=aa.m.MockInversion( - curvature_matrix=curvature_matrix_0, - data_vector_mapper=1, - curvature_matrix_mapper_diag=1, - mapper_operated_mapping_matrix_dict={"key0": 1}, - ) - ) - fit_1 = aa.m.MockFitImaging( - inversion=aa.m.MockInversion( - curvature_matrix=curvature_matrix_1, - data_vector_mapper=1, - curvature_matrix_mapper_diag=1, - mapper_operated_mapping_matrix_dict={"key0": 1}, - ) - ) - - preloads = aa.Preloads(curvature_matrix=1) - preloads.set_curvature_matrix(fit_0=fit_0, fit_1=fit_1) - - assert preloads.curvature_matrix is None - - # Inversion's curvature matrices are the same therefore preload it and the curvature sparse terms. - - preloads = aa.Preloads(curvature_matrix=2) - - fit_0 = aa.m.MockFitImaging( - inversion=aa.m.MockInversion( - curvature_matrix=curvature_matrix_0, - data_vector_mapper=1, - curvature_matrix_mapper_diag=1, - mapper_operated_mapping_matrix_dict={"key0": 1}, - ) - ) - fit_1 = aa.m.MockFitImaging( - inversion=aa.m.MockInversion( - curvature_matrix=curvature_matrix_0, - data_vector_mapper=1, - curvature_matrix_mapper_diag=1, - mapper_operated_mapping_matrix_dict={"key0": 1}, - ) - ) - - preloads.set_curvature_matrix(fit_0=fit_0, fit_1=fit_1) - - assert (preloads.curvature_matrix == curvature_matrix_0).all() - - -def test__set_curvature_matrix__curvature_matrix_mapper_diag(): - # Inversion is None thus preload curvature_matrix to None. - - fit_0 = aa.m.MockFitImaging(inversion=None) - fit_1 = aa.m.MockFitImaging(inversion=None) - - preloads = aa.Preloads(data_vector_mapper=0, curvature_matrix_mapper_diag=1) - preloads.set_curvature_matrix(fit_0=fit_0, fit_1=fit_1) - - assert preloads.data_vector_mapper is None - assert preloads.curvature_matrix_mapper_diag is None - assert preloads.mapper_operated_mapping_matrix_dict is None - - # Inversion's curvature matrices are different thus no preloading. - - curvature_matrix_0 = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) - - curvature_matrix_1 = np.array([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) - - fit_0 = aa.m.MockFitImaging( - inversion=aa.m.MockInversion( - curvature_matrix=curvature_matrix_0, - curvature_matrix_mapper_diag=curvature_matrix_0, - mapper_operated_mapping_matrix_dict={"key0": 1}, - ) - ) - fit_1 = aa.m.MockFitImaging( - inversion=aa.m.MockInversion( - curvature_matrix=curvature_matrix_1, - curvature_matrix_mapper_diag=curvature_matrix_1, - mapper_operated_mapping_matrix_dict={"key0": 1}, - ) - ) - - preloads = aa.Preloads( - data_vector_mapper=0, - curvature_matrix_mapper_diag=1, - mapper_operated_mapping_matrix_dict=2, - ) - preloads.set_curvature_matrix(fit_0=fit_0, fit_1=fit_1) - - assert preloads.data_vector_mapper is None - assert preloads.curvature_matrix_mapper_diag is None - assert preloads.mapper_operated_mapping_matrix_dict is None - - # Inversion's curvature matrices are the same therefore preload it and the curvature sparse terms. - - preloads = aa.Preloads(data_vector_mapper=10, curvature_matrix_mapper_diag=2) - - fit_0 = aa.m.MockFitImaging( - inversion=aa.m.MockInversion( - curvature_matrix=curvature_matrix_0, - data_vector_mapper=0, - curvature_matrix_mapper_diag=curvature_matrix_0, - mapper_operated_mapping_matrix_dict={"key0": 1}, - ) - ) - fit_1 = aa.m.MockFitImaging( - inversion=aa.m.MockInversion( - curvature_matrix=curvature_matrix_1, - data_vector_mapper=0, - curvature_matrix_mapper_diag=curvature_matrix_0, - mapper_operated_mapping_matrix_dict={"key0": 1}, - ) - ) - - preloads.set_curvature_matrix(fit_0=fit_0, fit_1=fit_1) - - assert preloads.data_vector_mapper == 0 - assert (preloads.curvature_matrix_mapper_diag == curvature_matrix_0).all() - assert preloads.mapper_operated_mapping_matrix_dict == {"key0": 1} - - -def test__set_regularization_matrix_and_term(): - regularization = aa.m.MockRegularization(regularization_matrix=np.eye(2)) - - # Inversion is None thus preload log_det_regularization_matrix_term to None. - - fit_0 = aa.m.MockFitImaging(inversion=None) - fit_1 = aa.m.MockFitImaging(inversion=None) - - preloads = aa.Preloads(log_det_regularization_matrix_term=1) - preloads.set_regularization_matrix_and_term(fit_0=fit_0, fit_1=fit_1) - - assert preloads.regularization_matrix is None - assert preloads.log_det_regularization_matrix_term is None - - # Inversion's log_det_regularization_matrix_term are different thus no preloading. - - inversion_0 = aa.m.MockInversion( - log_det_regularization_matrix_term=0, - linear_obj_list=[aa.m.MockLinearObj(regularization=regularization)], - ) - - inversion_1 = aa.m.MockInversion( - log_det_regularization_matrix_term=1, - linear_obj_list=[aa.m.MockLinearObj(regularization=regularization)], - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads = aa.Preloads(log_det_regularization_matrix_term=1) - preloads.set_regularization_matrix_and_term(fit_0=fit_0, fit_1=fit_1) - - assert preloads.regularization_matrix is None - assert preloads.log_det_regularization_matrix_term is None - - # Inversion's regularization matrix terms are the same therefore preload it and the regularization matrix. - - preloads = aa.Preloads(log_det_regularization_matrix_term=2) - - inversion_0 = aa.m.MockInversion( - log_det_regularization_matrix_term=1, - linear_obj_list=[aa.m.MockMapper(regularization=regularization)], - ) - - inversion_1 = aa.m.MockInversion( - log_det_regularization_matrix_term=1, - linear_obj_list=[aa.m.MockMapper(regularization=regularization)], - ) - - fit_0 = aa.m.MockFitImaging(inversion=inversion_0) - fit_1 = aa.m.MockFitImaging(inversion=inversion_1) - - preloads.set_regularization_matrix_and_term(fit_0=fit_0, fit_1=fit_1) - - assert (preloads.regularization_matrix == np.eye(2)).all() - assert preloads.log_det_regularization_matrix_term == 1 From 5f9b6ff350f334b0cd0c2ac90f4d6ad0ab28d4f4 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Oct 2024 13:04:47 +0100 Subject: [PATCH 023/108] complete removal of preloads --- autoarray/__init__.py | 1 - autoarray/config/general.yaml | 3 - autoarray/exc.py | 10 -- autoarray/inversion/inversion/abstract.py | 9 -- autoarray/inversion/inversion/factory.py | 31 +--- .../inversion/inversion/imaging/abstract.py | 9 -- .../inversion/inversion/imaging/mapping.py | 2 - .../inversion/inversion/imaging/w_tilde.py | 3 - .../inversion/interferometer/abstract.py | 3 - .../inversion/interferometer/mapping.py | 3 - .../inversion/interferometer/w_tilde.py | 3 - autoarray/inversion/mock/mock_inversion.py | 3 - .../inversion/mock/mock_inversion_imaging.py | 5 - .../mock/mock_inversion_interferometer.py | 3 - autoarray/inversion/mock/mock_mesh.py | 3 - autoarray/inversion/mock/mock_pixelization.py | 1 - .../pixelization/mappers/mapper_grids.py | 10 -- .../inversion/pixelization/mesh/abstract.py | 6 - .../pixelization/mesh/rectangular.py | 8 +- .../pixelization/mesh/triangulation.py | 7 - autoarray/preloads.py | 83 ----------- .../inversion/inversion/test_abstract.py | 10 -- .../inversion/inversion/test_factory.py | 137 ------------------ .../pixelization/mesh/test_rectangular.py | 21 --- .../pixelization/mesh/test_triangulation.py | 21 --- test_autoarray/test_preloads.py | 84 ----------- 26 files changed, 3 insertions(+), 476 deletions(-) delete mode 100644 autoarray/preloads.py delete mode 100644 test_autoarray/inversion/pixelization/mesh/test_rectangular.py delete mode 100644 test_autoarray/inversion/pixelization/mesh/test_triangulation.py delete mode 100644 test_autoarray/test_preloads.py diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 3b1a8c079..34d1931b0 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -4,7 +4,6 @@ from . import fixtures from . import mock as m from .numba_util import profile_func -from .preloads import Preloads from .dataset import preprocess from .dataset.abstract.dataset import AbstractDataset from .dataset.abstract.w_tilde import AbstractWTilde diff --git a/autoarray/config/general.yaml b/autoarray/config/general.yaml index 16a8cd608..a6cb8d5dc 100644 --- a/autoarray/config/general.yaml +++ b/autoarray/config/general.yaml @@ -16,6 +16,3 @@ pixelization: voronoi_nn_max_interpolation_neighbors: 300 structures: native_binned_only: false # If True, data structures are only stored in their native and binned format. This is used to reduce memory usage in autocti. -test: - preloads_check_threshold: 1.0 # If the figure of merit of a fit with and without preloads is greater than this threshold, the check preload test fails and an exception raised for a model-fit. - diff --git a/autoarray/exc.py b/autoarray/exc.py index 657ec8324..cfe79eb76 100644 --- a/autoarray/exc.py +++ b/autoarray/exc.py @@ -106,16 +106,6 @@ class PlottingException(Exception): pass -class PreloadsException(Exception): - """ - Raises exceptions associated with the `preloads.py` module and `Preloads` class. - - For example if the preloaded quantities lead to a change in figure of merit of a fit compared to a fit without - preloading. - """ - - pass - class ProfilingException(Exception): """ diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index 253381339..390d2f6fb 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -31,7 +31,6 @@ def __init__( dataset: Union[Imaging, Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads: Optional["Preloads"] = None, run_time_dict: Optional[Dict] = None, ): """ @@ -70,17 +69,10 @@ def __init__( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - preloads - Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation, - for example certain matrices used by the linear algebra could be preloaded. run_time_dict A dictionary which contains timing of certain functions calls which is used for profiling. """ - from autoarray.preloads import Preloads - - preloads = preloads or Preloads() - # try: # import numba # except ModuleNotFoundError: @@ -98,7 +90,6 @@ def __init__( self.settings = settings - self.preloads = preloads self.run_time_dict = run_time_dict @property diff --git a/autoarray/inversion/inversion/factory.py b/autoarray/inversion/inversion/factory.py index f88bc36e8..349d0c168 100644 --- a/autoarray/inversion/inversion/factory.py +++ b/autoarray/inversion/inversion/factory.py @@ -18,14 +18,12 @@ from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.preloads import Preloads def inversion_from( dataset: Union[Imaging, Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ): """ @@ -51,9 +49,6 @@ def inversion_from( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - preloads - Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation, - for example certain matrices used by the linear algebra could be preloaded. run_time_dict A dictionary which contains timing of certain functions calls which is used for profiling. @@ -66,7 +61,6 @@ def inversion_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) @@ -82,7 +76,6 @@ def inversion_imaging_from( dataset, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ): """ @@ -112,9 +105,6 @@ def inversion_imaging_from( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - preloads - Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation, - for example certain matrices used by the linear algebra could be preloaded. run_time_dict A dictionary which contains timing of certain functions calls which is used for profiling. @@ -127,8 +117,6 @@ def inversion_imaging_from( for linear_obj in linear_obj_list ): use_w_tilde = False - elif preloads.use_w_tilde is not None: - use_w_tilde = preloads.use_w_tilde else: use_w_tilde = settings.use_w_tilde @@ -136,17 +124,13 @@ def inversion_imaging_from( use_w_tilde = False if use_w_tilde: - if preloads.w_tilde is not None: - w_tilde = preloads.w_tilde - else: - w_tilde = dataset.w_tilde + w_tilde = dataset.w_tilde return InversionImagingWTilde( dataset=dataset, w_tilde=w_tilde, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) @@ -154,7 +138,6 @@ def inversion_imaging_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) @@ -163,7 +146,6 @@ def inversion_interferometer_from( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ): """ @@ -197,9 +179,6 @@ def inversion_interferometer_from( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - preloads - Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation, - for example certain matrices used by the linear algebra could be preloaded. run_time_dict A dictionary which contains timing of certain functions calls which is used for profiling. @@ -222,17 +201,13 @@ def inversion_interferometer_from( if not settings.use_linear_operators: if use_w_tilde: - if preloads.w_tilde is not None: - w_tilde = preloads.w_tilde - else: - w_tilde = dataset.w_tilde + w_tilde = dataset.w_tilde return InversionInterferometerWTilde( dataset=dataset, w_tilde=w_tilde, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) @@ -241,7 +216,6 @@ def inversion_interferometer_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) @@ -250,6 +224,5 @@ def inversion_interferometer_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 234ed3be1..c1d2a7842 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -22,7 +22,6 @@ def __init__( dataset: Union[Imaging, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads=None, run_time_dict: Optional[Dict] = None, ): """ @@ -65,22 +64,14 @@ def __init__( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - preloads - Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation, - for example certain matrices used by the linear algebra could be preloaded. run_time_dict A dictionary which contains timing of certain functions calls which is used for profiling. """ - from autoarray.preloads import Preloads - - preloads = preloads or Preloads() - super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index e078125ef..5b3e40966 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -24,7 +24,6 @@ def __init__( dataset: Union[Imaging, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads=None, run_time_dict: Optional[Dict] = None, ): """ @@ -55,7 +54,6 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index 6a9e2d545..5f791873c 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -14,7 +14,6 @@ from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper -from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.inversion.inversion import inversion_util @@ -28,7 +27,6 @@ def __init__( w_tilde: WTildeImaging, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ): """ @@ -63,7 +61,6 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/inversion/interferometer/abstract.py b/autoarray/inversion/inversion/interferometer/abstract.py index 1d27d34f5..47e1c84bf 100644 --- a/autoarray/inversion/inversion/interferometer/abstract.py +++ b/autoarray/inversion/inversion/interferometer/abstract.py @@ -7,7 +7,6 @@ from autoarray.mask.mask_2d import Mask2D from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.inversion.settings import SettingsInversion -from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.inversion.inversion import inversion_util @@ -21,7 +20,6 @@ def __init__( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ): """ @@ -51,7 +49,6 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/inversion/interferometer/mapping.py b/autoarray/inversion/inversion/interferometer/mapping.py index 6a1719996..9cde492e9 100644 --- a/autoarray/inversion/inversion/interferometer/mapping.py +++ b/autoarray/inversion/inversion/interferometer/mapping.py @@ -10,7 +10,6 @@ ) from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.inversion.settings import SettingsInversion -from autoarray.preloads import Preloads from autoarray.structures.visibilities import Visibilities from autoarray.inversion.inversion.interferometer import inversion_interferometer_util @@ -25,7 +24,6 @@ def __init__( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ): """ @@ -58,7 +56,6 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/inversion/interferometer/w_tilde.py b/autoarray/inversion/inversion/interferometer/w_tilde.py index 46ef6794e..c3fd233bb 100644 --- a/autoarray/inversion/inversion/interferometer/w_tilde.py +++ b/autoarray/inversion/inversion/interferometer/w_tilde.py @@ -12,7 +12,6 @@ from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper -from autoarray.preloads import Preloads from autoarray.structures.visibilities import Visibilities from autoarray.inversion.inversion import inversion_util @@ -27,7 +26,6 @@ def __init__( w_tilde: WTildeInterferometer, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ): """ @@ -66,7 +64,6 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/mock/mock_inversion.py b/autoarray/inversion/mock/mock_inversion.py index c94551352..b221163e3 100644 --- a/autoarray/inversion/mock/mock_inversion.py +++ b/autoarray/inversion/mock/mock_inversion.py @@ -4,7 +4,6 @@ from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.inversion.abstract import AbstractInversion from autoarray.inversion.inversion.settings import SettingsInversion -from autoarray.preloads import Preloads class MockInversion(AbstractInversion): @@ -31,7 +30,6 @@ def __init__( log_det_curvature_reg_matrix_term=None, log_det_regularization_matrix_term=None, settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), ): dataset = DatasetInterface( data=data, @@ -42,7 +40,6 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list or [], settings=settings, - preloads=preloads, ) self._operated_mapping_matrix = operated_mapping_matrix diff --git a/autoarray/inversion/mock/mock_inversion_imaging.py b/autoarray/inversion/mock/mock_inversion_imaging.py index 65f224ac1..64076db56 100644 --- a/autoarray/inversion/mock/mock_inversion_imaging.py +++ b/autoarray/inversion/mock/mock_inversion_imaging.py @@ -5,7 +5,6 @@ from autoarray.inversion.inversion.imaging.mapping import InversionImagingMapping from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde from autoarray.inversion.inversion.settings import SettingsInversion -from autoarray.preloads import Preloads class MockInversionImaging(InversionImagingMapping): @@ -19,7 +18,6 @@ def __init__( linear_func_operated_mapping_matrix_dict=None, data_linear_func_matrix_dict=None, settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), ): dataset = DatasetInterface( data=data, @@ -31,7 +29,6 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, ) self._operated_mapping_matrix = operated_mapping_matrix @@ -78,7 +75,6 @@ def __init__( linear_obj_list=None, curvature_matrix_mapper_diag=None, settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), ): dataset = DatasetInterface( data=data, @@ -91,7 +87,6 @@ def __init__( w_tilde=w_tilde or MockWTildeImaging(), linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, ) self.__curvature_matrix_mapper_diag = curvature_matrix_mapper_diag diff --git a/autoarray/inversion/mock/mock_inversion_interferometer.py b/autoarray/inversion/mock/mock_inversion_interferometer.py index a0c092505..58de71520 100644 --- a/autoarray/inversion/mock/mock_inversion_interferometer.py +++ b/autoarray/inversion/mock/mock_inversion_interferometer.py @@ -5,7 +5,6 @@ InversionInterferometerMapping, ) from autoarray.inversion.inversion.settings import SettingsInversion -from autoarray.preloads import Preloads class MockInversionInterferometer(InversionInterferometerMapping): @@ -17,7 +16,6 @@ def __init__( linear_obj_list=None, operated_mapping_matrix=None, settings: SettingsInversion = SettingsInversion(), - preloads: Preloads = Preloads(), ): dataset = DatasetInterface( data=data, @@ -29,7 +27,6 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - preloads=preloads, ) self._operated_mapping_matrix = operated_mapping_matrix diff --git a/autoarray/inversion/mock/mock_mesh.py b/autoarray/inversion/mock/mock_mesh.py index 04fca5c01..def02657a 100644 --- a/autoarray/inversion/mock/mock_mesh.py +++ b/autoarray/inversion/mock/mock_mesh.py @@ -7,7 +7,6 @@ from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.preloads import Preloads class MockMesh(AbstractMesh): @@ -24,7 +23,6 @@ def mapper_grids_from( source_plane_mesh_grid: Optional[Abstract2DMesh] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: Optional[np.ndarray] = None, - preloads: Optional[Preloads] = None, run_time_dict: Optional[Dict] = None, ) -> MapperGrids: return MapperGrids( @@ -34,7 +32,6 @@ def mapper_grids_from( source_plane_mesh_grid=source_plane_mesh_grid, image_plane_mesh_grid=self.image_plane_mesh_grid, adapt_data=adapt_data, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/mock/mock_pixelization.py b/autoarray/inversion/mock/mock_pixelization.py index dc06118a7..72ced89e5 100644 --- a/autoarray/inversion/mock/mock_pixelization.py +++ b/autoarray/inversion/mock/mock_pixelization.py @@ -29,7 +29,6 @@ def mapper_grids_from( image_plane_mesh_grid=None, adapt_data=None, settings=None, - preloads=None, run_time_dict=None, ): return self.mapper diff --git a/autoarray/inversion/pixelization/mappers/mapper_grids.py b/autoarray/inversion/pixelization/mappers/mapper_grids.py index 037c1a19e..9a12e2f95 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_grids.py +++ b/autoarray/inversion/pixelization/mappers/mapper_grids.py @@ -2,9 +2,6 @@ import numpy as np from typing import TYPE_CHECKING, Dict, Optional -if TYPE_CHECKING: - from autoarray import Preloads - from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.uniform_2d import Grid2D @@ -22,7 +19,6 @@ def __init__( source_plane_mesh_grid: Optional[Abstract2DMesh] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: Optional[np.ndarray] = None, - preloads: Optional[Preloads] = None, run_time_dict: Optional[Dict] = None, ): """ @@ -59,21 +55,15 @@ def __init__( adapt_data An image which is used to determine the `image_plane_mesh_grid` and therefore adapt the distribution of pixels of the Delaunay grid to the data it discretizes. - preloads - Preloads in memory certain arrays which may be known beforehand in order to speed up the calculation, - for example the `source_plane_mesh_grid` could be preloaded. run_time_dict A dictionary which contains timing of certain functions calls which is used for profiling. """ - from autoarray.preloads import Preloads - self.mask = mask self.source_plane_data_grid = source_plane_data_grid self.source_plane_mesh_grid = source_plane_mesh_grid self.image_plane_mesh_grid = image_plane_mesh_grid self.adapt_data = adapt_data - self.preloads = preloads or Preloads() self.run_time_dict = run_time_dict @property diff --git a/autoarray/inversion/pixelization/mesh/abstract.py b/autoarray/inversion/pixelization/mesh/abstract.py index 472a28889..95d3d1ce3 100644 --- a/autoarray/inversion/pixelization/mesh/abstract.py +++ b/autoarray/inversion/pixelization/mesh/abstract.py @@ -5,7 +5,6 @@ from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.preloads import Preloads from autoarray.numba_util import profile_func @@ -19,7 +18,6 @@ def relocated_grid_from( self, border_relocator: BorderRelocator, source_plane_data_grid: Grid2D, - preloads: Preloads = Preloads(), ) -> Grid2D: """ Relocates all coordinates of the input `source_plane_data_grid` that are outside of a @@ -43,9 +41,6 @@ def relocated_grid_from( edge. source_plane_data_grid A 2D (y,x) grid of coordinates, whose coordinates outside the border are relocated to its edge. - preloads - Contains quantities which may already be computed and can be preloaded to speed up calculations, in this - case the relocated grid. """ if border_relocator is not None: return border_relocator.relocated_grid_from(grid=source_plane_data_grid) @@ -100,7 +95,6 @@ def mapper_grids_from( source_plane_mesh_grid: Optional[Grid2DIrregular] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: np.ndarray = None, - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ) -> MapperGrids: raise NotImplementedError diff --git a/autoarray/inversion/pixelization/mesh/rectangular.py b/autoarray/inversion/pixelization/mesh/rectangular.py index 0f4cc204a..514e5b6d4 100644 --- a/autoarray/inversion/pixelization/mesh/rectangular.py +++ b/autoarray/inversion/pixelization/mesh/rectangular.py @@ -4,7 +4,7 @@ from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular -from autoarray.preloads import Preloads + from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids from autoarray.inversion.pixelization.mesh.abstract import AbstractMesh from autoarray.inversion.pixelization.border_relocator import BorderRelocator @@ -62,7 +62,6 @@ def mapper_grids_from( source_plane_mesh_grid: Grid2D = None, image_plane_mesh_grid: Grid2D = None, adapt_data: np.ndarray = None, - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ) -> MapperGrids: """ @@ -95,9 +94,6 @@ def mapper_grids_from( Not used for a rectangular pixelization. adapt_data Not used for a rectangular pixelization. - preloads - Object which may contain preloaded arrays of quantities computed in the pixelization, which are passed via - this object speed up the calculation. run_time_dict A dictionary which contains timing of certain functions calls which is used for profiling. """ @@ -107,7 +103,6 @@ def mapper_grids_from( relocated_grid = self.relocated_grid_from( border_relocator=border_relocator, source_plane_data_grid=source_plane_data_grid, - preloads=preloads, ) mesh_grid = self.mesh_grid_from(source_plane_data_grid=relocated_grid) @@ -117,7 +112,6 @@ def mapper_grids_from( source_plane_mesh_grid=mesh_grid, image_plane_mesh_grid=image_plane_mesh_grid, adapt_data=adapt_data, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/pixelization/mesh/triangulation.py b/autoarray/inversion/pixelization/mesh/triangulation.py index 91ee0363f..17b7536f8 100644 --- a/autoarray/inversion/pixelization/mesh/triangulation.py +++ b/autoarray/inversion/pixelization/mesh/triangulation.py @@ -3,7 +3,6 @@ from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.preloads import Preloads from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids from autoarray.inversion.pixelization.mesh.abstract import AbstractMesh from autoarray.inversion.pixelization.border_relocator import BorderRelocator @@ -18,7 +17,6 @@ def mapper_grids_from( source_plane_mesh_grid: Optional[Grid2DIrregular] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: np.ndarray = None, - preloads: Preloads = Preloads(), run_time_dict: Optional[Dict] = None, ) -> MapperGrids: """ @@ -61,9 +59,6 @@ def mapper_grids_from( transformation applied to it to create the `source_plane_mesh_grid`. adapt_data Not used for a rectangular mesh. - preloads - Object which may contain preloaded arrays of quantities computed in the mesh, which are passed via - this object speed up the calculation. run_time_dict A dictionary which contains timing of certain functions calls which is used for profiling. """ @@ -73,7 +68,6 @@ def mapper_grids_from( source_plane_data_grid = self.relocated_grid_from( border_relocator=border_relocator, source_plane_data_grid=source_plane_data_grid, - preloads=preloads, ) relocated_source_plane_mesh_grid = self.relocated_mesh_grid_from( @@ -96,6 +90,5 @@ def mapper_grids_from( source_plane_mesh_grid=source_plane_mesh_grid, image_plane_mesh_grid=image_plane_mesh_grid, adapt_data=adapt_data, - preloads=preloads, run_time_dict=run_time_dict, ) diff --git a/autoarray/preloads.py b/autoarray/preloads.py deleted file mode 100644 index ba6fcfcc9..000000000 --- a/autoarray/preloads.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -import numpy as np -import os -from typing import List - -from autoconf import conf - -from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList -from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper - -from autoarray import exc -from autoarray.inversion.inversion.imaging import inversion_imaging_util - -logger = logging.getLogger(__name__) - -logger.setLevel(level="INFO") - - -class Preloads: - def __init__( - self, - w_tilde=None, - use_w_tilde=None, - ): - self.w_tilde = w_tilde - self.use_w_tilde = use_w_tilde - - def set_w_tilde_imaging(self, fit_0, fit_1): - """ - The w-tilde linear algebra formalism speeds up inversions by computing beforehand quantities that enable - efficiently construction of the curvature matrix. These quantities can only be used if the noise-map is - fixed, therefore this function preloads these w-tilde quantities if the noise-map does not change. - - This function compares the noise map of two fit's corresponding to two model instances, and preloads wtilde - if the noise maps of both fits are the same. - - The preload is typically used through search chaining pipelines, as it is uncommon for the noise map to be - scaled during the model-fit (although it is common for a fixed but scaled noise map to be used). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.w_tilde = None - self.use_w_tilde = False - - if fit_0.inversion is None: - return - - if not fit_0.inversion.has(cls=AbstractMapper): - return - - if np.max(abs(fit_0.noise_map - fit_1.noise_map)) < 1e-8: - logger.info("PRELOADS - Computing W-Tilde... May take a moment.") - - from autoarray.dataset.imaging.w_tilde import WTildeImaging - - ( - preload, - indexes, - lengths, - ) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from( - noise_map_native=np.array(fit_0.noise_map.native), - kernel_native=np.array(fit_0.dataset.psf.native), - native_index_for_slim_index=np.array( - fit_0.dataset.mask.derive_indexes.native_for_slim - ), - ) - - self.w_tilde = WTildeImaging( - curvature_preload=preload, - indexes=indexes.astype("int"), - lengths=lengths.astype("int"), - noise_map_value=fit_0.noise_map[0], - ) - - self.use_w_tilde = True - - logger.info("PRELOADS - W-Tilde preloaded for this model-fit.") diff --git a/test_autoarray/inversion/inversion/test_abstract.py b/test_autoarray/inversion/inversion/test_abstract.py index 7dde7f051..5d7fd622f 100644 --- a/test_autoarray/inversion/inversion/test_abstract.py +++ b/test_autoarray/inversion/inversion/test_abstract.py @@ -538,16 +538,6 @@ def test__regularization_term(): assert inversion.regularization_term == 34.0 -def test__preload_of_log_det_regularization_term_overwrites_calculation(): - inversion = aa.m.MockInversion( - linear_obj_list=[ - aa.m.MockLinearObj(parameters=3, regularization=aa.m.MockRegularization()) - ], - preloads=aa.Preloads(log_det_regularization_matrix_term=1.0), - ) - - assert inversion.log_det_regularization_matrix_term == 1.0 - def test__determinant_of_positive_definite_matrix_via_cholesky(): matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index c557adf5a..79ea56020 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -426,143 +426,6 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( ) -def test__inversion_imaging__linear_obj_func_with_w_tilde__include_preload_data_linear_func_matrix( - masked_imaging_7x7, - rectangular_mapper_7x7_3x3, - delaunay_mapper_9_3x3, -): - masked_imaging_7x7 = copy.copy(masked_imaging_7x7) - masked_imaging_7x7.data[4] = 2.0 - masked_imaging_7x7.noise_map[3] = 4.0 - masked_imaging_7x7.psf[0] = 0.1 - masked_imaging_7x7.psf[4] = 0.9 - - mask = masked_imaging_7x7.mask - - grid = aa.Grid2D.from_mask(mask=mask) - - mapping_matrix = np.full(fill_value=0.5, shape=(9, 2)) - mapping_matrix[0, 0] = 0.8 - mapping_matrix[1, 1] = 0.4 - - linear_obj = aa.m.MockLinearObjFuncList( - parameters=2, grid=grid, mapping_matrix=mapping_matrix - ) - - linear_obj_1 = aa.m.MockLinearObjFuncList( - parameters=2, grid=grid, mapping_matrix=mapping_matrix - ) - - linear_obj_2 = aa.m.MockLinearObjFuncList( - parameters=2, grid=grid, mapping_matrix=mapping_matrix - ) - - inversion_mapping = aa.Inversion( - dataset=masked_imaging_7x7, - linear_obj_list=[ - rectangular_mapper_7x7_3x3, - linear_obj, - delaunay_mapper_9_3x3, - linear_obj_1, - linear_obj_2, - ], - settings=aa.SettingsInversion(use_w_tilde=False), - ) - - preloads = aa.Preloads( - data_linear_func_matrix_dict=inversion_mapping.data_linear_func_matrix_dict - ) - - inversion_w_tilde = aa.Inversion( - dataset=masked_imaging_7x7, - linear_obj_list=[ - rectangular_mapper_7x7_3x3, - linear_obj, - delaunay_mapper_9_3x3, - linear_obj_1, - linear_obj_2, - ], - preloads=preloads, - settings=aa.SettingsInversion(use_w_tilde=True), - ) - - assert inversion_mapping.data_vector == pytest.approx( - inversion_w_tilde.data_vector, 1.0e-4 - ) - assert inversion_mapping.curvature_matrix == pytest.approx( - inversion_w_tilde.curvature_matrix, 1.0e-4 - ) - - -def test__inversion_imaging__linear_obj_func_with_w_tilde__include_preload_mapper_operated_mapping_matrix_dict( - masked_imaging_7x7, - rectangular_mapper_7x7_3x3, - delaunay_mapper_9_3x3, -): - masked_imaging_7x7 = copy.copy(masked_imaging_7x7) - masked_imaging_7x7.data[4] = 2.0 - masked_imaging_7x7.noise_map[3] = 4.0 - masked_imaging_7x7.psf[0] = 0.1 - masked_imaging_7x7.psf[4] = 0.9 - - mask = masked_imaging_7x7.mask - - grid = aa.Grid2D.from_mask(mask=mask) - - mapping_matrix = np.full(fill_value=0.5, shape=(9, 2)) - mapping_matrix[0, 0] = 0.8 - mapping_matrix[1, 1] = 0.4 - - linear_obj = aa.m.MockLinearObjFuncList( - parameters=2, grid=grid, mapping_matrix=mapping_matrix - ) - - linear_obj_1 = aa.m.MockLinearObjFuncList( - parameters=2, grid=grid, mapping_matrix=mapping_matrix - ) - - linear_obj_2 = aa.m.MockLinearObjFuncList( - parameters=2, grid=grid, mapping_matrix=mapping_matrix - ) - - inversion_mapping = aa.Inversion( - dataset=masked_imaging_7x7, - linear_obj_list=[ - rectangular_mapper_7x7_3x3, - linear_obj, - delaunay_mapper_9_3x3, - linear_obj_1, - linear_obj_2, - ], - settings=aa.SettingsInversion(use_w_tilde=False), - ) - - preloads = aa.Preloads( - mapper_operated_mapping_matrix_dict=inversion_mapping.mapper_operated_mapping_matrix_dict - ) - - inversion_w_tilde = aa.Inversion( - dataset=masked_imaging_7x7, - linear_obj_list=[ - rectangular_mapper_7x7_3x3, - linear_obj, - delaunay_mapper_9_3x3, - linear_obj_1, - linear_obj_2, - ], - preloads=preloads, - settings=aa.SettingsInversion(use_w_tilde=True), - ) - - assert inversion_mapping.data_vector == pytest.approx( - inversion_w_tilde.data_vector, 1.0e-4 - ) - - assert inversion_mapping.curvature_matrix == pytest.approx( - inversion_w_tilde.curvature_matrix, 1.0e-4 - ) - - def test__inversion_interferometer__via_mapper( interferometer_7_no_fft, rectangular_mapper_7x7_3x3, diff --git a/test_autoarray/inversion/pixelization/mesh/test_rectangular.py b/test_autoarray/inversion/pixelization/mesh/test_rectangular.py deleted file mode 100644 index aa18b5ec5..000000000 --- a/test_autoarray/inversion/pixelization/mesh/test_rectangular.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -import autoarray as aa - - -def test__preloads_used_for_relocated_grid(mask_2d_7x7): - mesh = aa.mesh.Rectangular(shape=(3, 3)) - - relocated_grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0) - - border_relocator = aa.BorderRelocator(mask=mask_2d_7x7, sub_size=1) - - mapper_grids = mesh.mapper_grids_from( - mask=mask_2d_7x7, - border_relocator=border_relocator, - source_plane_data_grid=relocated_grid, - source_plane_mesh_grid=None, - preloads=aa.Preloads(relocated_grid=relocated_grid), - ) - - assert mapper_grids.source_plane_data_grid == pytest.approx(relocated_grid, 1.0e-4) diff --git a/test_autoarray/inversion/pixelization/mesh/test_triangulation.py b/test_autoarray/inversion/pixelization/mesh/test_triangulation.py deleted file mode 100644 index eaf8ca5ce..000000000 --- a/test_autoarray/inversion/pixelization/mesh/test_triangulation.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -import autoarray as aa - - -def test___preloads_used_for_relocated_grid(mask_2d_7x7): - mesh = aa.mesh.Delaunay() - - relocated_grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0) - - border_relocator = aa.BorderRelocator(mask=mask_2d_7x7, sub_size=1) - - mapper_grids = mesh.mapper_grids_from( - mask=mask_2d_7x7, - border_relocator=border_relocator, - source_plane_data_grid=relocated_grid, - source_plane_mesh_grid=relocated_grid, - preloads=aa.Preloads(relocated_grid=relocated_grid), - ) - - assert mapper_grids.source_plane_data_grid == pytest.approx(relocated_grid, 1.0e-4) diff --git a/test_autoarray/test_preloads.py b/test_autoarray/test_preloads.py deleted file mode 100644 index e58c32b93..000000000 --- a/test_autoarray/test_preloads.py +++ /dev/null @@ -1,84 +0,0 @@ -import numpy as np - -import autoarray as aa - - -def test__set_w_tilde(): - # fit inversion is None, so no need to bother with w_tilde. - - fit_0 = aa.m.MockFitImaging(inversion=None) - fit_1 = aa.m.MockFitImaging(inversion=None) - - preloads = aa.Preloads(w_tilde=1, use_w_tilde=1) - preloads.set_w_tilde_imaging(fit_0=fit_0, fit_1=fit_1) - - assert preloads.w_tilde is None - assert preloads.use_w_tilde is False - - # Noise maps of fit are different but there is an inversion, so we should not preload w_tilde and use w_tilde. - - inversion = aa.m.MockInversion(linear_obj_list=[aa.m.MockMapper()]) - - fit_0 = aa.m.MockFitImaging( - inversion=inversion, - noise_map=aa.Array2D.zeros(shape_native=(3, 1), pixel_scales=0.1), - ) - fit_1 = aa.m.MockFitImaging( - inversion=inversion, - noise_map=aa.Array2D.ones(shape_native=(3, 1), pixel_scales=0.1), - ) - - preloads = aa.Preloads(w_tilde=1, use_w_tilde=1) - preloads.set_w_tilde_imaging(fit_0=fit_0, fit_1=fit_1) - - assert preloads.w_tilde is None - assert preloads.use_w_tilde is False - - # Noise maps of fits are the same so preload w_tilde and use it. - - noise_map = aa.Array2D.ones(shape_native=(5, 5), pixel_scales=0.1) - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True, True, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, True, True, True, True], - ] - ), - pixel_scales=noise_map.pixel_scales, - ) - - dataset = aa.m.MockDataset(psf=aa.Kernel2D.no_blur(pixel_scales=1.0), mask=mask) - - fit_0 = aa.m.MockFitImaging( - inversion=inversion, dataset=dataset, noise_map=noise_map - ) - fit_1 = aa.m.MockFitImaging( - inversion=inversion, dataset=dataset, noise_map=noise_map - ) - - preloads = aa.Preloads(w_tilde=1, use_w_tilde=1) - preloads.set_w_tilde_imaging(fit_0=fit_0, fit_1=fit_1) - - ( - curvature_preload, - indexes, - lengths, - ) = aa.util.inversion_imaging.w_tilde_curvature_preload_imaging_from( - noise_map_native=np.array(fit_0.noise_map.native), - kernel_native=np.array(fit_0.dataset.psf.native), - native_index_for_slim_index=np.array( - fit_0.dataset.mask.derive_indexes.native_for_slim - ), - ) - - assert preloads.w_tilde.curvature_preload[0] == curvature_preload[0] - assert preloads.w_tilde.indexes[0] == indexes[0] - assert preloads.w_tilde.lengths[0] == lengths[0] - assert preloads.w_tilde.noise_map_value == 1.0 - assert preloads.use_w_tilde == True - - From 280d57474cb7091b3876edd10a16779a5d1b1873 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Oct 2024 13:22:15 +0100 Subject: [PATCH 024/108] black --- autoarray/exc.py | 1 - autoarray/geometry/geometry_util.py | 10 +-- autoarray/inversion/inversion/abstract.py | 12 ++-- .../inversion/inversion/imaging/abstract.py | 16 +++-- .../pixelization/border_relocator.py | 12 ++-- .../inversion/plot/inversion_plotters.py | 7 +- autoarray/mask/mask_2d_util.py | 4 +- autoarray/numpy_wrapper.py | 1 - autoarray/operators/convolver.py | 65 ++++++++----------- .../over_sampling/over_sample_util.py | 32 ++++++--- autoarray/plot/wrap/base/colorbar.py | 6 +- autoarray/plot/wrap/base/title.py | 2 +- autoarray/structures/grids/grid_2d_util.py | 20 ++++-- autoarray/structures/grids/uniform_2d.py | 6 +- .../inversion/inversion/test_abstract.py | 2 - 15 files changed, 102 insertions(+), 94 deletions(-) diff --git a/autoarray/exc.py b/autoarray/exc.py index cfe79eb76..eed76b04e 100644 --- a/autoarray/exc.py +++ b/autoarray/exc.py @@ -106,7 +106,6 @@ class PlottingException(Exception): pass - class ProfilingException(Exception): """ Raises exceptions associated with in-built profiling tools (e.g. the `profile_func` decorator). diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 71fb9b1f1..5117252d9 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -390,10 +390,12 @@ def transform_grid_2d_to_reference_frame( theta_coordinate_to_profile = np.arctan2( shifted_grid_2d[:, 0], shifted_grid_2d[:, 1] ) - np.radians(angle) - return np.vstack([ - radius * np.sin(theta_coordinate_to_profile), - radius * np.cos(theta_coordinate_to_profile) - ]).T + return np.vstack( + [ + radius * np.sin(theta_coordinate_to_profile), + radius * np.cos(theta_coordinate_to_profile), + ] + ).T def transform_grid_2d_from_reference_frame( diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index 390d2f6fb..ca7e31861 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -493,12 +493,12 @@ def reconstruction(self) -> np.ndarray: solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0]) - solutions[ - values_to_solve - ] = inversion_util.reconstruction_positive_only_from( - data_vector=data_vector_input, - curvature_reg_matrix=curvature_reg_matrix_input, - settings=self.settings, + solutions[values_to_solve] = ( + inversion_util.reconstruction_positive_only_from( + data_vector=data_vector_input, + curvature_reg_matrix=curvature_reg_matrix_input, + settings=self.settings, + ) ) return solutions else: diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index c1d2a7842..1ea826f88 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -95,11 +95,13 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: """ return [ - self.convolver.convolve_mapping_matrix( - mapping_matrix=linear_obj.mapping_matrix + ( + self.convolver.convolve_mapping_matrix( + mapping_matrix=linear_obj.mapping_matrix + ) + if linear_obj.operated_mapping_matrix_override is None + else self.linear_func_operated_mapping_matrix_dict[linear_obj] ) - if linear_obj.operated_mapping_matrix_override is None - else self.linear_func_operated_mapping_matrix_dict[linear_obj] for linear_obj in self.linear_obj_list ] @@ -141,9 +143,9 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: mapping_matrix=linear_func.mapping_matrix ) - linear_func_operated_mapping_matrix_dict[ - linear_func - ] = operated_mapping_matrix + linear_func_operated_mapping_matrix_dict[linear_func] = ( + operated_mapping_matrix + ) return linear_func_operated_mapping_matrix_dict diff --git a/autoarray/inversion/pixelization/border_relocator.py b/autoarray/inversion/pixelization/border_relocator.py index da3cd86e1..8af010a7f 100644 --- a/autoarray/inversion/pixelization/border_relocator.py +++ b/autoarray/inversion/pixelization/border_relocator.py @@ -117,12 +117,12 @@ def sub_border_pixel_slim_indexes_from( int(border_pixel) ] - sub_border_pixels[ - border_1d_index - ] = grid_2d_util.furthest_grid_2d_slim_index_from( - grid_2d_slim=sub_grid_2d_slim, - slim_indexes=sub_border_pixels_of_border_pixel, - coordinate=mask_centre, + sub_border_pixels[border_1d_index] = ( + grid_2d_util.furthest_grid_2d_slim_index_from( + grid_2d_slim=sub_grid_2d_slim, + slim_indexes=sub_border_pixels_of_border_pixel, + coordinate=mask_centre, + ) ) return sub_border_pixels diff --git a/autoarray/inversion/plot/inversion_plotters.py b/autoarray/inversion/plot/inversion_plotters.py index c4f9846b4..3cee76a67 100644 --- a/autoarray/inversion/plot/inversion_plotters.py +++ b/autoarray/inversion/plot/inversion_plotters.py @@ -211,10 +211,9 @@ def figures_2d_of_pixelization( "inversion" ]["reconstruction_vmax_factor"] - self.mat_plot_2d.cmap.kwargs[ - "vmax" - ] = reconstruction_vmax_factor * np.max( - self.inversion.reconstruction + self.mat_plot_2d.cmap.kwargs["vmax"] = ( + reconstruction_vmax_factor + * np.max(self.inversion.reconstruction) ) vmax_custom = True diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index db2751b04..47db2413b 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -316,9 +316,7 @@ def elliptical_radius_from( y_scaled_elliptical = r_scaled * np.sin(theta_rotated) x_scaled_elliptical = r_scaled * np.cos(theta_rotated) - return np.sqrt( - x_scaled_elliptical**2.0 + (y_scaled_elliptical / axis_ratio) ** 2.0 - ) + return np.sqrt(x_scaled_elliptical**2.0 + (y_scaled_elliptical / axis_ratio) ** 2.0) @numba_util.jit() diff --git a/autoarray/numpy_wrapper.py b/autoarray/numpy_wrapper.py index 54edb6c40..3f534d995 100644 --- a/autoarray/numpy_wrapper.py +++ b/autoarray/numpy_wrapper.py @@ -9,7 +9,6 @@ import jax from jax import numpy as np, jit - print("JAX mode enabled") except ImportError: raise ImportError( diff --git a/autoarray/operators/convolver.py b/autoarray/operators/convolver.py index 21ce28f0f..5a52329b0 100644 --- a/autoarray/operators/convolver.py +++ b/autoarray/operators/convolver.py @@ -223,12 +223,12 @@ def __init__(self, mask, kernel): mask_index_array=self.mask_index_array, kernel_2d=np.array(self.kernel.native[:, :]), ) - self.image_frame_1d_indexes[ - mask_1d_index, : - ] = image_frame_1d_indexes - self.image_frame_1d_kernels[ - mask_1d_index, : - ] = image_frame_1d_kernels + self.image_frame_1d_indexes[mask_1d_index, :] = ( + image_frame_1d_indexes + ) + self.image_frame_1d_kernels[mask_1d_index, :] = ( + image_frame_1d_kernels + ) self.image_frame_1d_lengths[mask_1d_index] = image_frame_1d_indexes[ image_frame_1d_indexes >= 0 ].shape[0] @@ -265,15 +265,15 @@ def __init__(self, mask, kernel): mask_index_array=np.array(self.mask_index_array), kernel_2d=np.array(self.kernel.native), ) - self.blurring_frame_1d_indexes[ - mask_1d_index, : - ] = image_frame_1d_indexes - self.blurring_frame_1d_kernels[ - mask_1d_index, : - ] = image_frame_1d_kernels - self.blurring_frame_1d_lengths[ - mask_1d_index - ] = image_frame_1d_indexes[image_frame_1d_indexes >= 0].shape[0] + self.blurring_frame_1d_indexes[mask_1d_index, :] = ( + image_frame_1d_indexes + ) + self.blurring_frame_1d_kernels[mask_1d_index, :] = ( + image_frame_1d_kernels + ) + self.blurring_frame_1d_lengths[mask_1d_index] = ( + image_frame_1d_indexes[image_frame_1d_indexes >= 0].shape[0] + ) mask_1d_index += 1 @staticmethod @@ -317,33 +317,28 @@ def frame_at_coordinates_jit(coordinates, mask, mask_index_array, kernel_2d): return frame, kernel_frame - def jax_convolve(self, image, blurring_image, method='auto'): + def jax_convolve(self, image, blurring_image, method="auto"): slim_to_2D_index_image = jnp.nonzero( - jnp.logical_not(self.mask.array), - size=image.shape[0] + jnp.logical_not(self.mask.array), size=image.shape[0] ) slim_to_2D_index_blurring = jnp.nonzero( - jnp.logical_not(self.blurring_mask), - size=blurring_image.shape[0] + jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0] ) expanded_image_native = jnp.zeros(self.mask.shape) - expanded_image_native = expanded_image_native.at[ - slim_to_2D_index_image - ].set(image.array) - expanded_image_native = expanded_image_native.at[ - slim_to_2D_index_blurring - ].set(blurring_image.array) + expanded_image_native = expanded_image_native.at[slim_to_2D_index_image].set( + image.array + ) + expanded_image_native = expanded_image_native.at[slim_to_2D_index_blurring].set( + blurring_image.array + ) kernel = np.array(self.kernel.native.array) convolve_native = jax.scipy.signal.convolve( - expanded_image_native, - kernel, - mode='same', - method=method + expanded_image_native, kernel, mode="same", method=method ) convolve_slim = convolve_native[slim_to_2D_index_image] return convolve_slim - def convolve_image(self, image, blurring_image, jax_method='fft'): + def convolve_image(self, image, blurring_image, jax_method="fft"): """ For a given 1D array and blurring array, convolve the two using this convolver. @@ -371,14 +366,10 @@ def exception_message(): self.blurring_mask is None, lambda _: jax.debug.callback(exception_message), lambda _: None, - None + None, ) - return self.jax_convolve( - image, - blurring_image, - method=jax_method - ) + return self.jax_convolve(image, blurring_image, method=jax_method) else: if self.blurring_mask is None: diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index 1a0ae4290..d539051a0 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -219,9 +219,9 @@ def sub_slim_index_for_sub_native_index_from(sub_mask_2d: np.ndarray): for sub_mask_y in range(sub_mask_2d.shape[0]): for sub_mask_x in range(sub_mask_2d.shape[1]): if sub_mask_2d[sub_mask_y, sub_mask_x] == False: - sub_slim_index_for_sub_native_index[ - sub_mask_y, sub_mask_x - ] = sub_mask_1d_index + sub_slim_index_for_sub_native_index[sub_mask_y, sub_mask_x] = ( + sub_mask_1d_index + ) sub_mask_1d_index += 1 return sub_slim_index_for_sub_native_index @@ -407,18 +407,32 @@ def grid_2d_slim_over_sampled_via_mask_from( for x1 in range(sub): if use_jax: # while this makes it run, it is very, very slow - grid_slim = grid_slim.at[sub_index, 0].set(-( - y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0) - )) + grid_slim = grid_slim.at[sub_index, 0].set( + -( + y_scaled + - y_sub_half + + y1 * y_sub_step + + (y_sub_step / 2.0) + ) + ) grid_slim = grid_slim.at[sub_index, 1].set( - x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0) + x_scaled + - x_sub_half + + x1 * x_sub_step + + (x_sub_step / 2.0) ) else: grid_slim[sub_index, 0] = -( - y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0) + y_scaled + - y_sub_half + + y1 * y_sub_step + + (y_sub_step / 2.0) ) grid_slim[sub_index, 1] = ( - x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0) + x_scaled + - x_sub_half + + x1 * x_sub_step + + (x_sub_step / 2.0) ) sub_index += 1 diff --git a/autoarray/plot/wrap/base/colorbar.py b/autoarray/plot/wrap/base/colorbar.py index 272d93fcd..b0650013a 100644 --- a/autoarray/plot/wrap/base/colorbar.py +++ b/autoarray/plot/wrap/base/colorbar.py @@ -130,9 +130,9 @@ def tick_labels_from( cb_unit = units.colorbar_label middle_index = (len(manual_tick_labels) - 1) // 2 - manual_tick_labels[ - middle_index - ] = rf"{manual_tick_labels[middle_index]}{cb_unit}" + manual_tick_labels[middle_index] = ( + rf"{manual_tick_labels[middle_index]}{cb_unit}" + ) return manual_tick_labels diff --git a/autoarray/plot/wrap/base/title.py b/autoarray/plot/wrap/base/title.py index 08ab5261a..8185a7184 100644 --- a/autoarray/plot/wrap/base/title.py +++ b/autoarray/plot/wrap/base/title.py @@ -4,7 +4,7 @@ class Title(AbstractMatWrap): - def __init__(self, prefix: str = None, disable_log10_label : bool = False, **kwargs): + def __init__(self, prefix: str = None, disable_log10_label: bool = False, **kwargs): """ The settings used to customize the figure's title. diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index e3eb09c71..d81c405b2 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -57,6 +57,7 @@ def check_grid_2d(grid_2d: np.ndarray): def check_grid_2d_and_mask_2d(grid_2d: np.ndarray, mask_2d: Mask2D): if len(grid_2d.shape) == 2: + def exception_message(): raise exc.GridException( f""" @@ -68,17 +69,19 @@ def exception_message(): The mask number of pixels is {mask_2d.pixels_in_mask}. """ ) + if use_jax: jax.lax.cond( grid_2d.shape[0] != mask_2d.pixels_in_mask, lambda _: jax.debug.callback(exception_message), lambda _: None, - None + None, ) elif grid_2d.shape[0] != mask_2d.pixels_in_mask: exception_message() elif len(grid_2d.shape) == 3: + def exception_message(): raise exc.GridException( f""" @@ -89,12 +92,13 @@ def exception_message(): The mask shape_native is {mask_2d.shape_native}. """ ) + if use_jax: jax.lax.cond( (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, lambda _: jax.debug.callback(exception_message), lambda _: None, - None + None, ) elif (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: exception_message() @@ -283,8 +287,12 @@ def grid_2d_slim_via_mask_from( for x in range(mask_2d.shape[1]): if not mask_2d[y, x]: if use_jax: - grid_slim = grid_slim.at[index, 0].set(-(y - centres_scaled[0]) * pixel_scales[0]) - grid_slim = grid_slim.at[index, 1].set((x - centres_scaled[1]) * pixel_scales[1]) + grid_slim = grid_slim.at[index, 0].set( + -(y - centres_scaled[0]) * pixel_scales[0] + ) + grid_slim = grid_slim.at[index, 1].set( + (x - centres_scaled[1]) * pixel_scales[1] + ) else: grid_slim[index, 0] = -(y - centres_scaled[0]) * pixel_scales[0] grid_slim[index, 1] = (x - centres_scaled[1]) * pixel_scales[1] @@ -786,9 +794,7 @@ def grid_2d_slim_upscaled_from( The pixel scale of the uniform grid that laid over the irregular grid of (y,x) coordinates. """ - grid_2d_slim_upscaled = np.zeros( - shape=(grid_slim.shape[0] * upscale_factor**2, 2) - ) + grid_2d_slim_upscaled = np.zeros(shape=(grid_slim.shape[0] * upscale_factor**2, 2)) upscale_index = 0 diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 3ff0f5c80..9522326fb 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -845,10 +845,10 @@ def distances_to_coordinate_from( coordinate The (y,x) coordinate from which the distance of every grid (y,x) coordinate is computed. """ - squared_distance = self.squared_distances_to_coordinate_from(coordinate=coordinate) - distances = np.sqrt( - squared_distance.array + squared_distance = self.squared_distances_to_coordinate_from( + coordinate=coordinate ) + distances = np.sqrt(squared_distance.array) return Array2D(values=distances, mask=self.mask) def grid_2d_radial_projected_shape_slim_from( diff --git a/test_autoarray/inversion/inversion/test_abstract.py b/test_autoarray/inversion/inversion/test_abstract.py index 5d7fd622f..a90c901eb 100644 --- a/test_autoarray/inversion/inversion/test_abstract.py +++ b/test_autoarray/inversion/inversion/test_abstract.py @@ -319,7 +319,6 @@ def test__regularization_matrix(): assert inversion.regularization_matrix == pytest.approx(regularization_matrix) - def test__reconstruction_reduced(): linear_obj_list = [ aa.m.MockLinearObj(parameters=2, regularization=aa.m.MockRegularization()), @@ -538,7 +537,6 @@ def test__regularization_term(): assert inversion.regularization_term == 34.0 - def test__determinant_of_positive_definite_matrix_via_cholesky(): matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) From e6419042fd707bb8286641ac9d89878ebc5fee69 Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Fri, 25 Oct 2024 13:14:02 +0100 Subject: [PATCH 025/108] Wrap class as PyTree Needed to make the `autolens.Tracer` example work. --- autoarray/operators/over_sampling/uniform.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/autoarray/operators/over_sampling/uniform.py b/autoarray/operators/over_sampling/uniform.py index cd0751782..5c8652b0f 100644 --- a/autoarray/operators/over_sampling/uniform.py +++ b/autoarray/operators/over_sampling/uniform.py @@ -14,7 +14,10 @@ from autoarray import exc from autoarray.operators.over_sampling import over_sample_util +from autofit.jax_wrapper import register_pytree_node_class + +@register_pytree_node_class class OverSamplingUniform(AbstractOverSampling): def __init__(self, sub_size: Union[int, Array2D]): """ @@ -319,6 +322,15 @@ def over_sampler_from(self, mask: Mask2D) -> "OverSamplerUniform": mask=mask, sub_size=self.sub_size, ) + + def tree_flatten(self): + children = (self.sub_size,) + aux_data = None + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) class OverSamplerUniform(AbstractOverSampler): From 6cbfa7d05b1b41f276cefc502b0c599ea04d333f Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Mon, 4 Nov 2024 10:13:36 +0000 Subject: [PATCH 026/108] Changes needed for critical curve calculations These changes are mostly needed to speed up the (and make jit'able) the various functions withing `deflections.py` (in PyAutoGalaxy). While not all these changes are needed for the final "jax" method, it does make the previous method compatible when jax array's are used. --- autoarray/geometry/geometry_2d.py | 4 +- autoarray/geometry/geometry_util.py | 134 ++++++++++------ autoarray/mask/abstract_mask.py | 21 ++- autoarray/mask/mask_2d_util.py | 38 +++-- autoarray/operators/contour.py | 22 ++- .../over_sampling/over_sample_util.py | 13 +- autoarray/operators/over_sampling/uniform.py | 4 +- autoarray/structures/arrays/array_2d_util.py | 149 ++++++++++++------ autoarray/structures/grids/grid_2d_util.py | 31 ++-- autoarray/structures/grids/uniform_2d.py | 28 +++- autoarray/structures/vectors/uniform.py | 9 +- 11 files changed, 292 insertions(+), 161 deletions(-) diff --git a/autoarray/geometry/geometry_2d.py b/autoarray/geometry/geometry_2d.py index f14c0eccf..29c604405 100644 --- a/autoarray/geometry/geometry_2d.py +++ b/autoarray/geometry/geometry_2d.py @@ -13,6 +13,8 @@ from autoarray import type as ty from autoarray.geometry import geometry_util +from autofit.jax_wrapper import use_jax + logging.basicConfig() logger = logging.getLogger(__name__) @@ -234,7 +236,7 @@ def grid_pixels_2d_from(self, grid_scaled_2d: Grid2D) -> Grid2D: from autoarray.structures.grids.uniform_2d import Grid2D grid_pixels_2d = geometry_util.grid_pixels_2d_slim_from( - grid_scaled_2d_slim=np.array(grid_scaled_2d), + grid_scaled_2d_slim=np.array(grid_scaled_2d.array), shape_native=self.shape_native, pixel_scales=self.pixel_scales, origin=self.origin, diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 5117252d9..bbe7bc601 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -477,23 +477,28 @@ def grid_pixels_2d_slim_from( pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_pixels_2d_slim[slim_index, 0] = ( - (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d_slim[slim_index, 1] = ( - (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) + if use_jax: + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1, 1]) + return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 + else: + grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) + for slim_index in range(grid_scaled_2d_slim.shape[0]): + grid_pixels_2d_slim[slim_index, 0] = ( + (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) + + centres_scaled[0] + + 0.5 + ) + grid_pixels_2d_slim[slim_index, 1] = ( + (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) + + centres_scaled[1] + + 0.5 + ) return grid_pixels_2d_slim @@ -539,23 +544,32 @@ def grid_pixel_centres_2d_slim_from( pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_pixels_2d_slim[slim_index, 0] = int( - (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d_slim[slim_index, 1] = int( - (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) + if use_jax: + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + grid_pixels_2d_slim = ( + (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 + ).astype(int) + else: + grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) + + for slim_index in range(grid_scaled_2d_slim.shape[0]): + grid_pixels_2d_slim[slim_index, 0] = int( + (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) + + centres_scaled[0] + + 0.5 + ) + grid_pixels_2d_slim[slim_index, 1] = int( + (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) + + centres_scaled[1] + + 0.5 + ) return grid_pixels_2d_slim @@ -613,13 +627,18 @@ def grid_pixel_indexes_2d_slim_from( origin=origin, ) - grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0]) + if use_jax: + grid_pixel_indexes_2d_slim = ( + grid_pixels_2d_slim * np.array([shape_native[1], 1]) + ).sum(axis=1).astype(int) + else: + grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0]) - for slim_index in range(grid_pixels_2d_slim.shape[0]): - grid_pixel_indexes_2d_slim[slim_index] = int( - grid_pixels_2d_slim[slim_index, 0] * shape_native[1] - + grid_pixels_2d_slim[slim_index, 1] - ) + for slim_index in range(grid_pixels_2d_slim.shape[0]): + grid_pixel_indexes_2d_slim[slim_index] = int( + grid_pixels_2d_slim[slim_index, 0] * shape_native[1] + + grid_pixels_2d_slim[slim_index, 1] + ) return grid_pixel_indexes_2d_slim @@ -664,20 +683,25 @@ def grid_scaled_2d_slim_from( pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2)) - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) + if use_jax: + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1, 1]) + grid_scaled_2d_slim = (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign + else: + grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2)) - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_scaled_2d_slim[slim_index, 0] = ( - -(grid_pixels_2d_slim[slim_index, 0] - centres_scaled[0] - 0.5) - * pixel_scales[0] - ) - grid_scaled_2d_slim[slim_index, 1] = ( - grid_pixels_2d_slim[slim_index, 1] - centres_scaled[1] - 0.5 - ) * pixel_scales[1] + for slim_index in range(grid_scaled_2d_slim.shape[0]): + grid_scaled_2d_slim[slim_index, 0] = ( + -(grid_pixels_2d_slim[slim_index, 0] - centres_scaled[0] - 0.5) + * pixel_scales[0] + ) + grid_scaled_2d_slim[slim_index, 1] = ( + grid_pixels_2d_slim[slim_index, 1] - centres_scaled[1] - 0.5 + ) * pixel_scales[1] return grid_scaled_2d_slim @@ -723,20 +747,28 @@ def grid_pixel_centres_2d_from( pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2)) - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - for y in range(grid_scaled_2d.shape[0]): - for x in range(grid_scaled_2d.shape[1]): - grid_pixels_2d[y, x, 0] = int( - (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) + centres_scaled[0] + 0.5 - ) - grid_pixels_2d[y, x, 1] = int( - (grid_scaled_2d[y, x, 1] / pixel_scales[1]) + centres_scaled[1] + 0.5 - ) + if use_jax: + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + grid_pixels_2d = ( + (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 + ).astype(int) + else: + grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2)) + + for y in range(grid_scaled_2d.shape[0]): + for x in range(grid_scaled_2d.shape[1]): + grid_pixels_2d[y, x, 0] = int( + (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) + centres_scaled[0] + 0.5 + ) + grid_pixels_2d[y, x, 1] = int( + (grid_scaled_2d[y, x, 1] / pixel_scales[1]) + centres_scaled[1] + 0.5 + ) return grid_pixels_2d diff --git a/autoarray/mask/abstract_mask.py b/autoarray/mask/abstract_mask.py index 5401bcd09..fabee9c96 100644 --- a/autoarray/mask/abstract_mask.py +++ b/autoarray/mask/abstract_mask.py @@ -4,6 +4,8 @@ import logging from autoarray.numpy_wrapper import np, use_jax +if use_jax: + import jax from pathlib import Path from typing import Dict, Union @@ -74,12 +76,23 @@ def pixel_scale(self) -> float: For a mask with dimensions two or above check that are pixel scales are the same, and if so return this single value as a float. """ + def exception_message(): + raise exc.MaskException( + "Cannot return a pixel_scale for a grid where each dimension has a " + "different pixel scale (e.g. pixel_scales[0] != pixel_scales[1])" + ) + for pixel_scale in self.pixel_scales: - if abs(pixel_scale - self.pixel_scales[0]) > 1.0e-8: - raise exc.MaskException( - "Cannot return a pixel_scale for a grid where each dimension has a " - "different pixel scale (e.g. pixel_scales[0] != pixel_scales[1])" + cond = abs(pixel_scale - self.pixel_scales[0]) > 1.0e-8 + if use_jax: + jax.lax.cond( + cond, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None ) + elif cond: + exception_message() return self.pixel_scales[0] diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 47db2413b..4103e89d9 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -5,6 +5,7 @@ from autoarray import exc from autoarray import numba_util from autoarray import type as ty +from autoarray.numpy_wrapper import use_jax, np as jnp @numba_util.jit() @@ -66,15 +67,18 @@ def total_pixels_2d_from(mask_2d: np.ndarray) -> int: total_regular_pixels = total_regular_pixels_from(mask=mask) """ + if use_jax: + return (~mask_2d.astype(bool)).sum() - total_regular_pixels = 0 + else: + total_regular_pixels = 0 - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - total_regular_pixels += 1 + for y in range(mask_2d.shape[0]): + for x in range(mask_2d.shape[1]): + if not mask_2d[y, x]: + total_regular_pixels += 1 - return total_regular_pixels + return total_regular_pixels @numba_util.jit() @@ -1052,15 +1056,17 @@ def native_index_for_slim_index_2d_from( native_index_for_slim_index_2d = native_index_for_slim_index_2d_from(mask_2d=mask_2d) """ + if use_jax: + return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T + else: + total_pixels = total_pixels_2d_from(mask_2d=mask_2d) + native_index_for_slim_index_2d = np.zeros(shape=(total_pixels, 2)) + slim_index = 0 - total_pixels = total_pixels_2d_from(mask_2d=mask_2d) - native_index_for_slim_index_2d = np.zeros(shape=(total_pixels, 2)) - slim_index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - native_index_for_slim_index_2d[slim_index, :] = y, x - slim_index += 1 + for y in range(mask_2d.shape[0]): + for x in range(mask_2d.shape[1]): + if not mask_2d[y, x]: + native_index_for_slim_index_2d[slim_index, :] = y, x + slim_index += 1 - return native_index_for_slim_index_2d + return native_index_for_slim_index_2d diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index 1d58fd2a4..a85a73105 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -1,5 +1,6 @@ from __future__ import annotations -import numpy as np +from autoarray.numpy_wrapper import np, use_jax +import numpy from skimage import measure from scipy.spatial import ConvexHull from scipy.spatial import QhullError @@ -47,13 +48,17 @@ def contour_array(self): ).astype("int") arr = np.zeros(self.shape_native) - arr[tuple(np.array(pixel_centres).T)] = 1 + if use_jax: + arr = arr.at[tuple(np.array(pixel_centres).T)].set(1) + else: + arr[tuple(np.array(pixel_centres).T)] = 1 return arr @property def contour_list(self): - contour_indices_list = measure.find_contours(np.array(self.contour_array), 0) + # make sure to use base numpy to convert JAX array back to a normal array + contour_indices_list = measure.find_contours(numpy.array(self.contour_array.array), 0) if len(contour_indices_list) == 0: return [] @@ -67,8 +72,8 @@ def contour_list(self): pixel_scales=self.pixel_scales, ) - grid_scaled_1d[:, 0] -= self.pixel_scales[0] / 2.0 - grid_scaled_1d[:, 1] += self.pixel_scales[1] / 2.0 + factor = 0.5 * np.array(self.pixel_scales) * np.array([-1.0, 1.0]) + grid_scaled_1d += factor contour_list.append(Grid2DIrregular(values=grid_scaled_1d)) @@ -81,10 +86,11 @@ def hull( if self.grid.shape[0] < 3: return None - grid_convex = np.zeros((len(self.grid), 2)) + # cast JAX arrays to base numpy arrays + grid_convex = numpy.zeros((len(self.grid), 2)) - grid_convex[:, 0] = self.grid[:, 1] - grid_convex[:, 1] = self.grid[:, 0] + grid_convex[:, 0] = numpy.array(self.grid[:, 1]) + grid_convex[:, 1] = numpy.array(self.grid[:, 0]) try: hull = ConvexHull(grid_convex) diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index d539051a0..0a8d43c43 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -1,4 +1,4 @@ -from autoarray.numpy_wrapper import np, register_pytree_node_class, use_jax +from autoarray.numpy_wrapper import np, register_pytree_node_class, use_jax, jit from typing import List, Tuple @@ -504,9 +504,14 @@ def binned_array_2d_from( for y1 in range(sub): for x1 in range(sub): - binned_array_2d_slim[index] += ( - array_2d[sub_index] * sub_fraction[index] - ) + if use_jax: + binned_array_2d_slim = binned_array_2d_slim.at[index].add( + array_2d[sub_index] * sub_fraction[index] + ) + else: + binned_array_2d_slim[index] += ( + array_2d[sub_index] * sub_fraction[index] + ) sub_index += 1 index += 1 diff --git a/autoarray/operators/over_sampling/uniform.py b/autoarray/operators/over_sampling/uniform.py index 5c8652b0f..b6c8d71a3 100644 --- a/autoarray/operators/over_sampling/uniform.py +++ b/autoarray/operators/over_sampling/uniform.py @@ -439,9 +439,9 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": pass binned_array_2d = over_sample_util.binned_array_2d_from( - array_2d=np.array(array), + array_2d=np.array(array.array), mask_2d=np.array(self.mask), - sub_size=np.array(self.sub_size).astype("int"), + sub_size=np.array(self.sub_size.array).astype("int"), ) return Array2D( diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index 72cb8437e..53e0b1890 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -1,6 +1,6 @@ from __future__ import annotations from astropy.io import fits -import numpy as np +# import numpy as np import os from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union @@ -13,6 +13,10 @@ from autoarray.mask import mask_2d_util from autoarray import exc +from autoarray.numpy_wrapper import use_jax, np, jit +from functools import partial +if use_jax: + import jax def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: @@ -24,18 +28,34 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - - if type(array) is list: + if use_jax: + array = jax.lax.cond( + type(array) is list, + lambda _: np.asarray(array), + lambda _: array, + None + ) + elif type(array) is list: array = np.asarray(array) return array def check_array_2d(array_2d: np.ndarray): - if len(array_2d.shape) != 1: + def exception_message(): raise exc.ArrayException( "An array input into the Array2D.__new__ method is not of shape 1." ) + cond = len(array_2d.shape) != 1 + if use_jax: + jax.lax.cond( + cond, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None + ) + elif cond: + exception_message() def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): @@ -54,38 +74,58 @@ def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): mask_2d The mask of the output Array2D. """ - if len(array_2d.shape) == 1: - if array_2d.shape[0] != mask_2d.pixels_in_mask: - raise exc.ArrayException( - f""" - The input array is a slim 1D array, but it does not have the same number of entries as pixels in - the mask. + def exception_message_1(): + raise exc.ArrayException( + f""" + The input array is a slim 1D array, but it does not have the same number of entries as pixels in + the mask. - This indicates that the number of unmaksed pixels in the mask is different to the input slim array - shape. + This indicates that the number of unmaksed pixels in the mask is different to the input slim array + shape. - The shapes of the two arrays (which this exception is raised because they are different) are as follows: + The shapes of the two arrays (which this exception is raised because they are different) are as follows: - Input array_2d_slim.shape = {array_2d.shape[0]} - Input mask_2d.pixels_in_mask = {mask_2d.pixels_in_mask} - Input mask_2d.shape_native = {mask_2d.shape_native} - """ - ) + Input array_2d_slim.shape = {array_2d.shape[0]} + Input mask_2d.pixels_in_mask = {mask_2d.pixels_in_mask} + Input mask_2d.shape_native = {mask_2d.shape_native} + """ + ) + cond_1 = (len(array_2d.shape) == 1) and (array_2d.shape[0] != mask_2d.pixels_in_mask) + + if use_jax: + jax.lax.cond( + cond_1, + lambda _: jax.debug.callback(exception_message_1), + lambda _: None, + None + ) + elif cond_1: + exception_message_1() - if len(array_2d.shape) == 2: - if array_2d.shape != mask_2d.shape_native: - raise exc.ArrayException( - f""" - The input array is 2D but not the same dimensions as the mask. - - This indicates the mask's shape is different to the input array shape. - - The shapes of the two arrays (which this exception is raised because they are different) are as follows: - - Input array_2d shape = {array_2d.shape} - Input mask_2d shape_native = {mask_2d.shape_native} - """ - ) + def exception_message_2(): + raise exc.ArrayException( + f""" + The input array is 2D but not the same dimensions as the mask. + + This indicates the mask's shape is different to the input array shape. + + The shapes of the two arrays (which this exception is raised because they are different) are as follows: + + Input array_2d shape = {array_2d.shape} + Input mask_2d shape_native = {mask_2d.shape_native} + """ + ) + cond_2 = (len(array_2d.shape) == 2) and (array_2d.shape != mask_2d.shape_native) + + if use_jax: + jax.lax.cond( + cond_2, + lambda _: jax.debug.callback(exception_message_2), + lambda _: None, + None + ) + elif cond_2: + exception_message_2() def convert_array_2d( @@ -122,6 +162,8 @@ def convert_array_2d( check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d) is_native = len(array_2d.shape) == 2 + if use_jax: + mask_2d = mask_2d.array if is_native and not skip_mask: array_2d *= np.invert(mask_2d) @@ -533,18 +575,21 @@ def array_2d_slim_from( array_2d_slim = array_2d_slim_from(mask=mask, array_2d=array_2d) """ - total_pixels = mask_2d_util.total_pixels_2d_from( - mask_2d=mask_2d, - ) + if use_jax: + array_2d_slim = array_2d_native[~mask_2d.astype(bool)] + else: + total_pixels = mask_2d_util.total_pixels_2d_from( + mask_2d=mask_2d, + ) - array_2d_slim = np.zeros(shape=total_pixels) - index = 0 + array_2d_slim = np.zeros(shape=total_pixels) + index = 0 - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - array_2d_slim[index] = array_2d_native[y, x] - index += 1 + for y in range(mask_2d.shape[0]): + for x in range(mask_2d.shape[1]): + if not mask_2d[y, x]: + array_2d_slim[index] = array_2d_native[y, x] + index += 1 return array_2d_slim @@ -601,6 +646,7 @@ def array_2d_native_from( ) +@partial(jit, static_argnums=(1,)) @numba_util.jit() def array_2d_via_indexes_from( array_2d_slim: np.ndarray, @@ -634,13 +680,18 @@ def array_2d_via_indexes_from( ndarray The native 2D array of values mapped from the slimmed array with dimensions (total_values, total_values). """ - array_native_2d = np.zeros(shape) - - for slim_index in range(len(native_index_for_slim_index_2d)): - array_native_2d[ - native_index_for_slim_index_2d[slim_index, 0], - native_index_for_slim_index_2d[slim_index, 1], - ] = array_2d_slim[slim_index] + if use_jax: + array_native_2d = np.zeros(shape).at[ + tuple(native_index_for_slim_index_2d.T) + ].set(array_2d_slim) + else: + array_native_2d = np.zeros(shape) + + for slim_index in range(len(native_index_for_slim_index_2d)): + array_native_2d[ + native_index_for_slim_index_2d[slim_index, 0], + native_index_for_slim_index_2d[slim_index, 1], + ] = array_2d_slim[slim_index] return array_native_2d diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index d81c405b2..c659a032b 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -275,28 +275,27 @@ def grid_2d_slim_via_mask_from( total_pixels = mask_2d_util.total_pixels_2d_from(mask_2d) - grid_slim = np.zeros(shape=(total_pixels, 2)) - centres_scaled = geometry_util.central_scaled_coordinate_2d_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) - index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - if use_jax: - grid_slim = grid_slim.at[index, 0].set( - -(y - centres_scaled[0]) * pixel_scales[0] - ) - grid_slim = grid_slim.at[index, 1].set( - (x - centres_scaled[1]) * pixel_scales[1] - ) - else: + if use_jax: + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + grid_slim = ( + np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled + ) * sign * pixel_scales + else: + index = 0 + grid_slim = np.zeros(shape=(total_pixels, 2)) + + for y in range(mask_2d.shape[0]): + for x in range(mask_2d.shape[1]): + if not mask_2d[y, x]: grid_slim[index, 0] = -(y - centres_scaled[0]) * pixel_scales[0] grid_slim[index, 1] = (x - centres_scaled[1]) * pixel_scales[1] - index += 1 + index += 1 return grid_slim diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 9522326fb..b1c147af2 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -1009,10 +1009,16 @@ def scaled_minima(self) -> Tuple: The (y,x) minimum values of the grid in scaled units, buffed such that their extent is further than the grid's extent. """ - return ( - np.amin(self[:, 0]).astype("float"), - np.amin(self[:, 1]).astype("float"), - ) + if use_jax: + return ( + np.amin(self.array[:, 0]).astype("float"), + np.amin(self.array[:, 1]).astype("float"), + ) + else: + return ( + np.amin(self[:, 0]).astype("float"), + np.amin(self[:, 1]).astype("float"), + ) @property def scaled_maxima(self) -> Tuple: @@ -1020,10 +1026,16 @@ def scaled_maxima(self) -> Tuple: The (y,x) maximum values of the grid in scaled units, buffed such that their extent is further than the grid's extent. """ - return ( - np.amax(self[:, 0]).astype("float"), - np.amax(self[:, 1]).astype("float"), - ) + if use_jax: + return ( + np.amax(self.array[:, 0]).astype("float"), + np.amax(self.array[:, 1]).astype("float"), + ) + else: + return ( + np.amax(self[:, 0]).astype("float"), + np.amax(self[:, 1]).astype("float"), + ) def extent_with_buffer_from(self, buffer: float = 1.0e-8) -> List[float]: """ diff --git a/autoarray/structures/vectors/uniform.py b/autoarray/structures/vectors/uniform.py index 12aa440b5..0a882926c 100644 --- a/autoarray/structures/vectors/uniform.py +++ b/autoarray/structures/vectors/uniform.py @@ -1,5 +1,6 @@ import logging -import numpy as np +# import numpy as np +from autofit.jax_wrapper import numpy as np, use_jax from typing import List, Optional, Tuple, Union from autoarray.structures.arrays.uniform_2d import Array2D @@ -394,8 +395,12 @@ def magnitudes(self) -> Array2D: """ Returns the magnitude of every vector which are computed as sqrt(y**2 + x**2). """ + if use_jax: + s = self.array + else: + s = self return Array2D( - values=np.sqrt(self[:, 0] ** 2.0 + self[:, 1] ** 2.0), mask=self.mask + values=np.sqrt(s[:, 0] ** 2.0 + s[:, 1] ** 2.0), mask=self.mask ) @property From 1e2cc672781e7d77c274746c46dce702fc7ce326 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 16 Dec 2024 16:27:53 +0000 Subject: [PATCH 027/108] fix --- autoarray/structures/grids/irregular_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/structures/grids/irregular_2d.py b/autoarray/structures/grids/irregular_2d.py index a2aa50bf9..6c426b2ab 100644 --- a/autoarray/structures/grids/irregular_2d.py +++ b/autoarray/structures/grids/irregular_2d.py @@ -1,7 +1,7 @@ import logging from typing import List, Optional, Tuple, Union -from autoarray.numpy_wrapper import numpy as np +from autoarray.numpy_wrapper import np from autoarray.abstract_ndarray import AbstractNDArray from autoarray.geometry.geometry_2d_irregular import Geometry2DIrregular from autoarray.mask.mask_2d import Mask2D From c4962c93f109c5f778267404f1d2ff33c4ca8f48 Mon Sep 17 00:00:00 2001 From: Kolen Cheung Date: Thu, 19 Dec 2024 16:53:44 +0000 Subject: [PATCH 028/108] fix jit on mask_2d_circular_from --- autoarray/mask/mask_2d_util.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 4103e89d9..f6399d828 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -81,7 +81,7 @@ def total_pixels_2d_from(mask_2d: np.ndarray) -> int: return total_regular_pixels -@numba_util.jit() +@numba_util.jit(static_argnums=0) def mask_2d_circular_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, @@ -114,24 +114,12 @@ def mask_2d_circular_from( mask = mask_circular_from( shape=(10, 10), pixel_scales=0.1, radius=0.5, centre=(0.0, 0.0)) """ - - mask_2d = np.full(shape_native, True) - - centres_scaled = mask_2d_centres_from( - shape_native=mask_2d.shape, pixel_scales=pixel_scales, centre=centre - ) - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - y_scaled = (y - centres_scaled[0]) * pixel_scales[0] - x_scaled = (x - centres_scaled[1]) * pixel_scales[1] - - r_scaled = np.sqrt(x_scaled**2 + y_scaled**2) - - if r_scaled <= radius: - mask_2d[y, x] = False - - return mask_2d + centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) + ys, xs = np.mgrid[:shape_native[0], :shape_native[1]] + return (radius * radius) < ( + np.square((ys - centres_scaled[0]) * pixel_scales[0]) + + np.square((xs - centres_scaled[1]) * pixel_scales[1]) + ) @numba_util.jit() From ba31c1c66a35f6308bababdd5b944b93453641f2 Mon Sep 17 00:00:00 2001 From: Kolen Cheung Date: Thu, 19 Dec 2024 17:17:46 +0000 Subject: [PATCH 029/108] use indices instead of mgrid --- autoarray/mask/mask_2d_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index f6399d828..1981c317f 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -115,7 +115,7 @@ def mask_2d_circular_from( shape=(10, 10), pixel_scales=0.1, radius=0.5, centre=(0.0, 0.0)) """ centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) - ys, xs = np.mgrid[:shape_native[0], :shape_native[1]] + ys, xs = np.indices(shape_native) return (radius * radius) < ( np.square((ys - centres_scaled[0]) * pixel_scales[0]) + np.square((xs - centres_scaled[1]) * pixel_scales[1]) From 64d16b604a6e730abd8009cfdd432983e973798e Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 13 Mar 2025 21:03:44 +0000 Subject: [PATCH 030/108] remove static args --- autoarray/mask/mask_2d_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 1981c317f..f79f7acff 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -81,7 +81,7 @@ def total_pixels_2d_from(mask_2d: np.ndarray) -> int: return total_regular_pixels -@numba_util.jit(static_argnums=0) +@numba_util.jit() def mask_2d_circular_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, From 7f869aa41b992c0e81fa6e8681cc5998a4e1dcb2 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 18:49:23 +0000 Subject: [PATCH 031/108] mask circular converts and some aspect simplified --- autoarray/geometry/geometry_util.py | 26 ++-- .../pixelization/border_relocator.py | 2 +- .../inversion/pixelization/mesh/mesh_util.py | 2 +- autoarray/mask/abstract_mask.py | 1 + autoarray/mask/mask_2d_util.py | 114 ++++++------------ autoarray/operators/contour.py | 4 +- .../over_sampling/over_sample_util.py | 4 +- .../operators/over_sampling/over_sampler.py | 1 + autoarray/plot/multi_plotters.py | 2 +- autoarray/structures/arrays/array_2d_util.py | 40 +++--- autoarray/structures/grids/grid_2d_util.py | 8 +- autoarray/structures/vectors/uniform.py | 5 +- test_autoarray/mask/test_mask_2d_util.py | 8 -- 13 files changed, 89 insertions(+), 128 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index bbe7bc601..a795d42ee 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -182,7 +182,7 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] @numba_util.jit() def central_pixel_coordinates_2d_from( - shape_native: Tuple[int, int] + shape_native: Tuple[int, int], ) -> Tuple[float, float]: """ Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) @@ -477,7 +477,6 @@ def grid_pixels_2d_slim_from( pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) @@ -544,7 +543,6 @@ def grid_pixel_centres_2d_slim_from( pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) @@ -629,8 +627,10 @@ def grid_pixel_indexes_2d_slim_from( if use_jax: grid_pixel_indexes_2d_slim = ( - grid_pixels_2d_slim * np.array([shape_native[1], 1]) - ).sum(axis=1).astype(int) + (grid_pixels_2d_slim * np.array([shape_native[1], 1])) + .sum(axis=1) + .astype(int) + ) else: grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0]) @@ -690,7 +690,9 @@ def grid_scaled_2d_slim_from( centres_scaled = np.array(centres_scaled) pixel_scales = np.array(pixel_scales) sign = np.array([-1, 1]) - grid_scaled_2d_slim = (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign + grid_scaled_2d_slim = ( + (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign + ) else: grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2)) @@ -755,7 +757,7 @@ def grid_pixel_centres_2d_from( centres_scaled = np.array(centres_scaled) pixel_scales = np.array(pixel_scales) sign = np.array([-1.0, 1.0]) - grid_pixels_2d = ( + grid_pixels_2d = ( (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 ).astype(int) else: @@ -764,17 +766,21 @@ def grid_pixel_centres_2d_from( for y in range(grid_scaled_2d.shape[0]): for x in range(grid_scaled_2d.shape[1]): grid_pixels_2d[y, x, 0] = int( - (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) + centres_scaled[0] + 0.5 + (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) + + centres_scaled[0] + + 0.5 ) grid_pixels_2d[y, x, 1] = int( - (grid_scaled_2d[y, x, 1] / pixel_scales[1]) + centres_scaled[1] + 0.5 + (grid_scaled_2d[y, x, 1] / pixel_scales[1]) + + centres_scaled[1] + + 0.5 ) return grid_pixels_2d def extent_symmetric_from( - extent: Tuple[float, float, float, float] + extent: Tuple[float, float, float, float], ) -> Tuple[float, float, float, float]: """ Given an input extent of the form (x_min, x_max, y_min, y_max), this function returns an extent which is diff --git a/autoarray/inversion/pixelization/border_relocator.py b/autoarray/inversion/pixelization/border_relocator.py index 73a8d0d28..737444b53 100644 --- a/autoarray/inversion/pixelization/border_relocator.py +++ b/autoarray/inversion/pixelization/border_relocator.py @@ -48,7 +48,7 @@ def sub_slim_indexes_for_slim_index_via_mask_2d_from( sub_mask_1d_indexes_for_mask_1d_index = sub_mask_1d_indexes_for_mask_1d_index_from(mask=mask, sub_size=2) """ - total_pixels = mask_2d_util.total_pixels_2d_from(mask_2d=mask_2d) + total_pixels = np.sum(~mask_2d) sub_slim_indexes_for_slim_index = [[] for _ in range(total_pixels)] diff --git a/autoarray/inversion/pixelization/mesh/mesh_util.py b/autoarray/inversion/pixelization/mesh/mesh_util.py index 78cb4a860..305b56b72 100644 --- a/autoarray/inversion/pixelization/mesh/mesh_util.py +++ b/autoarray/inversion/pixelization/mesh/mesh_util.py @@ -7,7 +7,7 @@ @numba_util.jit() def rectangular_neighbors_from( - shape_native: Tuple[int, int] + shape_native: Tuple[int, int], ) -> Tuple[np.ndarray, np.ndarray]: """ Returns the 4 (or less) adjacent neighbors of every pixel on a rectangular pixelization as an ndarray of shape diff --git a/autoarray/mask/abstract_mask.py b/autoarray/mask/abstract_mask.py index bc80d6a1b..5ffc0bf7f 100644 --- a/autoarray/mask/abstract_mask.py +++ b/autoarray/mask/abstract_mask.py @@ -4,6 +4,7 @@ import logging from autoarray.numpy_wrapper import np, use_jax + if use_jax: import jax from pathlib import Path diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index f79f7acff..6a628a5dc 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -8,118 +8,78 @@ from autoarray.numpy_wrapper import use_jax, np as jnp -@numba_util.jit() def mask_2d_centres_from( - shape_native: Tuple[int, int], - pixel_scales: ty.PixelScales, - centre: Tuple[float, float], -) -> Tuple[float, float]: + shape_native: tuple[int, int], + pixel_scales: tuple[float, float], + centre: tuple[float, float], +) -> tuple[float, float]: """ - Returns the (y,x) scaled central coordinates of a mask from its shape, pixel-scales and centre. + Compute the (y, x) scaled central coordinates of a mask given its shape, pixel-scales, and centre. - The coordinate system is defined such that the positive y axis is up and positive x axis is right. + The coordinate system is defined such that the positive y-axis is up and the positive x-axis is right. Parameters ---------- shape_native - The (y,x) shape of the 2D array the scaled centre is computed for. + The shape of the 2D array in pixels. pixel_scales - The (y,x) scaled units to pixel units conversion factor of the 2D array. - centre : (float, flloat) - The (y,x) centre of the 2D mask. - - Returns - ------- - tuple (float, float) - The (y,x) scaled central coordinates of the input array. - - Examples - -------- - centres_scaled = centres_from(shape=(5,5), pixel_scales=(0.5, 0.5), centre=(0.0, 0.0)) - """ - y_centre_scaled = (float(shape_native[0] - 1) / 2) - (centre[0] / pixel_scales[0]) - x_centre_scaled = (float(shape_native[1] - 1) / 2) + (centre[1] / pixel_scales[1]) - - return (y_centre_scaled, x_centre_scaled) - - -@numba_util.jit() -def total_pixels_2d_from(mask_2d: np.ndarray) -> int: - """ - Returns the total number of unmasked pixels in a mask. - - Parameters - ---------- - mask_2d - A 2D array of bools, where `False` values are unmasked and included when counting pixels. + The conversion factors from pixels to scaled units. + centre + The central coordinate of the mask in scaled units. Returns ------- - int - The total number of pixels that are unmasked. + The (y, x) scaled central coordinates of the input array. Examples -------- - - mask = np.array([[True, False, True], - [False, False, False] - [True, False, True]]) - - total_regular_pixels = total_regular_pixels_from(mask=mask) + centres_scaled = mask_2d_centres_from(shape_native=(5, 5), pixel_scales=(0.5, 0.5), centre=(0.0, 0.0)) """ - if use_jax: - return (~mask_2d.astype(bool)).sum() - - else: - total_regular_pixels = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - total_regular_pixels += 1 - - return total_regular_pixels + return ( + 0.5 * (shape_native[0] - 1) - (centre[0] / pixel_scales[0]), + 0.5 * (shape_native[1] - 1) + (centre[1] / pixel_scales[1]), + ) -@numba_util.jit() def mask_2d_circular_from( - shape_native: Tuple[int, int], - pixel_scales: ty.PixelScales, + shape_native: tuple[int, int], + pixel_scales: tuple[float, float], radius: float, - centre: Tuple[float, float] = (0.0, 0.0), + centre: tuple[float, float] = (0.0, 0.0), ) -> np.ndarray: """ - Returns a circular mask from the 2D mask array shape and radius of the circle. + Create a circular mask within a 2D array. - This creates a 2D array where all values within the mask radius are unmasked and therefore `False`. + This generates a 2D array where all values within the specified radius are unmasked (set to `False`). Parameters ---------- - shape_native: Tuple[int, int] - The (y,x) shape of the mask in units of pixels. + shape_native + The shape of the mask array in pixels. pixel_scales - The scaled units to pixel units conversion factor of each pixel. + The conversion factors from pixels to scaled units. radius - The radius (in scaled units) of the circle within which pixels unmasked. + The radius of the circular mask in scaled units. centre - The centre of the circle used to mask pixels. + The central coordinate of the circle in scaled units. Returns ------- - ndarray - The 2D mask array whose central pixels are masked as a circle. + The 2D mask array with the central region defined by the radius unmasked (False). Examples -------- - mask = mask_circular_from( - shape=(10, 10), pixel_scales=0.1, radius=0.5, centre=(0.0, 0.0)) + mask = mask_2d_circular_from(shape_native=(10, 10), pixel_scales=(0.1, 0.1), radius=0.5, centre=(0.0, 0.0)) """ centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) - ys, xs = np.indices(shape_native) - return (radius * radius) < ( - np.square((ys - centres_scaled[0]) * pixel_scales[0]) + - np.square((xs - centres_scaled[1]) * pixel_scales[1]) - ) + + y, x = np.ogrid[: shape_native[0], : shape_native[1]] + y_scaled = (y - centres_scaled[0]) * pixel_scales[0] + x_scaled = (x - centres_scaled[1]) * pixel_scales[1] + + distances_squared = x_scaled**2 + y_scaled**2 + + return distances_squared >= radius**2 @numba_util.jit() @@ -1047,7 +1007,7 @@ def native_index_for_slim_index_2d_from( if use_jax: return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T else: - total_pixels = total_pixels_2d_from(mask_2d=mask_2d) + total_pixels = np.sum(~mask_2d) native_index_for_slim_index_2d = np.zeros(shape=(total_pixels, 2)) slim_index = 0 diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index a85a73105..c7da5c7f1 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -58,7 +58,9 @@ def contour_array(self): @property def contour_list(self): # make sure to use base numpy to convert JAX array back to a normal array - contour_indices_list = measure.find_contours(numpy.array(self.contour_array.array), 0) + contour_indices_list = measure.find_contours( + numpy.array(self.contour_array.array), 0 + ) if len(contour_indices_list) == 0: return [] diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index b0df135a8..a98276896 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -528,9 +528,7 @@ def binned_array_2d_from( grid_slim = grid_2d_slim_over_sampled_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), sub_size=1, origin=(0.0, 0.0)) """ - total_pixels = mask_2d_util.total_pixels_2d_from( - mask_2d=mask_2d, - ) + total_pixels = np.sum(~mask_2d) sub_fraction = 1.0 / sub_size**2 diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index d7187a33d..6492c00b7 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -11,6 +11,7 @@ from autofit.jax_wrapper import register_pytree_node_class + @register_pytree_node_class class OverSampler: def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): diff --git a/autoarray/plot/multi_plotters.py b/autoarray/plot/multi_plotters.py index 5c2c5071e..a58d08c02 100644 --- a/autoarray/plot/multi_plotters.py +++ b/autoarray/plot/multi_plotters.py @@ -315,7 +315,7 @@ def output_to_fits( output_path = self.plotter_list[0].mat_plot_2d.output.output_path_from( format="fits_multi" ) - output_fits_file = Path(output_path)/ f"{filename}.fits" + output_fits_file = Path(output_path) / f"{filename}.fits" if remove_fits_first: output_fits_file.unlink(missing_ok=True) diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index f147baf41..b75dd859e 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -1,5 +1,6 @@ from __future__ import annotations from astropy.io import fits + # import numpy as np import os from pathlib import Path @@ -15,6 +16,7 @@ from autoarray import exc from autoarray.numpy_wrapper import use_jax, np, jit from functools import partial + if use_jax: import jax @@ -30,10 +32,7 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: """ if use_jax: array = jax.lax.cond( - type(array) is list, - lambda _: np.asarray(array), - lambda _: array, - None + type(array) is list, lambda _: np.asarray(array), lambda _: array, None ) elif type(array) is list: array = np.asarray(array) @@ -46,13 +45,11 @@ def exception_message(): raise exc.ArrayException( "An array input into the Array2D.__new__ method is not of shape 1." ) + cond = len(array_2d.shape) != 1 if use_jax: jax.lax.cond( - cond, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None + cond, lambda _: jax.debug.callback(exception_message), lambda _: None, None ) elif cond: exception_message() @@ -74,6 +71,7 @@ def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): mask_2d The mask of the output Array2D. """ + def exception_message_1(): raise exc.ArrayException( f""" @@ -90,14 +88,17 @@ def exception_message_1(): Input mask_2d.shape_native = {mask_2d.shape_native} """ ) - cond_1 = (len(array_2d.shape) == 1) and (array_2d.shape[0] != mask_2d.pixels_in_mask) + + cond_1 = (len(array_2d.shape) == 1) and ( + array_2d.shape[0] != mask_2d.pixels_in_mask + ) if use_jax: jax.lax.cond( cond_1, lambda _: jax.debug.callback(exception_message_1), lambda _: None, - None + None, ) elif cond_1: exception_message_1() @@ -115,6 +116,7 @@ def exception_message_2(): Input mask_2d shape_native = {mask_2d.shape_native} """ ) + cond_2 = (len(array_2d.shape) == 2) and (array_2d.shape != mask_2d.shape_native) if use_jax: @@ -122,7 +124,7 @@ def exception_message_2(): cond_2, lambda _: jax.debug.callback(exception_message_2), lambda _: None, - None + None, ) elif cond_2: exception_message_2() @@ -578,9 +580,7 @@ def array_2d_slim_from( if use_jax: array_2d_slim = array_2d_native[~mask_2d.astype(bool)] else: - total_pixels = mask_2d_util.total_pixels_2d_from( - mask_2d=mask_2d, - ) + total_pixels = np.sum(~mask_2d) array_2d_slim = np.zeros(shape=total_pixels) index = 0 @@ -681,9 +681,11 @@ def array_2d_via_indexes_from( The native 2D array of values mapped from the slimmed array with dimensions (total_values, total_values). """ if use_jax: - array_native_2d = np.zeros(shape).at[ - tuple(native_index_for_slim_index_2d.T) - ].set(array_2d_slim) + array_native_2d = ( + np.zeros(shape) + .at[tuple(native_index_for_slim_index_2d.T)] + .set(array_2d_slim) + ) else: array_native_2d = np.zeros(shape) @@ -728,9 +730,7 @@ def array_2d_slim_complex_from( A 1D array of values mapped from the 2D array with dimensions (total_unmasked_pixels). """ - total_pixels = mask_2d_util.total_pixels_2d_from( - mask_2d=mask, - ) + total_pixels = np.sum(~mask_2d) array_1d = 0 + 0j * np.zeros(shape=total_pixels) index = 0 diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index c659a032b..44c75c7c5 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -273,7 +273,7 @@ def grid_2d_slim_via_mask_from( grid_slim = grid_2d_slim_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - total_pixels = mask_2d_util.total_pixels_2d_from(mask_2d) + total_pixels = np.sum(~mask_2d) centres_scaled = geometry_util.central_scaled_coordinate_2d_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin @@ -284,8 +284,10 @@ def grid_2d_slim_via_mask_from( pixel_scales = np.array(pixel_scales) sign = np.array([-1.0, 1.0]) grid_slim = ( - np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled - ) * sign * pixel_scales + (np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled) + * sign + * pixel_scales + ) else: index = 0 grid_slim = np.zeros(shape=(total_pixels, 2)) diff --git a/autoarray/structures/vectors/uniform.py b/autoarray/structures/vectors/uniform.py index 0a882926c..89d589139 100644 --- a/autoarray/structures/vectors/uniform.py +++ b/autoarray/structures/vectors/uniform.py @@ -1,4 +1,5 @@ import logging + # import numpy as np from autofit.jax_wrapper import numpy as np, use_jax from typing import List, Optional, Tuple, Union @@ -399,9 +400,7 @@ def magnitudes(self) -> Array2D: s = self.array else: s = self - return Array2D( - values=np.sqrt(s[:, 0] ** 2.0 + s[:, 1] ** 2.0), mask=self.mask - ) + return Array2D(values=np.sqrt(s[:, 0] ** 2.0 + s[:, 1] ** 2.0), mask=self.mask) @property def y(self) -> Array2D: diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index a3b05c894..ef2d3481a 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -5,14 +5,6 @@ import pytest -def test__total_pixels_2d_from(): - mask_2d = np.array( - [[True, False, True], [False, False, False], [True, False, True]] - ) - - assert util.mask_2d.total_pixels_2d_from(mask_2d=mask_2d) == 5 - - def test__total_edge_pixels_from_mask(): mask_2d = np.array( [ From 581aa75e76a752138c9b560fdf99660af8cbb878 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 18:58:59 +0000 Subject: [PATCH 032/108] mask_circular_annular_from converted --- autoarray/mask/mask_2d_util.py | 50 ++++++++++++++-------------------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 6a628a5dc..3269f060f 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -82,60 +82,50 @@ def mask_2d_circular_from( return distances_squared >= radius**2 -@numba_util.jit() def mask_2d_circular_annular_from( - shape_native: Tuple[int, int], - pixel_scales: ty.PixelScales, + shape_native: tuple[int, int], + pixel_scales: tuple[float, float], inner_radius: float, outer_radius: float, - centre: Tuple[float, float] = (0.0, 0.0), + centre: tuple[float, float] = (0.0, 0.0), ) -> np.ndarray: """ - Returns an circular annular mask from an input inner and outer mask radius and shape. + Create a circular annular mask within a 2D array. - This creates a 2D array where all values within the inner and outer radii are unmasked and therefore `False`. + This generates a 2D array where all values within the specified inner and outer radii are unmasked (set to `False`). Parameters ---------- shape_native - The (y,x) shape of the mask in units of pixels. + The shape of the mask array in pixels. pixel_scales - The scaled units to pixel units conversion factor of each pixel. + The conversion factors from pixels to scaled units. inner_radius - The radius (in scaled units) of the inner circle outside of which pixels are unmasked. + The inner radius of the annular mask in scaled units. outer_radius - The radius (in scaled units) of the outer circle within which pixels are unmasked. + The outer radius of the annular mask in scaled units. centre - The centre of the annulus used to mask pixels. + The central coordinate of the annulus in scaled units. Returns ------- - ndarray - The 2D mask array whose central pixels are masked as a annulus. + The 2D mask array with the region between the inner and outer radii unmasked (False). Examples -------- - mask = mask_annnular_from( - shape=(10, 10), pixel_scales=0.1, inner_radius=0.5, outer_radius=1.5, centre=(0.0, 0.0)) - """ - - mask_2d = np.full(shape_native, True) - - centres_scaled = mask_2d_centres_from( - shape_native=mask_2d.shape, pixel_scales=pixel_scales, centre=centre + mask = mask_2d_circular_annular_from( + shape_native=(10, 10), pixel_scales=(0.1, 0.1), inner_radius=0.5, outer_radius=1.5, centre=(0.0, 0.0) ) + """ + centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - y_scaled = (y - centres_scaled[0]) * pixel_scales[0] - x_scaled = (x - centres_scaled[1]) * pixel_scales[1] - - r_scaled = np.sqrt(x_scaled**2 + y_scaled**2) + y, x = np.ogrid[:shape_native[0], :shape_native[1]] + y_scaled = (y - centres_scaled[0]) * pixel_scales[0] + x_scaled = (x - centres_scaled[1]) * pixel_scales[1] - if outer_radius >= r_scaled >= inner_radius: - mask_2d[y, x] = False + distances_squared = x_scaled**2 + y_scaled**2 - return mask_2d + return ~((distances_squared >= inner_radius**2) & (distances_squared <= outer_radius**2)) @numba_util.jit() From ed60fdac85df548baa8eca0ec36e817188610a64 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 19:00:04 +0000 Subject: [PATCH 033/108] update typing --- autoarray/mask/mask_2d_util.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 3269f060f..102ff771b 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -9,10 +9,10 @@ def mask_2d_centres_from( - shape_native: tuple[int, int], - pixel_scales: tuple[float, float], - centre: tuple[float, float], -) -> tuple[float, float]: + shape_native: Tuple[int, int], + pixel_scales: Tuple[float, float], + centre: Tuple[float, float], +) -> Tuple[float, float]: """ Compute the (y, x) scaled central coordinates of a mask given its shape, pixel-scales, and centre. @@ -42,10 +42,10 @@ def mask_2d_centres_from( def mask_2d_circular_from( - shape_native: tuple[int, int], - pixel_scales: tuple[float, float], + shape_native: Tuple[int, int], + pixel_scales: Tuple[float, float], radius: float, - centre: tuple[float, float] = (0.0, 0.0), + centre: Tuple[float, float] = (0.0, 0.0), ) -> np.ndarray: """ Create a circular mask within a 2D array. @@ -83,11 +83,11 @@ def mask_2d_circular_from( def mask_2d_circular_annular_from( - shape_native: tuple[int, int], - pixel_scales: tuple[float, float], + shape_native: Tuple[int, int], + pixel_scales: Tuple[float, float], inner_radius: float, outer_radius: float, - centre: tuple[float, float] = (0.0, 0.0), + centre: Tuple[float, float] = (0.0, 0.0), ) -> np.ndarray: """ Create a circular annular mask within a 2D array. From 14609842c65eb7713912a518e2042e440a54840c Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 19:01:52 +0000 Subject: [PATCH 034/108] remove anti annular --- autoarray/mask/mask_2d.py | 59 --------------- autoarray/mask/mask_2d_util.py | 65 ----------------- test_autoarray/mask/test_mask_2d.py | 39 ---------- test_autoarray/mask/test_mask_2d_util.py | 91 ------------------------ 4 files changed, 254 deletions(-) diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 18a4c8fea..9cecf4b24 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -380,65 +380,6 @@ def circular_annular( invert=invert, ) - @classmethod - def circular_anti_annular( - cls, - shape_native: Tuple[int, int], - inner_radius: float, - outer_radius: float, - outer_radius_2: float, - pixel_scales: ty.PixelScales, - origin: Tuple[float, float] = (0.0, 0.0), - centre: Tuple[float, float] = (0.0, 0.0), - invert: bool = False, - ) -> "Mask2D": - """ - Returns a Mask2D (see *Mask2D.__new__*) where all `False` entries are within an inner circle and second - outer circle, forming an inverse annulus. - - The `inner_radius`, `outer_radius`, `outer_radius_2` and `centre` are all input in scaled units. - - Parameters - ---------- - shape_native - The (y,x) shape of the mask in units of pixels. - inner_radius - The inner radius in scaled units of the annulus within which pixels are `False` and unmasked. - outer_radius - The first outer radius in scaled units of the annulus within which pixels are `True` and masked. - outer_radius_2 - The second outer radius in scaled units of the annulus within which pixels are `False` and unmasked and - outside of which all entries are `True` and masked. - pixel_scales - The (y,x) scaled units to pixel units conversion factors of every pixel. If this is input as a `float`, - it is converted to a (float, float) structure. - origin - The (y,x) scaled units origin of the mask's coordinate system. - centre - The (y,x) scaled units centre of the anti-annulus used to mask pixels. - invert - If `True`, the `bool`'s of the input `mask` are inverted, for example `False`'s become `True` - and visa versa. - """ - - pixel_scales = geometry_util.convert_pixel_scales_2d(pixel_scales=pixel_scales) - - mask = mask_2d_util.mask_2d_circular_anti_annular_from( - shape_native=shape_native, - pixel_scales=pixel_scales, - inner_radius=inner_radius, - outer_radius=outer_radius, - outer_radius_2_scaled=outer_radius_2, - centre=centre, - ) - - return cls( - mask=mask, - pixel_scales=pixel_scales, - origin=origin, - invert=invert, - ) - @classmethod def elliptical( cls, diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 102ff771b..c8c0bf62f 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -128,71 +128,6 @@ def mask_2d_circular_annular_from( return ~((distances_squared >= inner_radius**2) & (distances_squared <= outer_radius**2)) -@numba_util.jit() -def mask_2d_circular_anti_annular_from( - shape_native: Tuple[int, int], - pixel_scales: ty.PixelScales, - inner_radius: float, - outer_radius: float, - outer_radius_2_scaled: float, - centre: Tuple[float, float] = (0.0, 0.0), -) -> np.ndarray: - """ - Returns an anti-annular mask from an input inner and outer mask radius and shape. The anti-annular is analogous to - the annular mask but inverted, whereby its unmasked values are those inside the annulus. - - This creates a 2D array where all values outside the inner and outer radii are unmasked and therefore `False`. - - Parameters - ---------- - shape_native - The (y,x) shape of the mask in units of pixels. - pixel_scales - The scaled units to pixel units conversion factor of each pixel. - inner_radius - The inner radius in scaled units of the annulus within which pixels are `False` and unmasked. - outer_radius - The first outer radius in scaled units of the annulus within which pixels are `True` and masked. - outer_radius_2 - The second outer radius in scaled units of the annulus within which pixels are `False` and unmasked and - outside of which all entries are `True` and masked. - centre - The centre of the annulus used to mask pixels. - - Returns - ------- - ndarray - The 2D mask array whose central pixels are masked as a annulus. - - Examples - -------- - mask = mask_annnular_from( - shape=(10, 10), pixel_scales=0.1, inner_radius=0.5, outer_radius=1.5, centre=(0.0, 0.0)) - - """ - - mask_2d = np.full(shape_native, True) - - centres_scaled = mask_2d_centres_from( - shape_native=mask_2d.shape, pixel_scales=pixel_scales, centre=centre - ) - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - y_scaled = (y - centres_scaled[0]) * pixel_scales[0] - x_scaled = (x - centres_scaled[1]) * pixel_scales[1] - - r_scaled = np.sqrt(x_scaled**2 + y_scaled**2) - - if ( - inner_radius >= r_scaled - or outer_radius_2_scaled >= r_scaled >= outer_radius - ): - mask_2d[y, x] = False - - return mask_2d - - def mask_2d_via_pixel_coordinates_from( shape_native: Tuple[int, int], pixel_coordinates: [list], buffer: int = 0 ) -> np.ndarray: diff --git a/test_autoarray/mask/test_mask_2d.py b/test_autoarray/mask/test_mask_2d.py index 3b80030e6..2ad05a0a7 100644 --- a/test_autoarray/mask/test_mask_2d.py +++ b/test_autoarray/mask/test_mask_2d.py @@ -141,45 +141,6 @@ def test__circular_annular(): assert mask.origin == (0.0, 0.0) assert mask.mask_centre == (0.0, 0.0) - -def test__circular_anti_annular(): - mask_via_util = aa.util.mask_2d.mask_2d_circular_anti_annular_from( - shape_native=(9, 9), - pixel_scales=(1.2, 1.2), - inner_radius=0.8, - outer_radius=2.2, - outer_radius_2_scaled=3.0, - centre=(0.0, 0.0), - ) - - mask = aa.Mask2D.circular_anti_annular( - shape_native=(9, 9), - pixel_scales=(1.2, 1.2), - inner_radius=0.8, - outer_radius=2.2, - outer_radius_2=3.0, - centre=(0.0, 0.0), - ) - - assert (mask == mask_via_util).all() - assert mask.origin == (0.0, 0.0) - assert mask.mask_centre == (0.0, 0.0) - - mask = aa.Mask2D.circular_anti_annular( - shape_native=(9, 9), - pixel_scales=(1.2, 1.2), - inner_radius=0.8, - outer_radius=2.2, - outer_radius_2=3.0, - centre=(0.0, 0.0), - invert=True, - ) - - assert (mask == np.invert(mask_via_util)).all() - assert mask.origin == (0.0, 0.0) - assert mask.mask_centre == (0.0, 0.0) - - def test__elliptical(): mask_via_util = aa.util.mask_2d.mask_2d_elliptical_from( shape_native=(8, 5), diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index ef2d3481a..344f8a1c1 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -234,97 +234,6 @@ def test__mask_2d_circular_annular_from__input_centre(): ).all() -def test__mask_2d_circular_anti_annular_from(): - mask = util.mask_2d.mask_2d_circular_anti_annular_from( - shape_native=(5, 5), - pixel_scales=(1.0, 1.0), - inner_radius=0.5, - outer_radius=10.0, - outer_radius_2_scaled=20.0, - ) - - assert ( - mask - == np.array( - [ - [True, True, True, True, True], - [True, True, True, True, True], - [True, True, False, True, True], - [True, True, True, True, True], - [True, True, True, True, True], - ] - ) - ).all() - - mask = util.mask_2d.mask_2d_circular_anti_annular_from( - shape_native=(5, 5), - pixel_scales=(0.1, 1.0), - inner_radius=1.5, - outer_radius=10.0, - outer_radius_2_scaled=20.0, - ) - - assert ( - mask - == np.array( - [ - [True, False, False, False, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, False, False, False, True], - ] - ) - ).all() - - mask = util.mask_2d.mask_2d_circular_anti_annular_from( - shape_native=(5, 5), - pixel_scales=(1.0, 1.0), - inner_radius=0.5, - outer_radius=1.5, - outer_radius_2_scaled=20.0, - ) - - assert ( - mask - == np.array( - [ - [False, False, False, False, False], - [False, True, True, True, False], - [False, True, False, True, False], - [False, True, True, True, False], - [False, False, False, False, False], - ] - ) - ).all() - - -def test__mask_2d_circular_anti_annular_from__include_centre(): - mask = util.mask_2d.mask_2d_circular_anti_annular_from( - shape_native=(7, 7), - pixel_scales=(3.0, 3.0), - inner_radius=1.5, - outer_radius=4.5, - outer_radius_2_scaled=8.7, - centre=(-3.0, 3.0), - ) - - assert ( - mask - == np.array( - [ - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - [True, True, False, False, False, False, False], - [True, True, False, True, True, True, False], - [True, True, False, True, False, True, False], - [True, True, False, True, True, True, False], - [True, True, False, False, False, False, False], - ] - ) - ).all() - - def test__mask_2d_elliptical_from(): mask = util.mask_2d.mask_2d_elliptical_from( shape_native=(3, 3), From c4fe49f2fef9f71afa55c0742ecfa053d95a1ed8 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 19:04:01 +0000 Subject: [PATCH 035/108] move from pixel coordinates --- autoarray/mask/mask_2d_util.py | 68 +++++++++++++++-------------- test_autoarray/mask/test_mask_2d.py | 1 + 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index c8c0bf62f..1f56b96b1 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -119,44 +119,15 @@ def mask_2d_circular_annular_from( """ centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) - y, x = np.ogrid[:shape_native[0], :shape_native[1]] + y, x = np.ogrid[: shape_native[0], : shape_native[1]] y_scaled = (y - centres_scaled[0]) * pixel_scales[0] x_scaled = (x - centres_scaled[1]) * pixel_scales[1] distances_squared = x_scaled**2 + y_scaled**2 - return ~((distances_squared >= inner_radius**2) & (distances_squared <= outer_radius**2)) - - -def mask_2d_via_pixel_coordinates_from( - shape_native: Tuple[int, int], pixel_coordinates: [list], buffer: int = 0 -) -> np.ndarray: - """ - Returns a mask where all unmasked `False` entries are defined from an input list of list of pixel coordinates. - - These may be buffed via an input ``buffer``, whereby all entries in all 8 neighboring directions by this - amount. - - Parameters - ---------- - shape_native (int, int) - The (y,x) shape of the mask in units of pixels. - pixel_coordinates : [[int, int]] - The input lists of 2D pixel coordinates where `False` entries are created. - buffer - All input ``pixel_coordinates`` are buffed with `False` entries in all 8 neighboring directions by this - amount. - """ - - mask_2d = np.full(shape=shape_native, fill_value=True) - - for y, x in pixel_coordinates: - mask_2d[y, x] = False - - if buffer == 0: - return mask_2d - else: - return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) + return ~( + (distances_squared >= inner_radius**2) & (distances_squared <= outer_radius**2) + ) @numba_util.jit() @@ -342,6 +313,37 @@ def mask_2d_elliptical_annular_from( return mask_2d +def mask_2d_via_pixel_coordinates_from( + shape_native: Tuple[int, int], pixel_coordinates: [list], buffer: int = 0 +) -> np.ndarray: + """ + Returns a mask where all unmasked `False` entries are defined from an input list of list of pixel coordinates. + + These may be buffed via an input ``buffer``, whereby all entries in all 8 neighboring directions by this + amount. + + Parameters + ---------- + shape_native (int, int) + The (y,x) shape of the mask in units of pixels. + pixel_coordinates : [[int, int]] + The input lists of 2D pixel coordinates where `False` entries are created. + buffer + All input ``pixel_coordinates`` are buffed with `False` entries in all 8 neighboring directions by this + amount. + """ + + mask_2d = np.full(shape=shape_native, fill_value=True) + + for y, x in pixel_coordinates: + mask_2d[y, x] = False + + if buffer == 0: + return mask_2d + else: + return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) + + @numba_util.jit() def blurring_mask_2d_from( mask_2d: np.ndarray, kernel_shape_native: Tuple[int, int] diff --git a/test_autoarray/mask/test_mask_2d.py b/test_autoarray/mask/test_mask_2d.py index 2ad05a0a7..eafac1efc 100644 --- a/test_autoarray/mask/test_mask_2d.py +++ b/test_autoarray/mask/test_mask_2d.py @@ -141,6 +141,7 @@ def test__circular_annular(): assert mask.origin == (0.0, 0.0) assert mask.mask_centre == (0.0, 0.0) + def test__elliptical(): mask_via_util = aa.util.mask_2d.mask_2d_elliptical_from( shape_native=(8, 5), From a5544b4ff9c93892a9bd98d466381fab420028ff Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 19:14:14 +0000 Subject: [PATCH 036/108] mask_2d_elliptical_From --- autoarray/mask/mask_2d_util.py | 60 ++++++++++++++++------------------ 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 1f56b96b1..a73e6c909 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -167,67 +167,65 @@ def elliptical_radius_from( return np.sqrt(x_scaled_elliptical**2.0 + (y_scaled_elliptical / axis_ratio) ** 2.0) -@numba_util.jit() + def mask_2d_elliptical_from( shape_native: Tuple[int, int], - pixel_scales: ty.PixelScales, + pixel_scales: Tuple[float, float], major_axis_radius: float, axis_ratio: float, angle: float, centre: Tuple[float, float] = (0.0, 0.0), ) -> np.ndarray: """ - Returns an elliptical mask from an input major-axis mask radius, axis-ratio, rotational angle, shape and - centre. + Create an elliptical mask within a 2D array. - This creates a 2D array where all values within the ellipse are unmasked and therefore `False`. + This generates a 2D array where all values within the specified ellipse are unmasked (set to `False`). Parameters ---------- - shape_native: Tuple[int, int] - The (y,x) shape of the mask in units of pixels. + shape_native + The shape of the mask array in pixels. pixel_scales - The scaled units to pixel units conversion factor of each pixel. + The conversion factors from pixels to scaled units. major_axis_radius - The major-axis (in scaled units) of the ellipse within which pixels are unmasked. + The major axis radius of the elliptical mask in scaled units. axis_ratio - The axis-ratio of the ellipse within which pixels are unmasked. + The axis ratio of the ellipse (minor axis / major axis). angle - The rotation angle of the ellipse within which pixels are unmasked, (counter-clockwise from the positive - x-axis). + The rotation angle of the ellipse in degrees, counter-clockwise from the positive x-axis. centre - The centre of the ellipse used to mask pixels. + The central coordinate of the ellipse in scaled units. Returns ------- - ndarray - The 2D mask array whose central pixels are masked as an ellipse. + np.ndarray + The 2D mask array with the elliptical region defined by the major axis radius unmasked (False). Examples -------- - mask = mask_elliptical_from( - shape=(10, 10), pixel_scales=0.1, major_axis_radius=0.5, ell_comps=(0.333333, 0.0), centre=(0.0, 0.0)) + mask = mask_2d_elliptical_from( + shape_native=(10, 10), pixel_scales=(0.1, 0.1), major_axis_radius=0.5, axis_ratio=0.5, angle=45.0, centre=(0.0, 0.0) + ) """ + centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) - mask_2d = np.full(shape_native, True) + y, x = np.ogrid[:shape_native[0], :shape_native[1]] + y_scaled = (y - centres_scaled[0]) * pixel_scales[0] + x_scaled = (x - centres_scaled[1]) * pixel_scales[1] - centres_scaled = mask_2d_centres_from( - shape_native=mask_2d.shape, pixel_scales=pixel_scales, centre=centre - ) + # Rotate the coordinates by the angle (counterclockwise) - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - y_scaled = (y - centres_scaled[0]) * pixel_scales[0] - x_scaled = (x - centres_scaled[1]) * pixel_scales[1] + r_scaled = np.sqrt(x_scaled**2 + y_scaled**2) - r_scaled_elliptical = elliptical_radius_from( - y_scaled, x_scaled, angle, axis_ratio - ) + theta_rotated = np.arctan2(y_scaled, x_scaled) + np.radians(angle) - if r_scaled_elliptical <= major_axis_radius: - mask_2d[y, x] = False + y_scaled_elliptical = r_scaled * np.sin(theta_rotated) + x_scaled_elliptical = r_scaled * np.cos(theta_rotated) - return mask_2d + # Compute the elliptical radius + r_scaled_elliptical = np.sqrt(x_scaled_elliptical**2 + (y_scaled_elliptical / axis_ratio)**2) + + return ~(r_scaled_elliptical <= major_axis_radius) @numba_util.jit() From 36b6d3509a1ac326ea94644f21ef97944d91ee7f Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 19:16:52 +0000 Subject: [PATCH 037/108] mask_2d_elliptical_annular_from --- autoarray/mask/mask_2d_util.py | 104 ++++++++++++--------------------- 1 file changed, 37 insertions(+), 67 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index a73e6c909..732a13db1 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -130,44 +130,6 @@ def mask_2d_circular_annular_from( ) -@numba_util.jit() -def elliptical_radius_from( - y_scaled: float, x_scaled: float, angle: float, axis_ratio: float -) -> float: - """ - Returns the elliptical radius of an ellipse from its (y,x) scaled centre, rotation angle `angle` defined in degrees - counter-clockwise from the positive x-axis and its axis-ratio. - - This is used by the function `mask_elliptical_from` to determine the radius of every (y,x) coordinate in elliptical - units when deciding if it is within the mask. - - Parameters - ---------- - y_scaled - The scaled y coordinate in Cartesian coordinates which is converted to elliptical coordinates. - x_scaled - The scaled x coordinate in Cartesian coordinates which is converted to elliptical coordinates. - angle - The rotation angle in degrees counter-clockwise from the positive x-axis - axis_ratio - The axis-ratio of the ellipse (minor axis / major axis). - - Returns - ------- - float - The radius of the input scaled (y,x) coordinate on the ellipse's ellipitcal coordinate system. - """ - r_scaled = np.sqrt(x_scaled**2 + y_scaled**2) - - theta_rotated = np.arctan2(y_scaled, x_scaled) + np.radians(angle) - - y_scaled_elliptical = r_scaled * np.sin(theta_rotated) - x_scaled_elliptical = r_scaled * np.cos(theta_rotated) - - return np.sqrt(x_scaled_elliptical**2.0 + (y_scaled_elliptical / axis_ratio) ** 2.0) - - - def mask_2d_elliptical_from( shape_native: Tuple[int, int], pixel_scales: Tuple[float, float], @@ -209,7 +171,7 @@ def mask_2d_elliptical_from( """ centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) - y, x = np.ogrid[:shape_native[0], :shape_native[1]] + y, x = np.ogrid[: shape_native[0], : shape_native[1]] y_scaled = (y - centres_scaled[0]) * pixel_scales[0] x_scaled = (x - centres_scaled[1]) * pixel_scales[1] @@ -223,15 +185,16 @@ def mask_2d_elliptical_from( x_scaled_elliptical = r_scaled * np.cos(theta_rotated) # Compute the elliptical radius - r_scaled_elliptical = np.sqrt(x_scaled_elliptical**2 + (y_scaled_elliptical / axis_ratio)**2) + r_scaled_elliptical = np.sqrt( + x_scaled_elliptical**2 + (y_scaled_elliptical / axis_ratio) ** 2 + ) return ~(r_scaled_elliptical <= major_axis_radius) -@numba_util.jit() def mask_2d_elliptical_annular_from( shape_native: Tuple[int, int], - pixel_scales: ty.PixelScales, + pixel_scales: Tuple[float, float], inner_major_axis_radius: float, inner_axis_ratio: float, inner_phi: float, @@ -277,38 +240,45 @@ def mask_2d_elliptical_annular_from( Examples -------- mask = mask_elliptical_annuli_from( - shape=(10, 10), pixel_scales=0.1, - inner_major_axis_radius=0.5, inner_axis_ratio=0.5, inner_phi=45.0, - outer_major_axis_radius=1.5, outer_axis_ratio=0.8, outer_phi=90.0, - centre=(0.0, 0.0)) + shape=(10, 10), pixel_scales=(0.1, 0.1), + inner_major_axis_radius=0.5, inner_axis_ratio=0.5, inner_phi=45.0, + outer_major_axis_radius=1.5, outer_axis_ratio=0.8, outer_phi=90.0, + centre=(0.0, 0.0)) """ + centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) - mask_2d = np.full(shape_native, True) - - centres_scaled = mask_2d_centres_from( - shape_native=mask_2d.shape, pixel_scales=pixel_scales, centre=centre - ) + y, x = np.ogrid[: shape_native[0], : shape_native[1]] + y_scaled = (y - centres_scaled[0]) * pixel_scales[0] + x_scaled = (x - centres_scaled[1]) * pixel_scales[1] - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - y_scaled = (y - centres_scaled[0]) * pixel_scales[0] - x_scaled = (x - centres_scaled[1]) * pixel_scales[1] + # Rotate the coordinates for the inner annulus + r_scaled_inner = np.sqrt(x_scaled**2 + y_scaled**2) + theta_rotated_inner = np.arctan2(y_scaled, x_scaled) + np.radians(inner_phi) + y_scaled_elliptical_inner = r_scaled_inner * np.sin(theta_rotated_inner) + x_scaled_elliptical_inner = r_scaled_inner * np.cos(theta_rotated_inner) - inner_r_scaled_elliptical = elliptical_radius_from( - y_scaled, x_scaled, inner_phi, inner_axis_ratio - ) + # Compute the elliptical radius for the inner annulus + r_scaled_elliptical_inner = np.sqrt( + x_scaled_elliptical_inner**2 + + (y_scaled_elliptical_inner / inner_axis_ratio) ** 2 + ) - outer_r_scaled_elliptical = elliptical_radius_from( - y_scaled, x_scaled, outer_phi, outer_axis_ratio - ) + # Rotate the coordinates for the outer annulus + r_scaled_outer = np.sqrt(x_scaled**2 + y_scaled**2) + theta_rotated_outer = np.arctan2(y_scaled, x_scaled) + np.radians(outer_phi) + y_scaled_elliptical_outer = r_scaled_outer * np.sin(theta_rotated_outer) + x_scaled_elliptical_outer = r_scaled_outer * np.cos(theta_rotated_outer) - if ( - inner_r_scaled_elliptical >= inner_major_axis_radius - and outer_r_scaled_elliptical <= outer_major_axis_radius - ): - mask_2d[y, x] = False + # Compute the elliptical radius for the outer annulus + r_scaled_elliptical_outer = np.sqrt( + x_scaled_elliptical_outer**2 + + (y_scaled_elliptical_outer / outer_axis_ratio) ** 2 + ) - return mask_2d + return ~( + (r_scaled_elliptical_inner >= inner_major_axis_radius) + & (r_scaled_elliptical_outer <= outer_major_axis_radius) + ) def mask_2d_via_pixel_coordinates_from( From ec8ca607beb5d966d4bb9c04794e994ed79f4ee4 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 19:18:43 +0000 Subject: [PATCH 038/108] simplify tests to not include centre --- test_autoarray/mask/test_mask_2d_util.py | 166 ++--------------------- 1 file changed, 14 insertions(+), 152 deletions(-) diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index 344f8a1c1..befc358be 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -5,20 +5,6 @@ import pytest -def test__total_edge_pixels_from_mask(): - mask_2d = np.array( - [ - [True, True, True, True, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, True, True, True, True], - ] - ) - - assert util.mask_2d.total_edge_pixels_from(mask_2d=mask_2d) == 8 - - def test__mask_2d_circular_from(): mask = util.mask_2d.mask_2d_circular_from( shape_native=(3, 3), pixel_scales=(1.0, 1.0), radius=0.5 @@ -71,26 +57,6 @@ def test__mask_2d_circular_from(): ) ).all() - -def test__mask_2d_circular_from__input_centre(): - mask = util.mask_2d.mask_2d_circular_from( - shape_native=(3, 3), pixel_scales=(3.0, 3.0), radius=0.5, centre=(-3, 0) - ) - - assert mask.shape == (3, 3) - assert ( - mask == np.array([[True, True, True], [True, True, True], [True, False, True]]) - ).all() - - mask = util.mask_2d.mask_2d_circular_from( - shape_native=(3, 3), pixel_scales=(3.0, 3.0), radius=0.5, centre=(0.0, 3.0) - ) - - assert mask.shape == (3, 3) - assert ( - mask == np.array([[True, True, True], [True, True, False], [True, True, True]]) - ).all() - mask = util.mask_2d.mask_2d_circular_from( shape_native=(3, 3), pixel_scales=(3.0, 3.0), radius=0.5, centre=(3, 3) ) @@ -183,40 +149,6 @@ def test__mask_2d_circular_annular_from(): ) ).all() - -def test__mask_2d_circular_annular_from__input_centre(): - mask = util.mask_2d.mask_2d_circular_annular_from( - shape_native=(3, 3), - pixel_scales=(3.0, 3.0), - inner_radius=0.5, - outer_radius=9.0, - centre=(3.0, 0.0), - ) - - assert mask.shape == (3, 3) - assert ( - mask - == np.array( - [[False, True, False], [False, False, False], [False, False, False]] - ) - ).all() - - mask = util.mask_2d.mask_2d_circular_annular_from( - shape_native=(3, 3), - pixel_scales=(3.0, 3.0), - inner_radius=0.5, - outer_radius=9.0, - centre=(0.0, 3.0), - ) - - assert mask.shape == (3, 3) - assert ( - mask - == np.array( - [[False, False, False], [False, False, True], [False, False, False]] - ) - ).all() - mask = util.mask_2d.mask_2d_circular_annular_from( shape_native=(3, 3), pixel_scales=(3.0, 3.0), @@ -333,34 +265,6 @@ def test__mask_2d_elliptical_from(): ) ).all() - -def test__mask_2d_elliptical_from__include_centre(): - mask = util.mask_2d.mask_2d_elliptical_from( - shape_native=(3, 3), - pixel_scales=(3.0, 3.0), - major_axis_radius=4.8, - axis_ratio=0.1, - angle=45.0, - centre=(-3.0, 0.0), - ) - - assert ( - mask == np.array([[True, True, True], [True, True, False], [True, False, True]]) - ).all() - - mask = util.mask_2d.mask_2d_elliptical_from( - shape_native=(3, 3), - pixel_scales=(3.0, 3.0), - major_axis_radius=4.8, - axis_ratio=0.1, - angle=45.0, - centre=(0.0, 3.0), - ) - - assert ( - mask == np.array([[True, True, True], [True, True, False], [True, False, True]]) - ).all() - mask = util.mask_2d.mask_2d_elliptical_from( shape_native=(3, 3), pixel_scales=(3.0, 3.0), @@ -551,62 +455,6 @@ def test__mask_2d_elliptical_annular_from(): ) ).all() - -def test__mask_2d_elliptical_annular_from__include_centre(): - mask = util.mask_2d.mask_2d_elliptical_annular_from( - shape_native=(7, 5), - pixel_scales=(1.0, 1.0), - inner_major_axis_radius=1.0, - inner_axis_ratio=0.1, - inner_phi=0.0, - outer_major_axis_radius=2.0, - outer_axis_ratio=0.1, - outer_phi=90.0, - centre=(-1.0, 0.0), - ) - - assert ( - mask - == np.array( - [ - [True, True, True, True, True], - [True, True, True, True, True], - [True, True, False, True, True], - [True, True, False, True, True], - [True, True, True, True, True], - [True, True, False, True, True], - [True, True, False, True, True], - ] - ) - ).all() - - mask = util.mask_2d.mask_2d_elliptical_annular_from( - shape_native=(7, 5), - pixel_scales=(1.0, 1.0), - inner_major_axis_radius=1.0, - inner_axis_ratio=0.1, - inner_phi=0.0, - outer_major_axis_radius=2.0, - outer_axis_ratio=0.1, - outer_phi=90.0, - centre=(0.0, 1.0), - ) - - assert ( - mask - == np.array( - [ - [True, True, True, True, True], - [True, True, True, False, True], - [True, True, True, False, True], - [True, True, True, True, True], - [True, True, True, False, True], - [True, True, True, False, True], - [True, True, True, True, True], - ] - ) - ).all() - mask = util.mask_2d.mask_2d_elliptical_annular_from( shape_native=(7, 5), pixel_scales=(1.0, 1.0), @@ -899,6 +747,20 @@ def test__mask_1d_indexes_from(): assert masked_slim[-1] == 48 +def test__total_edge_pixels_from_mask(): + mask_2d = np.array( + [ + [True, True, True, True, True], + [True, False, False, False, True], + [True, False, False, False, True], + [True, False, False, False, True], + [True, True, True, True, True], + ] + ) + + assert util.mask_2d.total_edge_pixels_from(mask_2d=mask_2d) == 8 + + def test__edge_1d_indexes_from(): mask = np.array( [ From 943ff407f32ecf12ddb01eb4ba19f29bc2bcbfce Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sat, 15 Mar 2025 20:10:31 +0000 Subject: [PATCH 039/108] blurring_mask_2d_from --- autoarray/mask/mask_2d_util.py | 42 +++++++++--------------- test_autoarray/mask/test_mask_2d_util.py | 14 -------- 2 files changed, 16 insertions(+), 40 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 732a13db1..9b0157309 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -312,7 +312,9 @@ def mask_2d_via_pixel_coordinates_from( return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) -@numba_util.jit() +from scipy.ndimage import convolve + + def blurring_mask_2d_from( mask_2d: np.ndarray, kernel_shape_native: Tuple[int, int] ) -> np.ndarray: @@ -348,32 +350,20 @@ def blurring_mask_2d_from( """ - blurring_mask_2d = np.full(mask_2d.shape, True) + # Create a (3, 3) kernel of ones + kernel = np.ones(kernel_shape_native, dtype=np.uint8) - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - for y1 in range( - (-kernel_shape_native[0] + 1) // 2, - (kernel_shape_native[0] + 1) // 2, - ): - for x1 in range( - (-kernel_shape_native[1] + 1) // 2, - (kernel_shape_native[1] + 1) // 2, - ): - if ( - 0 <= x + x1 <= mask_2d.shape[1] - 1 - and 0 <= y + y1 <= mask_2d.shape[0] - 1 - ): - if mask_2d[y + y1, x + x1]: - blurring_mask_2d[y + y1, x + x1] = False - else: - raise exc.MaskException( - "setup_blurring_mask extends beyond the edge " - "of the mask - pad the datas array before masking" - ) - - return blurring_mask_2d + # Convolve the mask with the kernel, applying logical AND to maintain 'True' regions + convolved_mask = convolve(mask_2d.astype(np.uint8), kernel, mode="reflect", cval=0) + + # We want to return the mask where the convolved value is the full kernel size (i.e., 9 for a 3x3 kernel) + result_mask = convolved_mask == np.prod(kernel_shape_native) + + blurring_mask = ~mask_2d + result_mask + + print(blurring_mask * convolved_mask) + + return blurring_mask @numba_util.jit() diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index befc358be..9e1718abd 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -522,20 +522,6 @@ def test__blurring_mask_2d_from(): ) ).all() - mask = np.array( - [ - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - [True, True, True, False, True, True, True], - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - ] - ) - - blurring_mask = util.mask_2d.blurring_mask_2d_from(mask, kernel_shape_native=(3, 3)) - mask = np.array( [ [True, True, True, True, True, True, True], From 0290b84919e5240e049c045d167f8b29572c52e6 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sun, 30 Mar 2025 15:32:27 +0100 Subject: [PATCH 040/108] check on blurring mask now works --- autoarray/mask/mask_2d_util.py | 70 ++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 9b0157309..f78e9b852 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -1,4 +1,5 @@ import numpy as np +from scipy.ndimage import convolve from typing import Tuple import warnings @@ -308,12 +309,58 @@ def mask_2d_via_pixel_coordinates_from( if buffer == 0: return mask_2d - else: - return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) + return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) -from scipy.ndimage import convolve +import numpy as np + + +def min_false_distance_to_edge(mask: np.ndarray) -> Tuple[int, int]: + """ + Compute the minimum 1D distance in the y and x directions from any False value at the mask's extreme positions + (leftmost, rightmost, topmost, bottommost) to its closest edge. + Parameters + ---------- + mask + A 2D boolean array where False represents the unmasked region. + + Returns + ------- + The smallest distances of any extreme False value to the nearest edge in the vertical (y) and horizontal (x) + directions. + + Examples + -------- + >>> mask = np.array([ + ... [ True, True, True, True], + ... [ True, False, False, True], + ... [ True, False, True, True], + ... [ True, True, True, True] + ... ]) + >>> min_false_distance_to_edge(mask) + (1, 1) + """ + false_indices = np.column_stack(np.where(mask == False)) + + if false_indices.size == 0: + raise ValueError("No False values found in the mask.") + + leftmost = false_indices[np.argmin(false_indices[:, 1])] + rightmost = false_indices[np.argmax(false_indices[:, 1])] + topmost = false_indices[np.argmin(false_indices[:, 0])] + bottommost = false_indices[np.argmax(false_indices[:, 0])] + + height, width = mask.shape + + # Compute distances to respective edges + left_dist = leftmost[1] # Distance to left edge (column index) + right_dist = width - 1 - rightmost[1] # Distance to right edge + top_dist = topmost[0] # Distance to top edge (row index) + bottom_dist = height - 1 - bottommost[0] # Distance to bottom edge + + # Return the minimum distance to an edge + return min(top_dist, bottom_dist), min(left_dist, right_dist) def blurring_mask_2d_from( mask_2d: np.ndarray, kernel_shape_native: Tuple[int, int] @@ -350,19 +397,26 @@ def blurring_mask_2d_from( """ - # Create a (3, 3) kernel of ones + y_distance, x_distance = min_false_distance_to_edge(mask_2d) + + y_kernel_distance = (kernel_shape_native[0]) // 2 + x_kernel_distance = (kernel_shape_native[1]) // 2 + + if (y_distance < y_kernel_distance) or (x_distance < x_kernel_distance): + + raise exc.MaskException( + "The input mask is too small for the kernel shape. " + "Please pad the mask before computing the blurring mask." + ) + kernel = np.ones(kernel_shape_native, dtype=np.uint8) - # Convolve the mask with the kernel, applying logical AND to maintain 'True' regions convolved_mask = convolve(mask_2d.astype(np.uint8), kernel, mode="reflect", cval=0) - # We want to return the mask where the convolved value is the full kernel size (i.e., 9 for a 3x3 kernel) result_mask = convolved_mask == np.prod(kernel_shape_native) blurring_mask = ~mask_2d + result_mask - print(blurring_mask * convolved_mask) - return blurring_mask From 744e2ed96d56ebb5754fa503cea881c1aca2d8c6 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sun, 30 Mar 2025 15:47:07 +0100 Subject: [PATCH 041/108] improve mask 2d util docs --- autoarray/mask/mask_2d_util.py | 211 ++++++++++++++++++++++----------- 1 file changed, 145 insertions(+), 66 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index f78e9b852..4eca6a758 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -34,12 +34,19 @@ def mask_2d_centres_from( Examples -------- - centres_scaled = mask_2d_centres_from(shape_native=(5, 5), pixel_scales=(0.5, 0.5), centre=(0.0, 0.0)) + >>> centres_scaled = mask_2d_centres_from(shape_native=(5, 5), pixel_scales=(0.5, 0.5), centre=(0.0, 0.0)) + >>> print(centres_scaled) + (0.0, 0.0) """ - return ( - 0.5 * (shape_native[0] - 1) - (centre[0] / pixel_scales[0]), - 0.5 * (shape_native[1] - 1) + (centre[1] / pixel_scales[1]), - ) + + # Calculate scaled y-coordinate by centering and adjusting for pixel scale + y_scaled = 0.5 * (shape_native[0] - 1) - (centre[0] / pixel_scales[0]) + + # Calculate scaled x-coordinate by centering and adjusting for pixel scale + x_scaled = 0.5 * (shape_native[1] - 1) + (centre[1] / pixel_scales[1]) + + # Return the scaled (y, x) coordinates + return (y_scaled, x_scaled) def mask_2d_circular_from( @@ -70,16 +77,23 @@ def mask_2d_circular_from( Examples -------- - mask = mask_2d_circular_from(shape_native=(10, 10), pixel_scales=(0.1, 0.1), radius=0.5, centre=(0.0, 0.0)) + >>> mask = mask_2d_circular_from(shape_native=(10, 10), pixel_scales=(0.1, 0.1), radius=0.5, centre=(0.0, 0.0)) """ + + # Get scaled coordinates of the mask center centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) + # Create a grid of y, x indices for the mask y, x = np.ogrid[: shape_native[0], : shape_native[1]] + + # Scale the y and x indices based on pixel scales y_scaled = (y - centres_scaled[0]) * pixel_scales[0] x_scaled = (x - centres_scaled[1]) * pixel_scales[1] + # Compute squared distances from the center for each pixel distances_squared = x_scaled**2 + y_scaled**2 + # Return a mask with True for pixels outside the circle and False for inside return distances_squared >= radius**2 @@ -114,18 +128,25 @@ def mask_2d_circular_annular_from( Examples -------- - mask = mask_2d_circular_annular_from( - shape_native=(10, 10), pixel_scales=(0.1, 0.1), inner_radius=0.5, outer_radius=1.5, centre=(0.0, 0.0) - ) + >>> mask = mask_2d_circular_annular_from( + >>> shape_native=(10, 10), pixel_scales=(0.1, 0.1), inner_radius=0.5, outer_radius=1.5, centre=(0.0, 0.0) + >>> ) """ + + # Get scaled coordinates of the mask center centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) + # Create grid of y, x indices for the mask y, x = np.ogrid[: shape_native[0], : shape_native[1]] + + # Scale the y and x indices based on pixel scales y_scaled = (y - centres_scaled[0]) * pixel_scales[0] x_scaled = (x - centres_scaled[1]) * pixel_scales[1] + # Compute squared distances from the center for each pixel distances_squared = x_scaled**2 + y_scaled**2 + # Return the mask where pixels are unmasked between inner and outer radii return ~( (distances_squared >= inner_radius**2) & (distances_squared <= outer_radius**2) ) @@ -166,30 +187,31 @@ def mask_2d_elliptical_from( Examples -------- - mask = mask_2d_elliptical_from( - shape_native=(10, 10), pixel_scales=(0.1, 0.1), major_axis_radius=0.5, axis_ratio=0.5, angle=45.0, centre=(0.0, 0.0) - ) + >>> mask = mask_2d_elliptical_from( + >>> shape_native=(10, 10), pixel_scales=(0.1, 0.1), major_axis_radius=0.5, axis_ratio=0.5, angle=45.0, centre=(0.0, 0.0) + >>> ) """ + + # Get scaled coordinates of the mask center centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) + # Create grid of y, x indices for the mask y, x = np.ogrid[: shape_native[0], : shape_native[1]] + + # Scale the y and x indices based on pixel scales y_scaled = (y - centres_scaled[0]) * pixel_scales[0] x_scaled = (x - centres_scaled[1]) * pixel_scales[1] - # Rotate the coordinates by the angle (counterclockwise) - + # Compute the rotated coordinates and elliptical radius r_scaled = np.sqrt(x_scaled**2 + y_scaled**2) - theta_rotated = np.arctan2(y_scaled, x_scaled) + np.radians(angle) - y_scaled_elliptical = r_scaled * np.sin(theta_rotated) x_scaled_elliptical = r_scaled * np.cos(theta_rotated) - - # Compute the elliptical radius r_scaled_elliptical = np.sqrt( x_scaled_elliptical**2 + (y_scaled_elliptical / axis_ratio) ** 2 ) + # Return the mask where pixels are outside the elliptical region return ~(r_scaled_elliptical <= major_axis_radius) @@ -231,7 +253,7 @@ def mask_2d_elliptical_annular_from( The rotation angle of the outer ellipse within which pixels are unmasked, (counter-clockwise from the positive x-axis). centre - The centre of the elliptical annuli used to mask pixels. + The centre of the elliptical annuli used to mask pixels. Returns ------- @@ -240,42 +262,45 @@ def mask_2d_elliptical_annular_from( Examples -------- - mask = mask_elliptical_annuli_from( - shape=(10, 10), pixel_scales=(0.1, 0.1), - inner_major_axis_radius=0.5, inner_axis_ratio=0.5, inner_phi=45.0, - outer_major_axis_radius=1.5, outer_axis_ratio=0.8, outer_phi=90.0, - centre=(0.0, 0.0)) + >>> mask = mask_elliptical_annuli_from( + >>> shape=(10, 10), pixel_scales=(0.1, 0.1), + >>> inner_major_axis_radius=0.5, inner_axis_ratio=0.5, inner_phi=45.0, + >>> outer_major_axis_radius=1.5, outer_axis_ratio=0.8, outer_phi=90.0, + >>> centre=(0.0, 0.0) + >>> ) """ + + # Get scaled coordinates of the mask center centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre) + # Create grid of y, x indices for the mask y, x = np.ogrid[: shape_native[0], : shape_native[1]] + + # Scale the y and x indices based on pixel scales y_scaled = (y - centres_scaled[0]) * pixel_scales[0] x_scaled = (x - centres_scaled[1]) * pixel_scales[1] - # Rotate the coordinates for the inner annulus + # Compute and rotate coordinates for inner annulus r_scaled_inner = np.sqrt(x_scaled**2 + y_scaled**2) theta_rotated_inner = np.arctan2(y_scaled, x_scaled) + np.radians(inner_phi) y_scaled_elliptical_inner = r_scaled_inner * np.sin(theta_rotated_inner) x_scaled_elliptical_inner = r_scaled_inner * np.cos(theta_rotated_inner) - - # Compute the elliptical radius for the inner annulus r_scaled_elliptical_inner = np.sqrt( x_scaled_elliptical_inner**2 + (y_scaled_elliptical_inner / inner_axis_ratio) ** 2 ) - # Rotate the coordinates for the outer annulus + # Compute and rotate coordinates for outer annulus r_scaled_outer = np.sqrt(x_scaled**2 + y_scaled**2) theta_rotated_outer = np.arctan2(y_scaled, x_scaled) + np.radians(outer_phi) y_scaled_elliptical_outer = r_scaled_outer * np.sin(theta_rotated_outer) x_scaled_elliptical_outer = r_scaled_outer * np.cos(theta_rotated_outer) - - # Compute the elliptical radius for the outer annulus r_scaled_elliptical_outer = np.sqrt( x_scaled_elliptical_outer**2 + (y_scaled_elliptical_outer / outer_axis_ratio) ** 2 ) + # Return the mask where pixels are outside the inner and outer elliptical annuli return ~( (r_scaled_elliptical_inner >= inner_major_axis_radius) & (r_scaled_elliptical_outer <= outer_major_axis_radius) @@ -283,33 +308,53 @@ def mask_2d_elliptical_annular_from( def mask_2d_via_pixel_coordinates_from( - shape_native: Tuple[int, int], pixel_coordinates: [list], buffer: int = 0 + shape_native: Tuple[int, int], pixel_coordinates: list, buffer: int = 0 ) -> np.ndarray: """ - Returns a mask where all unmasked `False` entries are defined from an input list of list of pixel coordinates. + Returns a mask where all unmasked `False` entries are defined from an input list of 2D pixel coordinates. - These may be buffed via an input ``buffer``, whereby all entries in all 8 neighboring directions by this + These may be buffed via an input `buffer`, whereby all entries in all 8 neighboring directions are buffed by this amount. Parameters ---------- - shape_native (int, int) - The (y,x) shape of the mask in units of pixels. - pixel_coordinates : [[int, int]] - The input lists of 2D pixel coordinates where `False` entries are created. + shape_native + The (y, x) shape of the mask in units of pixels. + pixel_coordinates + The input list of 2D pixel coordinates where `False` entries are created. buffer - All input ``pixel_coordinates`` are buffed with `False` entries in all 8 neighboring directions by this + All input `pixel_coordinates` are buffed with `False` entries in all 8 neighboring directions by this amount. - """ - mask_2d = np.full(shape=shape_native, fill_value=True) + Returns + ------- + np.ndarray + The 2D mask array where all entries in the input pixel coordinates are set to `False`, with optional buffering + applied to the neighboring entries. - for y, x in pixel_coordinates: + Examples + -------- + mask = mask_2d_via_pixel_coordinates_from( + shape_native=(10, 10), + pixel_coordinates=[[1, 2], [3, 4], [5, 6]], + buffer=1 + ) + """ + mask_2d = np.full( + shape=shape_native, fill_value=True + ) # Initialize mask with all True values + + for ( + y, + x, + ) in ( + pixel_coordinates + ): # Loop over input coordinates to set corresponding mask entries to False mask_2d[y, x] = False - if buffer == 0: + if buffer == 0: # If no buffer is specified, return the mask directly return mask_2d - return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) + return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) # Apply buf import numpy as np @@ -317,18 +362,19 @@ def mask_2d_via_pixel_coordinates_from( def min_false_distance_to_edge(mask: np.ndarray) -> Tuple[int, int]: """ - Compute the minimum 1D distance in the y and x directions from any False value at the mask's extreme positions + Compute the minimum 1D distance in the y and x directions from any `False` value at the mask's extreme positions (leftmost, rightmost, topmost, bottommost) to its closest edge. Parameters ---------- mask - A 2D boolean array where False represents the unmasked region. + A 2D boolean array where `False` represents the unmasked region. Returns ------- - The smallest distances of any extreme False value to the nearest edge in the vertical (y) and horizontal (x) - directions. + Tuple[int, int] + The smallest distances of any extreme `False` value to the nearest edge in the vertical (y) and horizontal (x) + directions. Examples -------- @@ -341,17 +387,29 @@ def min_false_distance_to_edge(mask: np.ndarray) -> Tuple[int, int]: >>> min_false_distance_to_edge(mask) (1, 1) """ - false_indices = np.column_stack(np.where(mask == False)) + false_indices = np.column_stack( + np.where(mask == False) + ) # Find all coordinates where mask is False if false_indices.size == 0: - raise ValueError("No False values found in the mask.") - - leftmost = false_indices[np.argmin(false_indices[:, 1])] - rightmost = false_indices[np.argmax(false_indices[:, 1])] - topmost = false_indices[np.argmin(false_indices[:, 0])] - bottommost = false_indices[np.argmax(false_indices[:, 0])] - - height, width = mask.shape + raise ValueError( + "No False values found in the mask." + ) # Raise error if no False values + + leftmost = false_indices[ + np.argmin(false_indices[:, 1]) + ] # Find the leftmost False coordinate + rightmost = false_indices[ + np.argmax(false_indices[:, 1]) + ] # Find the rightmost False coordinate + topmost = false_indices[ + np.argmin(false_indices[:, 0]) + ] # Find the topmost False coordinate + bottommost = false_indices[ + np.argmax(false_indices[:, 0]) + ] # Find the bottommost False coordinate + + height, width = mask.shape # Get the height and width of the mask # Compute distances to respective edges left_dist = leftmost[1] # Distance to left edge (column index) @@ -359,9 +417,10 @@ def min_false_distance_to_edge(mask: np.ndarray) -> Tuple[int, int]: top_dist = topmost[0] # Distance to top edge (row index) bottom_dist = height - 1 - bottommost[0] # Distance to bottom edge - # Return the minimum distance to an edge + # Return the minimum distance to both edges return min(top_dist, bottom_dist), min(left_dist, right_dist) + def blurring_mask_2d_from( mask_2d: np.ndarray, kernel_shape_native: Tuple[int, int] ) -> np.ndarray: @@ -397,27 +456,47 @@ def blurring_mask_2d_from( """ - y_distance, x_distance = min_false_distance_to_edge(mask_2d) + # Get the distance from False values to edges + y_distance, x_distance = min_false_distance_to_edge( + mask_2d + ) - y_kernel_distance = (kernel_shape_native[0]) // 2 - x_kernel_distance = (kernel_shape_native[1]) // 2 + # Compute kernel half-size in y and x direction + y_kernel_distance = ( + kernel_shape_native[0] + ) // 2 + x_kernel_distance = ( + kernel_shape_native[1] + ) // 2 + # Check if mask is too small for the kernel size if (y_distance < y_kernel_distance) or (x_distance < x_kernel_distance): - raise exc.MaskException( "The input mask is too small for the kernel shape. " "Please pad the mask before computing the blurring mask." ) - kernel = np.ones(kernel_shape_native, dtype=np.uint8) + # Create a kernel with the given PSF shape + kernel = np.ones( + kernel_shape_native, dtype=np.uint8 + ) - convolved_mask = convolve(mask_2d.astype(np.uint8), kernel, mode="reflect", cval=0) + # Convolve mask with kernel producing non-zero values around mask False values + convolved_mask = convolve( + mask_2d.astype(np.uint8), kernel, mode="reflect", cval=0 + ) - result_mask = convolved_mask == np.prod(kernel_shape_native) + # Identify pixels that are non-zero and fully covered by kernel + result_mask = convolved_mask == np.prod( + kernel_shape_native + ) - blurring_mask = ~mask_2d + result_mask + # Create the blurring mask by removing False values in original mask + blurring_mask = ( + ~mask_2d + result_mask + ) - return blurring_mask + return blurring_mask @numba_util.jit() From affea87ee786dd0c63ca5bd869a28f294406926a Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sun, 30 Mar 2025 15:52:30 +0100 Subject: [PATCH 042/108] remove mask_2d_via_shape_native_and_native_for_slim --- autoarray/mask/mask_2d_util.py | 45 +----------------------- test_autoarray/mask/test_mask_2d_util.py | 38 -------------------- 2 files changed, 1 insertion(+), 82 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 4eca6a758..f7dd8e2af 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -496,50 +496,7 @@ def blurring_mask_2d_from( ~mask_2d + result_mask ) - return blurring_mask - - -@numba_util.jit() -def mask_2d_via_shape_native_and_native_for_slim( - shape_native: Tuple[int, int], native_for_slim: np.ndarray -) -> np.ndarray: - """ - For a slimmed set of data that was computed by mapping unmasked values from a native 2D array of shape - (total_y_pixels, total_x_pixels), map its slimmed indexes back to the original 2D array to create the - native 2D mask. - - This uses an array 'native_for_slim' of shape [total_masked_pixels[ where each index gives the native 2D pixel - indexes of the slimmed array's unmasked pixels, for example: - - - If native_for_slim[0] = [0,0], the first value of the slimmed array maps to the pixel [0,0] of the native 2D array. - - If native_for_slim[1] = [0,1], the second value of the slimmed array maps to the pixel [0,1] of the native 2D array. - - If native_for_slim[4] = [1,1], the fifth value of the slimmed array maps to the pixel [1,1] of the native 2D array. - - Parameters - ---------- - shape_native - The shape of the 2D array which the pixels are defined on. - native_for_slim - An array describing the native 2D array index that every slimmed array index maps too. - - Returns - ------- - ndarray - A 2D mask array where unmasked values are `False`. - - Examples - -------- - native_for_slim = np.array([[0,1], [1,0], [1,1], [1,2], [2,1]]) - - mask = mask_from(shape=(3,3), native_for_slim=native_for_slim) - """ - - mask = np.ones(shape_native) - - for index in range(len(native_for_slim)): - mask[native_for_slim[index, 0], native_for_slim[index, 1]] = False - - return mask + return blurring_mask @numba_util.jit() diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index 9e1718abd..94779c78d 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -668,44 +668,6 @@ def test__blurring_mask_2d_from__mask_extends_beyond_edge_so_raises_mask_excepti util.mask_2d.blurring_mask_2d_from(mask, kernel_shape_native=(5, 5)) -def test__mask_2d_via_shape_native_and_native_for_slim(): - slim_to_native = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) - shape = (2, 2) - - mask = util.mask_2d.mask_2d_via_shape_native_and_native_for_slim( - shape_native=shape, native_for_slim=slim_to_native - ) - - assert (mask == np.array([[False, False], [False, False]])).all() - - slim_to_native = np.array([[0, 0], [0, 1], [1, 0]]) - shape = (2, 2) - - mask = util.mask_2d.mask_2d_via_shape_native_and_native_for_slim( - shape_native=shape, native_for_slim=slim_to_native - ) - - assert (mask == np.array([[False, False], [False, True]])).all() - - slim_to_native = np.array([[0, 0], [0, 1], [1, 0], [2, 0], [2, 1], [2, 3]]) - shape = (3, 4) - - mask = util.mask_2d.mask_2d_via_shape_native_and_native_for_slim( - shape_native=shape, native_for_slim=slim_to_native - ) - - assert ( - mask - == np.array( - [ - [False, False, True, True], - [False, True, True, True], - [False, False, True, False], - ] - ) - ).all() - - def test__mask_1d_indexes_from(): mask = np.array( [ From 4cd2971bb647d304e85dab5e1b2c32961943fef7 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Sun, 30 Mar 2025 15:58:36 +0100 Subject: [PATCH 043/108] mask_slim_indexes_from --- autoarray/mask/mask_2d_util.py | 79 ++++++++++++---------------------- 1 file changed, 28 insertions(+), 51 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index f7dd8e2af..2e6b4b0e1 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -457,17 +457,11 @@ def blurring_mask_2d_from( """ # Get the distance from False values to edges - y_distance, x_distance = min_false_distance_to_edge( - mask_2d - ) + y_distance, x_distance = min_false_distance_to_edge(mask_2d) # Compute kernel half-size in y and x direction - y_kernel_distance = ( - kernel_shape_native[0] - ) // 2 - x_kernel_distance = ( - kernel_shape_native[1] - ) // 2 + y_kernel_distance = (kernel_shape_native[0]) // 2 + x_kernel_distance = (kernel_shape_native[1]) // 2 # Check if mask is too small for the kernel size if (y_distance < y_kernel_distance) or (x_distance < x_kernel_distance): @@ -477,29 +471,18 @@ def blurring_mask_2d_from( ) # Create a kernel with the given PSF shape - kernel = np.ones( - kernel_shape_native, dtype=np.uint8 - ) + kernel = np.ones(kernel_shape_native, dtype=np.uint8) # Convolve mask with kernel producing non-zero values around mask False values - convolved_mask = convolve( - mask_2d.astype(np.uint8), kernel, mode="reflect", cval=0 - ) + convolved_mask = convolve(mask_2d.astype(np.uint8), kernel, mode="reflect", cval=0) # Identify pixels that are non-zero and fully covered by kernel - result_mask = convolved_mask == np.prod( - kernel_shape_native - ) + result_mask = convolved_mask == np.prod(kernel_shape_native) # Create the blurring mask by removing False values in original mask - blurring_mask = ( - ~mask_2d + result_mask - ) + return ~mask_2d + result_mask - return blurring_mask - -@numba_util.jit() def mask_slim_indexes_from( mask_2d: np.ndarray, return_masked_indexes: bool = True ) -> np.ndarray: @@ -509,12 +492,12 @@ def mask_slim_indexes_from( For example, for the following ``Mask2D``: :: - [[True, True, True, True] + [[True, True, True, True], [True, False, False, True], [True, False, True, True], [True, True, True, True]] - This has three unmasked (``False`` values) which have the ``slim`` indexes, there ``unmasked_slim`` is: + This has three unmasked (``False`` values) which have the ``slim`` indexes, their ``unmasked_slim`` is: :: [0, 1, 2] @@ -522,36 +505,30 @@ def mask_slim_indexes_from( Parameters ---------- mask_2d - The mask for which the 1D unmasked pixel indexes are computed. + A 2D array representing the mask, where `True` indicates a masked pixel and `False` indicates an unmasked pixel. return_masked_indexes - Whether to return the masked index values (`value=True`) or the unmasked index values (`value=False`). + A boolean flag that determines whether to return indexes of masked (`True`) or unmasked (`False`) pixels. Returns ------- - np.ndarray - The 1D indexes of all unmasked pixels on the mask. - """ - - mask_pixel_total = 0 - - for y in range(0, mask_2d.shape[0]): - for x in range(0, mask_2d.shape[1]): - if mask_2d[y, x] == return_masked_indexes: - mask_pixel_total += 1 - - mask_pixels = np.zeros(mask_pixel_total) - mask_index = 0 - regular_index = 0 + A 1D array of indexes corresponding to either the masked or unmasked pixels in the mask. - for y in range(0, mask_2d.shape[0]): - for x in range(0, mask_2d.shape[1]): - if mask_2d[y, x] == return_masked_indexes: - mask_pixels[mask_index] = regular_index - mask_index += 1 - - regular_index += 1 - - return mask_pixels + Examples + -------- + >>> mask = np.array([[True, True, True, True], + ... [True, False, False, True], + ... [True, False, True, True], + ... [True, True, True, True]]) + >>> mask_slim_indexes_from(mask, return_masked_indexes=True) + array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> mask_slim_indexes_from(mask, return_masked_indexes=False) + array([10, 11]) + """ + # Flatten the mask and use np.where to get indexes of either True or False + mask_flat = mask_2d.flatten() + + # Get the indexes where the mask is equal to return_masked_indexes (True or False) + return np.where(mask_flat == return_masked_indexes)[0] @numba_util.jit() From 06a9ec3b94fad0ae247c61b2519087ecc75a258e Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 19:39:32 +0100 Subject: [PATCH 044/108] edge_1d_indexes_from --- autoarray/mask/mask_2d_util.py | 250 +++++++---------------- test_autoarray/mask/test_mask_2d_util.py | 15 +- 2 files changed, 81 insertions(+), 184 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 2e6b4b0e1..2ba06135c 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -8,6 +8,52 @@ from autoarray import type as ty from autoarray.numpy_wrapper import use_jax, np as jnp +def native_index_for_slim_index_2d_from( + mask_2d: np.ndarray, +) -> np.ndarray: + """ + Returns an array of shape [total_unmasked_pixels] that maps every unmasked pixel to its + corresponding native 2D pixel using its (y,x) pixel indexes. + + For example, for the following ``Mask2D``: + + :: + [[True, True, True, True] + [True, False, False, True], + [True, False, True, True], + [True, True, True, True]] + + This has three unmasked (``False`` values) which have the ``slim`` indexes: + + :: + [0, 1, 2] + + The array ``native_index_for_slim_index_2d`` is therefore: + + :: + [[1,1], [1,2], [2,1]] + + Parameters + ---------- + mask_2d + A 2D array of bools, where `False` values are unmasked. + + Returns + ------- + ndarray + An array that maps pixels from a slimmed array of shape [total_unmasked_pixels] to its native array + of shape [total_pixels, total_pixels]. + + Examples + -------- + mask_2d = np.array([[True, True, True], + [True, False, True] + [True, True, True]]) + + native_index_for_slim_index_2d = native_index_for_slim_index_2d_from(mask_2d=mask_2d) + """ + return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T + def mask_2d_centres_from( shape_native: Tuple[int, int], @@ -531,136 +577,56 @@ def mask_slim_indexes_from( return np.where(mask_flat == return_masked_indexes)[0] -@numba_util.jit() -def check_if_edge_pixel(mask_2d: np.ndarray, y: int, x: int) -> bool: - """ - Checks if an input [y,x] pixel on the input `mask` is an edge-pixel. - - An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least 1 of its 8 - direct neighbors is masked (is `True`). - - Parameters - ---------- - mask_2d - The mask for which the input pixel is checked if it is an edge pixel. - y - The y pixel coordinate on the mask that is checked for if it is an edge pixel. - x - The x pixel coordinate on the mask that is checked for if it is an edge pixel. - - Returns - ------- - bool - If `True` the pixel on the mask is an edge pixel, else a `False` is returned because it is not. - """ - - if ( - mask_2d[y + 1, x] - or mask_2d[y - 1, x] - or mask_2d[y, x + 1] - or mask_2d[y, x - 1] - or mask_2d[y + 1, x + 1] - or mask_2d[y + 1, x - 1] - or mask_2d[y - 1, x + 1] - or mask_2d[y - 1, x - 1] - ): - return True - else: - return False - - -@numba_util.jit() -def total_edge_pixels_from(mask_2d: np.ndarray) -> int: - """ - Returns the total number of edge-pixels in a mask. - - An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least 1 of its 8 - direct neighbors is masked (is `True`). - - Parameters - ---------- - mask_2d - The mask for which the total number of edge pixels is computed. - - Returns - ------- - int - The total number of edge pixels. - """ - - edge_pixel_total = 0 - - for y in range(1, mask_2d.shape[0] - 1): - for x in range(1, mask_2d.shape[1] - 1): - if not mask_2d[y, x]: - if check_if_edge_pixel(mask_2d=mask_2d, y=y, x=x): - edge_pixel_total += 1 - - return edge_pixel_total - - -@numba_util.jit() def edge_1d_indexes_from(mask_2d: np.ndarray) -> np.ndarray: """ Returns a 1D array listing all edge pixel indexes in the mask. - An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least 1 of its 8 + An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least one of its 8 direct neighbors is masked (is `True`). - For example, for the following ``Mask2D``: - - :: - [[True, True, True, True, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, True, True, True, True]] - - The `edge_slim` indexes (given via ``mask_2d.derive_indexes.edge_slim``) is given by: - - :: - [0, 1, 2, 3, 5, 6, 7, 8] - - Note that index 4 is skipped, which corresponds to the ``False`` value in the centre of the mask, because it - does not neighbor a ``True`` value in any one of the eight neighboring directions and is therefore not at - an edge. - Parameters ---------- mask_2d - The mask for which the 1D edge pixel indexes are computed. + A 2D boolean array where `False` values indicate unmasked pixels. Returns ------- np.ndarray - The 1D indexes of all edge pixels on the mask. - """ - - edge_pixel_total = total_edge_pixels_from(mask_2d) + A 1D array of indexes of all edge pixels on the mask. - edge_pixels = np.zeros(edge_pixel_total) - edge_index = 0 - regular_index = 0 + Examples + -------- + >>> mask = np.array([ + ... [True, True, True, True, True], + ... [True, False, False, True, True], + ... [True, False, False, False, True], + ... [True, True, False, True, True], + ... [True, True, True, True, True] + ... ]) + >>> edge_1d_indexes_from(mask) + array([1, 2, 5, 7, 8, 9]) + """ + # Pad the mask to handle edge cases without index errors + padded_mask = np.pad(mask_2d, pad_width=1, mode='constant', constant_values=True) + + # Identify neighbors in 3x3 regions around each pixel + neighbors = ( + padded_mask[:-2, 1:-1] | padded_mask[2:, 1:-1] | # Up, Down + padded_mask[1:-1, :-2] | padded_mask[1:-1, 2:] | # Left, Right + padded_mask[:-2, :-2] | padded_mask[:-2, 2:] | # Top-left, Top-right + padded_mask[2:, :-2] | padded_mask[2:, 2:] # Bottom-left, Bottom-right + ) - for y in range(1, mask_2d.shape[0] - 1): - for x in range(1, mask_2d.shape[1] - 1): - if not mask_2d[y, x]: - if ( - mask_2d[y + 1, x] - or mask_2d[y - 1, x] - or mask_2d[y, x + 1] - or mask_2d[y, x - 1] - or mask_2d[y + 1, x + 1] - or mask_2d[y + 1, x - 1] - or mask_2d[y - 1, x + 1] - or mask_2d[y - 1, x - 1] - ): - edge_pixels[edge_index] = regular_index - edge_index += 1 + # Identify edge pixels: False values with at least one True neighbor + edge_mask = ~mask_2d & neighbors - regular_index += 1 + # Create an index array where False entries get sequential 1D indices + index_array = np.full(mask_2d.shape, fill_value=-1, dtype=int) + false_indices = np.flatnonzero(~mask_2d) + index_array[~mask_2d] = np.arange(len(false_indices)) - return edge_pixels + # Return the 1D indexes of the edge pixels + return index_array[edge_mask] @numba_util.jit() @@ -911,62 +877,4 @@ def rescaled_mask_2d_from(mask_2d: np.ndarray, rescale_factor: float) -> np.ndar return np.isclose(rescaled_mask_2d, 1) -@numba_util.jit() -def native_index_for_slim_index_2d_from( - mask_2d: np.ndarray, -) -> np.ndarray: - """ - Returns an array of shape [total_unmasked_pixels] that maps every unmasked pixel to its - corresponding native 2D pixel using its (y,x) pixel indexes. - - For example, for the following ``Mask2D``: - - :: - [[True, True, True, True] - [True, False, False, True], - [True, False, True, True], - [True, True, True, True]] - - This has three unmasked (``False`` values) which have the ``slim`` indexes: - - :: - [0, 1, 2] - - The array ``native_index_for_slim_index_2d`` is therefore: - - :: - [[1,1], [1,2], [2,1]] - - Parameters - ---------- - mask_2d - A 2D array of bools, where `False` values are unmasked. - - Returns - ------- - ndarray - An array that maps pixels from a slimmed array of shape [total_unmasked_pixels] to its native array - of shape [total_pixels, total_pixels]. - - Examples - -------- - mask_2d = np.array([[True, True, True], - [True, False, True] - [True, True, True]]) - - native_index_for_slim_index_2d = native_index_for_slim_index_2d_from(mask_2d=mask_2d) - """ - if use_jax: - return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T - else: - total_pixels = np.sum(~mask_2d) - native_index_for_slim_index_2d = np.zeros(shape=(total_pixels, 2)) - slim_index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - native_index_for_slim_index_2d[slim_index, :] = y, x - slim_index += 1 - return native_index_for_slim_index_2d diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index 94779c78d..240e91f7d 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -695,19 +695,6 @@ def test__mask_1d_indexes_from(): assert masked_slim[-1] == 48 -def test__total_edge_pixels_from_mask(): - mask_2d = np.array( - [ - [True, True, True, True, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, True, True, True, True], - ] - ) - - assert util.mask_2d.total_edge_pixels_from(mask_2d=mask_2d) == 8 - def test__edge_1d_indexes_from(): mask = np.array( @@ -724,6 +711,8 @@ def test__edge_1d_indexes_from(): edge_pixels = util.mask_2d.edge_1d_indexes_from(mask_2d=mask) + print(edge_pixels) + assert (edge_pixels == np.array([0])).all() mask = np.array( From 39fc17370adbce2012bc2e31f0a83da176b8a02d Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 19:55:51 +0100 Subject: [PATCH 045/108] border_slim_indexes_from --- autoarray/mask/mask_2d_util.py | 185 +++++++++++---------------------- 1 file changed, 63 insertions(+), 122 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 2ba06135c..c8765b93d 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -584,6 +584,24 @@ def edge_1d_indexes_from(mask_2d: np.ndarray) -> np.ndarray: An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least one of its 8 direct neighbors is masked (is `True`). + For example, for the following ``Mask2D``: + + :: + [[True, True, True, True, True], + [True, False, False, False, True], + [True, False, False, False, True], + [True, False, False, False, True], + [True, True, True, True, True]] + + The `edge_slim` indexes (given via ``mask_2d.derive_indexes.edge_slim``) is given by: + + :: + [0, 1, 2, 3, 5, 6, 7, 8] + + Note that index 4 is skipped, which corresponds to the ``False`` value in the centre of the mask, because it + does not neighbor a ``True`` value in any one of the eight neighboring directions and is therefore not at + an edge. + Parameters ---------- mask_2d @@ -591,20 +609,19 @@ def edge_1d_indexes_from(mask_2d: np.ndarray) -> np.ndarray: Returns ------- - np.ndarray - A 1D array of indexes of all edge pixels on the mask. + A 1D array of indexes of all edge pixels on the mask. Examples -------- >>> mask = np.array([ ... [True, True, True, True, True], - ... [True, False, False, True, True], ... [True, False, False, False, True], - ... [True, True, False, True, True], + ... [True, False, False, False, True], + ... [True, False, False, False, True], ... [True, True, True, True, True] ... ]) >>> edge_1d_indexes_from(mask) - array([1, 2, 5, 7, 8, 9]) + array([0, 1, 2, 3, 5, 6, 7, 8]) """ # Pad the mask to handle edge cases without index errors padded_mask = np.pad(mask_2d, pad_width=1, mode='constant', constant_values=True) @@ -629,102 +646,12 @@ def edge_1d_indexes_from(mask_2d: np.ndarray) -> np.ndarray: return index_array[edge_mask] -@numba_util.jit() -def check_if_border_pixel( - mask_2d: np.ndarray, edge_pixel_slim: int, native_to_slim: np.ndarray -) -> bool: - """ - Checks if an input [y,x] pixel on the input `mask` is a border-pixel. - - A borders pixel is a pixel which: - - 1) is not fully surrounding by `False` mask values. - 2) Can reach the edge of the array without hitting a masked pixel in one of four directions (upwards, downwards, - left, right). - - The borders pixels are thus pixels which are on the exterior edge of the mask. For example, the inner ring of edge - pixels in an annular mask are edge pixels but not borders pixels. - - Parameters - ---------- - mask_2d - The mask for which the input pixel is checked if it is a border pixel. - edge_pixel_slim - The edge pixel index in 1D that is checked if it is a border pixel (this 1D index is mapped to 2d via the - array `native_index_for_slim_index_2d`). - native_to_slim - An array describing the native 2D array index that every slimmed array index maps too. - - Returns - ------- - bool - If `True` the pixel on the mask is a border pixel, else a `False` is returned because it is not. - """ - edge_pixel_index = int(edge_pixel_slim) - - y = int(native_to_slim[edge_pixel_index, 0]) - x = int(native_to_slim[edge_pixel_index, 1]) - - if ( - np.sum(mask_2d[0:y, x]) == y - or np.sum(mask_2d[y, x : mask_2d.shape[1]]) == mask_2d.shape[1] - x - 1 - or np.sum(mask_2d[y : mask_2d.shape[0], x]) == mask_2d.shape[0] - y - 1 - or np.sum(mask_2d[y, 0:x]) == x - ): - return True - else: - return False - - -@numba_util.jit() -def total_border_pixels_from(mask_2d, edge_pixels, native_to_slim): - """ - Returns the total number of border-pixels in a mask. - - A borders pixel is a pixel which: - - 1) is not fully surrounding by `False` mask values. - 2) Can reach the edge of the array without hitting a masked pixel in one of four directions (upwards, downwards, - left, right). - - The borders pixels are thus pixels which are on the exterior edge of the mask. For example, the inner ring of edge - pixels in an annular mask are edge pixels but not borders pixels. - - Parameters - ---------- - mask_2d - The mask for which the total number of border pixels is computed. - edge_pixel_1d - The edge pixel index in 1D that is checked if it is a border pixel (this 1D index is mapped to 2d via the - array `native_index_for_slim_index_2d`). - native_to_slim - An array describing the 2D array index that every 1D array index maps too. - - Returns - ------- - int - The total number of border pixels. - """ - - border_pixel_total = 0 - - for i in range(edge_pixels.shape[0]): - if check_if_border_pixel(mask_2d, edge_pixels[i], native_to_slim): - border_pixel_total += 1 - - return border_pixel_total - - -@numba_util.jit() def border_slim_indexes_from(mask_2d: np.ndarray) -> np.ndarray: """ - Returns a slim array of shape [total_unmasked_border_pixels] listing all borders pixel indexes in the mask. - - A borders pixel is a pixel which: + Returns a 1D array listing all border pixel indexes in the mask. - 1) is not fully surrounding by `False` mask values. - 2) Can reach the edge of the array without hitting a masked pixel in one of four directions (upwards, downwards, - left, right). + A border pixel is an unmasked pixel (`False` value) that can reach the edge of the mask without encountering + a masked (`True`) pixel in any of the four cardinal directions (up, down, left, right). The borders pixels are thus pixels which are on the exterior edge of the mask. For example, the inner ring of edge pixels in an annular mask are edge pixels but not borders pixels. @@ -753,39 +680,53 @@ def border_slim_indexes_from(mask_2d: np.ndarray) -> np.ndarray: Parameters ---------- mask_2d - The mask for which the slimmed border pixel indexes are calculated. + A 2D boolean array where `False` values indicate unmasked pixels. Returns ------- - np.ndarray - The slimmed indexes of all border pixels on the mask. - """ + A 1D array of indexes of all border pixels on the mask. - edge_pixels = edge_1d_indexes_from(mask_2d=mask_2d) - native_index_for_slim_index_2d = native_index_for_slim_index_2d_from( - mask_2d=mask_2d, - ) + Examples + -------- + >>> mask = np.array([ + ... [True, True, True, True, True, True, True, True, True], + ... [True, False, False, False, False, False, False, False, True], + ... [True, False, True, True, True, True, True, False, True], + ... [True, False, True, False, False, False, True, False, True], + ... [True, False, True, False, True, False, True, False, True], + ... [True, False, True, False, False, False, True, False, True], + ... [True, False, True, True, True, True, True, False, True], + ... [True, False, False, False, False, False, False, False, True], + ... [True, True, True, True, True, True, True, True, True] + ... ]) + >>> border_slim_indexes_from(mask) + array([0, 1, 2, 3, 5, 6, 7, 11, 12, 15, 16, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]) + """ - border_pixel_total = total_border_pixels_from( - mask_2d=mask_2d, - edge_pixels=edge_pixels, - native_to_slim=native_index_for_slim_index_2d, - ) + # Compute cumulative sums along each direction + up_sums = np.cumsum(mask_2d, axis=0) + down_sums = np.cumsum(mask_2d[::-1, :], axis=0)[::-1, :] + left_sums = np.cumsum(mask_2d, axis=1) + right_sums = np.cumsum(mask_2d[:, ::-1], axis=1)[:, ::-1] - border_pixels = np.zeros(border_pixel_total) + # Get mask dimensions + height, width = mask_2d.shape - border_pixel_index = 0 + # Identify border pixels: where the full length in any direction is True + border_mask = ( + (up_sums == np.arange(height)[:, None]) | + (down_sums == np.arange(height - 1, -1, -1)[:, None]) | + (left_sums == np.arange(width)[None, :]) | + (right_sums == np.arange(width - 1, -1, -1)[None, :]) + ) & ~mask_2d - for edge_pixel_index in range(edge_pixels.shape[0]): - if check_if_border_pixel( - mask_2d=mask_2d, - edge_pixel_slim=edge_pixels[edge_pixel_index], - native_to_slim=native_index_for_slim_index_2d, - ): - border_pixels[border_pixel_index] = edge_pixels[edge_pixel_index] - border_pixel_index += 1 + # Create an index array where False entries get sequential 1D indices + index_array = np.full(mask_2d.shape, fill_value=-1, dtype=int) + false_indices = np.flatnonzero(~mask_2d) + index_array[~mask_2d] = np.arange(len(false_indices)) - return border_pixels + # Return the 1D indexes of the border pixels + return index_array[border_mask] @numba_util.jit() From 3066c773b30199df6eea5ba91fb0a01b744dcf0a Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 20:04:03 +0100 Subject: [PATCH 046/108] all numba decorators removed --- autoarray/mask/mask_2d_util.py | 69 +++++++++++++++--------- test_autoarray/mask/test_mask_2d_util.py | 56 +++++++++---------- 2 files changed, 71 insertions(+), 54 deletions(-) diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index c8765b93d..9ea4e35c0 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -4,9 +4,7 @@ import warnings from autoarray import exc -from autoarray import numba_util -from autoarray import type as ty -from autoarray.numpy_wrapper import use_jax, np as jnp +from autoarray.numpy_wrapper import np as jnp def native_index_for_slim_index_2d_from( mask_2d: np.ndarray, @@ -402,10 +400,6 @@ def mask_2d_via_pixel_coordinates_from( return mask_2d return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) # Apply buf - -import numpy as np - - def min_false_distance_to_edge(mask: np.ndarray) -> Tuple[int, int]: """ Compute the minimum 1D distance in the y and x directions from any `False` value at the mask's extreme positions @@ -729,38 +723,61 @@ def border_slim_indexes_from(mask_2d: np.ndarray) -> np.ndarray: return index_array[border_mask] -@numba_util.jit() def buffed_mask_2d_from(mask_2d: np.ndarray, buffer: int = 1) -> np.ndarray: """ - Returns a buffed mask from an input mask, where the buffed mask is the input mask but all `False` entries in the - mask are buffed by an integer amount in all 8 surrouning pixels. + Returns a buffed mask from an input mask, where all `False` entries in the mask are "buffed" (set to `False`) + within a specified buffer range in all 8 surrounding directions. + + A "buffed" mask is created by marking all the pixels within a square of size `buffer` around each `False` + entry as `False`. This process simulates expanding the masked region around each `False` entry by the specified + buffer distance. Parameters ---------- mask_2d - The mask whose `False` entries are buffed. + A 2D boolean array where `False` values indicate unmasked pixels. buffer - The number of pixels around each `False` entry that pixel are buffed in all 8 directions. + The number of pixels around each `False` entry that should be buffed in all 8 surrounding directions. + This controls how far the "buffed" region extends from each `False` value. Returns ------- - np.ndarray - The buffed mask. + A new 2D boolean array where all `False` entries in the input mask are expanded by the specified buffer + distance, setting all pixels within the buffer range to `False`. + + Examples + -------- + >>> mask = np.array([ + ... [True, False, True], + ... [False, False, False], + ... [True, True, False] + ... ]) + >>> buffed_mask_2d_from(mask, buffer=1) + array([[False, False, False], + [False, False, False], + [False, False, False]]) """ + # Initialize buffed mask as a copy of the input mask buffed_mask_2d = mask_2d.copy() - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - for y0 in range(y - buffer, y + 1 + buffer): - for x0 in range(x - buffer, x + 1 + buffer): - if ( - y0 >= 0 - and x0 >= 0 - and y0 <= mask_2d.shape[0] - 1 - and x0 <= mask_2d.shape[1] - 1 - ): - buffed_mask_2d[y0, x0] = False + # Identify the coordinates of all False entries + false_coords = np.nonzero(~mask_2d) + + # Create grid of offsets for the neighboring pixels (buffer range) + buffer_range = np.arange(-buffer, buffer + 1) + + # Generate all possible neighbors for each False entry + dy, dx = np.meshgrid(buffer_range, buffer_range, indexing='ij') + neighbors = np.stack([dy.ravel(), dx.ravel()], axis=-1) + + # Calculate all neighboring positions for all False coordinates + all_neighbors = np.add(np.array(false_coords).T[:, np.newaxis], neighbors) + + # Clip the neighbors to stay within the bounds of the mask + valid_neighbors = np.clip(all_neighbors, [0, 0], [mask_2d.shape[0] - 1, mask_2d.shape[1] - 1]) + + # Update the buffed mask: set all the neighbors to False + buffed_mask_2d[valid_neighbors[:, :, 0], valid_neighbors[:, :, 1]] = False return buffed_mask_2d diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index 240e91f7d..bd2ebd84a 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -5,6 +5,34 @@ import pytest +def test__native_index_for_slim_index_2d_from(): + mask = np.array([[True, True, True], [True, False, True], [True, True, True]]) + + sub_mask_index_for_sub_mask_1d_index = ( + util.mask_2d.native_index_for_slim_index_2d_from(mask_2d=mask) + ) + + assert (sub_mask_index_for_sub_mask_1d_index == np.array([[1, 1]])).all() + + mask = np.array( + [ + [True, False, True], + [False, False, False], + [True, False, True], + [True, True, False], + ] + ) + + sub_mask_index_for_sub_mask_1d_index = ( + util.mask_2d.native_index_for_slim_index_2d_from(mask_2d=mask) + ) + + assert ( + sub_mask_index_for_sub_mask_1d_index + == np.array([[0, 1], [1, 0], [1, 1], [1, 2], [2, 1], [3, 2]]) + ).all() + + def test__mask_2d_circular_from(): mask = util.mask_2d.mask_2d_circular_from( shape_native=(3, 3), pixel_scales=(1.0, 1.0), radius=0.5 @@ -921,34 +949,6 @@ def test__border_slim_indexes_from(): ).all() -def test__native_index_for_slim_index_2d_from(): - mask = np.array([[True, True, True], [True, False, True], [True, True, True]]) - - sub_mask_index_for_sub_mask_1d_index = ( - util.mask_2d.native_index_for_slim_index_2d_from(mask_2d=mask) - ) - - assert (sub_mask_index_for_sub_mask_1d_index == np.array([[1, 1]])).all() - - mask = np.array( - [ - [True, False, True], - [False, False, False], - [True, False, True], - [True, True, False], - ] - ) - - sub_mask_index_for_sub_mask_1d_index = ( - util.mask_2d.native_index_for_slim_index_2d_from(mask_2d=mask) - ) - - assert ( - sub_mask_index_for_sub_mask_1d_index - == np.array([[0, 1], [1, 0], [1, 1], [1, 2], [2, 1], [3, 2]]) - ).all() - - def test__rescaled_mask_2d_from(): mask = np.array( [ From b317010a74eebf1331cb8d0662624f82fa3638c4 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 20:47:52 +0100 Subject: [PATCH 047/108] fix mask tests by using == --- autoarray/mask/abstract_mask.py | 9 ++------- autoarray/mask/mask_2d.py | 2 +- autoarray/structures/grids/uniform_2d.py | 2 +- test_autoarray/mask/test_mask_1d.py | 18 ++++++++---------- 4 files changed, 12 insertions(+), 19 deletions(-) diff --git a/autoarray/mask/abstract_mask.py b/autoarray/mask/abstract_mask.py index 30a778aed..d6ed67c2c 100644 --- a/autoarray/mask/abstract_mask.py +++ b/autoarray/mask/abstract_mask.py @@ -2,13 +2,8 @@ from abc import ABC import logging - -from autoarray.numpy_wrapper import np, use_jax - -if use_jax: - import jax -from pathlib import Path -from typing import Dict, Union +import numpy as np +from typing import Dict from autoconf.fitsable import output_to_fits diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 60559e9d4..b6dbf7977 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -793,7 +793,7 @@ def resized_from(self, new_shape, pad_value: int = 0.0) -> Mask2D: """ resized_mask = array_2d_util.resized_array_2d_from( - array_2d=np.array(self), + array_2d=np.array(self._array), resized_shape=new_shape, pad_value=pad_value, ).astype("bool") diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 41559e083..64e2738f9 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -184,7 +184,7 @@ def __init__( over_sample_util.grid_2d_slim_over_sampled_via_mask_from( mask_2d=np.array(self.mask), pixel_scales=self.mask.pixel_scales, - sub_size=np.array(self.over_sampler.sub_size).astype("int"), + sub_size=np.array(self.over_sampler.sub_size._array).astype("int"), origin=self.mask.origin, ) ) diff --git a/test_autoarray/mask/test_mask_1d.py b/test_autoarray/mask/test_mask_1d.py index 41f02d3b7..cb5fe2c56 100644 --- a/test_autoarray/mask/test_mask_1d.py +++ b/test_autoarray/mask/test_mask_1d.py @@ -1,6 +1,4 @@ -from astropy.io import fits import numpy as np -import os from os import path import pytest @@ -53,34 +51,34 @@ def test__constructor__input_is_2d_mask__raises_exception(): def test__is_all_true(): mask = aa.Mask1D(mask=[False, False, False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[False, True, False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[True, True, True, True], pixel_scales=1.0) - assert mask.is_all_true is True + assert mask.is_all_true == True def test__is_all_false(): mask = aa.Mask1D(mask=[False, False, False, False], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask1D(mask=[False, False], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask1D(mask=[False, True, False, False], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False mask = aa.Mask1D(mask=[True, True, False, False], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False From e1899b40cc0313a29a2cefd1ab07ee8d9af24e0a Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 20:49:03 +0100 Subject: [PATCH 048/108] fix more tests using == --- test_autoarray/mask/test_mask_2d.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test_autoarray/mask/test_mask_2d.py b/test_autoarray/mask/test_mask_2d.py index 7e12e5b82..399510a0f 100644 --- a/test_autoarray/mask/test_mask_2d.py +++ b/test_autoarray/mask/test_mask_2d.py @@ -366,37 +366,37 @@ def test__mask__input_is_1d_mask__no_shape_native__raises_exception(): def test__is_all_true(): mask = aa.Mask2D(mask=[[False, False], [False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[False, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[True, True], [True, True]], pixel_scales=1.0) - assert mask.is_all_true is True + assert mask.is_all_true == True def test__is_all_false(): mask = aa.Mask2D(mask=[[False, False], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask2D(mask=[[False, False]], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask2D(mask=[[False, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False mask = aa.Mask2D(mask=[[True, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False def test__shape_native_masked_pixels(): From cb5f13b72ec5770405b41920e6cbf5427eae5454 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 21:33:03 +0100 Subject: [PATCH 049/108] fixes to get basic func_grad to work --- autoarray/dataset/imaging/dataset.py | 2 +- autoarray/structures/arrays/kernel_2d.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index bc529ef96..7ded4d409 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -162,7 +162,7 @@ def __init__( if psf is not None and use_normalized_psf: psf = Kernel2D.no_mask( - values=psf.native, pixel_scales=psf.pixel_scales, normalize=True + values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True ) self.psf = psf diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index f21cf5a80..c03ceaf32 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -53,7 +53,7 @@ def __init__( ) if normalize: - self._array[:] = np.divide(self._array, np.sum(self._array)) + self._array = np.divide(self._array, np.sum(self._array)) @classmethod def no_mask( From 15bf0dbdbb9a7a76980cd0babac4cac5b3b87c98 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:25:02 +0100 Subject: [PATCH 050/108] progress stopped at convolver --- autoarray/dataset/imaging/dataset.py | 2 +- autoarray/operators/convolver.py | 2 +- autoarray/structures/grids/uniform_2d.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 7ded4d409..20bb8b0be 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -193,7 +193,7 @@ def convolver(self): The convolver given the masked imaging data's mask and PSF. """ - return Convolver(mask=self.mask, kernel=self.psf) + return Convolver(mask=self.mask, kernel=Kernel2D(values=self.psf._array, mask=self.psf.mask, header=self.psf.header)) @cached_property def w_tilde(self): diff --git a/autoarray/operators/convolver.py b/autoarray/operators/convolver.py index 5a52329b0..73d3767d5 100644 --- a/autoarray/operators/convolver.py +++ b/autoarray/operators/convolver.py @@ -221,7 +221,7 @@ def __init__(self, mask, kernel): coordinates=(x, y), mask=np.array(mask), mask_index_array=self.mask_index_array, - kernel_2d=np.array(self.kernel.native[:, :]), + kernel_2d=self.kernel.native, ) self.image_frame_1d_indexes[mask_1d_index, :] = ( image_frame_1d_indexes diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 64e2738f9..782298eb9 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -545,7 +545,7 @@ def from_mask( """ grid_1d = grid_2d_util.grid_2d_slim_via_mask_from( - mask_2d=np.array(mask), + mask_2d=mask._array, pixel_scales=mask.pixel_scales, origin=mask.origin, ) From d3649ff116d77dda5a54134b796dbc10923f9f26 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:39:14 +0100 Subject: [PATCH 051/108] updated grid_2d_slim_via_mask_from to be JAX implementation --- autoarray/geometry/geometry_util.py | 2 -- autoarray/structures/grids/grid_2d_util.py | 35 +++++++--------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index a795d42ee..d09089ef6 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -180,7 +180,6 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales -@numba_util.jit() def central_pixel_coordinates_2d_from( shape_native: Tuple[int, int], ) -> Tuple[float, float]: @@ -205,7 +204,6 @@ def central_pixel_coordinates_2d_from( return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) -@numba_util.jit() def central_scaled_coordinate_2d_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 44c75c7c5..e83eaab14 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -1,4 +1,7 @@ from __future__ import annotations + +import jax.numpy as jnp + from autoarray.numpy_wrapper import np, use_jax if use_jax: @@ -233,7 +236,6 @@ def grid_2d_centre_from(grid_2d_slim: np.ndarray) -> Tuple[float, float]: return centre_y, centre_x -@numba_util.jit() def grid_2d_slim_via_mask_from( mask_2d: np.ndarray, pixel_scales: ty.PixelScales, @@ -273,33 +275,18 @@ def grid_2d_slim_via_mask_from( grid_slim = grid_2d_slim_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - total_pixels = np.sum(~mask_2d) - centres_scaled = geometry_util.central_scaled_coordinate_2d_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_slim = ( - (np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled) - * sign - * pixel_scales - ) - else: - index = 0 - grid_slim = np.zeros(shape=(total_pixels, 2)) - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - grid_slim[index, 0] = -(y - centres_scaled[0]) * pixel_scales[0] - grid_slim[index, 1] = (x - centres_scaled[1]) * pixel_scales[1] - index += 1 - - return grid_slim + centres_scaled = jnp.array(centres_scaled) + pixel_scales = jnp.array(pixel_scales) + sign = jnp.array([-1.0, 1.0]) + return ( + (jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T - centres_scaled) + * sign + * pixel_scales + ) def grid_2d_via_mask_from( From adf5eadac1554872326ee158309cda0ef8546bc4 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:41:19 +0100 Subject: [PATCH 052/108] remove numba from grid_2d_centre_from --- autoarray/structures/grids/grid_2d_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index e83eaab14..522bb0525 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -216,7 +216,6 @@ def convert_grid_2d_to_native( ) -@numba_util.jit() def grid_2d_centre_from(grid_2d_slim: np.ndarray) -> Tuple[float, float]: """ Returns the centre of a grid from a 1D grid. From 31cdd33d54942bcd276d45ac873333e7576dfb56 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:42:18 +0100 Subject: [PATCH 053/108] remove numba from pixel_coordinates_2d_from -> fixes is circular --- autoarray/geometry/geometry_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index d09089ef6..0477fdbc9 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -242,7 +242,6 @@ def central_scaled_coordinate_2d_from( return (y_pixel, x_pixel) -@numba_util.jit() def pixel_coordinates_2d_from( scaled_coordinates_2d: Tuple[float, float], shape_native: Tuple[int, int], From ff1e811e9cc8cb26faeb0f77e8fc3bbae6018744 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:57:12 +0100 Subject: [PATCH 054/108] fixing grid_2d_slim_over_sampled_via_mask_from to use numba --- .../over_sampling/over_sample_util.py | 86 +++++++++---------- autoarray/structures/grids/uniform_2d.py | 4 +- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index a98276896..70dc220c4 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -1,8 +1,6 @@ from __future__ import annotations import numpy as np from typing import TYPE_CHECKING, Union, List, Tuple -from autoarray.numpy_wrapper import np, register_pytree_node_class, use_jax, jit - from typing import List, Tuple from autoarray.structures.arrays.uniform_2d import Array2D @@ -364,11 +362,11 @@ def sub_size_radial_bins_from( for i in range(radial_grid.shape[0]): for j in range(len(radial_list)): if radial_grid[i] < radial_list[j]: - if use_jax: - # while this makes it run, it is very, very slow - sub_size = sub_size.at[i].set(sub_size_list[j]) - else: - sub_size[i] = sub_size_list[j] + # if use_jax: + # # while this makes it run, it is very, very slow + # sub_size = sub_size.at[i].set(sub_size_list[j]) + # else: + sub_size[i] = sub_size_list[j] break return sub_size @@ -447,35 +445,35 @@ def grid_2d_slim_over_sampled_via_mask_from( for y1 in range(sub): for x1 in range(sub): - if use_jax: - # while this makes it run, it is very, very slow - grid_slim = grid_slim.at[sub_index, 0].set( - -( - y_scaled - - y_sub_half - + y1 * y_sub_step - + (y_sub_step / 2.0) - ) - ) - grid_slim = grid_slim.at[sub_index, 1].set( - x_scaled - - x_sub_half - + x1 * x_sub_step - + (x_sub_step / 2.0) - ) - else: - grid_slim[sub_index, 0] = -( - y_scaled - - y_sub_half - + y1 * y_sub_step - + (y_sub_step / 2.0) - ) - grid_slim[sub_index, 1] = ( - x_scaled - - x_sub_half - + x1 * x_sub_step - + (x_sub_step / 2.0) - ) + # if use_jax: + # # while this makes it run, it is very, very slow + # grid_slim = grid_slim.at[sub_index, 0].set( + # -( + # y_scaled + # - y_sub_half + # + y1 * y_sub_step + # + (y_sub_step / 2.0) + # ) + # ) + # grid_slim = grid_slim.at[sub_index, 1].set( + # x_scaled + # - x_sub_half + # + x1 * x_sub_step + # + (x_sub_step / 2.0) + # ) + # else: + grid_slim[sub_index, 0] = -( + y_scaled + - y_sub_half + + y1 * y_sub_step + + (y_sub_step / 2.0) + ) + grid_slim[sub_index, 1] = ( + x_scaled + - x_sub_half + + x1 * x_sub_step + + (x_sub_step / 2.0) + ) sub_index += 1 index += 1 @@ -544,14 +542,14 @@ def binned_array_2d_from( for y1 in range(sub): for x1 in range(sub): - if use_jax: - binned_array_2d_slim = binned_array_2d_slim.at[index].add( - array_2d[sub_index] * sub_fraction[index] - ) - else: - binned_array_2d_slim[index] += ( - array_2d[sub_index] * sub_fraction[index] - ) + # if use_jax: + # binned_array_2d_slim = binned_array_2d_slim.at[index].add( + # array_2d[sub_index] * sub_fraction[index] + # ) + # else: + binned_array_2d_slim[index] += ( + array_2d[sub_index] * sub_fraction[index] + ) sub_index += 1 index += 1 diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 782298eb9..846d15cc8 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -1,8 +1,10 @@ from __future__ import annotations -from autoarray.numpy_wrapper import np, use_jax +import numpy as np from pathlib import Path from typing import List, Optional, Tuple, Union +from autoarray.numpy_wrapper import use_jax + from autoconf import conf from autoconf import cached_property from autoconf.fitsable import ndarray_via_fits_from From b322a3faf36ea820d5f878cef30d542698778957 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:00:32 +0100 Subject: [PATCH 055/108] removed use of use_jax in one function --- autoarray/geometry/geometry_util.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 0477fdbc9..bc99dfbd7 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -1,4 +1,6 @@ +import jax.numpy as jnp from typing import Tuple, Union + from autoarray.numpy_wrapper import np, use_jax from autoarray import numba_util @@ -179,7 +181,7 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales - +@numba_util.jit() def central_pixel_coordinates_2d_from( shape_native: Tuple[int, int], ) -> Tuple[float, float]: @@ -203,7 +205,7 @@ def central_pixel_coordinates_2d_from( """ return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) - +@numba_util.jit() def central_scaled_coordinate_2d_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, @@ -379,18 +381,21 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - if use_jax: - shifted_grid_2d = grid_2d.array - np.array(centre) - else: - shifted_grid_2d = grid_2d - np.array(centre) - radius = np.sqrt(np.sum(shifted_grid_2d**2.0, axis=1)) - theta_coordinate_to_profile = np.arctan2( + # if use_jax: + # shifted_grid_2d = grid_2d.array - np.array(centre) + # else: + # shifted_grid_2d = grid_2d - np.array(centre) + + shifted_grid_2d = grid_2d.array - jnp.array(centre) + + radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) + theta_coordinate_to_profile = jnp.arctan2( shifted_grid_2d[:, 0], shifted_grid_2d[:, 1] - ) - np.radians(angle) - return np.vstack( + ) - jnp.radians(angle) + return jnp.vstack( [ - radius * np.sin(theta_coordinate_to_profile), - radius * np.cos(theta_coordinate_to_profile), + radius * jnp.sin(theta_coordinate_to_profile), + radius * jnp.cos(theta_coordinate_to_profile), ] ).T From 9e3c76c3b974d808972d4358505088798456ea3b Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:04:04 +0100 Subject: [PATCH 056/108] grid_pixels_2d_slim_from now uses native numpy, could support JAX --- autoarray/geometry/geometry_util.py | 93 +++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 25 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index bc99dfbd7..48a271341 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -182,7 +182,7 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales @numba_util.jit() -def central_pixel_coordinates_2d_from( +def central_pixel_coordinates_2d_numba_from( shape_native: Tuple[int, int], ) -> Tuple[float, float]: """ @@ -206,7 +206,7 @@ def central_pixel_coordinates_2d_from( return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) @numba_util.jit() -def central_scaled_coordinate_2d_from( +def central_scaled_coordinate_2d_numba_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, origin: Tuple[float, float] = (0.0, 0.0), @@ -234,7 +234,7 @@ def central_scaled_coordinate_2d_from( The central coordinates of the 2D data structure in scaled units. """ - central_pixel_coordinates = central_pixel_coordinates_2d_from( + central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( shape_native=shape_native ) @@ -244,6 +244,67 @@ def central_scaled_coordinate_2d_from( return (y_pixel, x_pixel) +def central_pixel_coordinates_2d_from( + shape_native: Tuple[int, int], +) -> Tuple[float, float]: + """ + Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) + from the shape of that data structure. + + Examples of the central pixels are as follows: + + - For a 3x3 image, the central pixel is pixel [1, 1]. + - For a 4x4 image, the central pixel is [1.5, 1.5]. + + Parameters + ---------- + shape_native + The dimensions of the data structure, which can be in 1D, 2D or higher dimensions. + + Returns + ------- + The central pixel coordinates of the data structure. + """ + return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) + + +def central_scaled_coordinate_2d_from( + shape_native: Tuple[int, int], + pixel_scales: ty.PixelScales, + origin: Tuple[float, float] = (0.0, 0.0), +) -> Tuple[float, float]: + """ + Returns the central scaled coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) + from the shape of that data structure. + + This is computed by using the data structure's shape and converting it to scaled units using an input + pixel-coordinates to scaled-coordinate conversion factor `pixel_scales`. + + The origin of the scaled grid can also be input and moved from (0.0, 0.0). + + Parameters + ---------- + shape_native + The 2D shape of the data structure whose central scaled coordinates are computed. + pixel_scales + The (y,x) scaled units to pixel units conversion factor of the 2D data structure. + origin + The (y,x) scaled units origin of the coordinate system the central scaled coordinate is computed on. + + Returns + ------- + The central coordinates of the 2D data structure in scaled units. + """ + + central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( + shape_native=shape_native + ) + + y_pixel = central_pixel_coordinates[0] + (origin[0] / pixel_scales[0]) + x_pixel = central_pixel_coordinates[1] - (origin[1] / pixel_scales[1]) + + return (y_pixel, x_pixel) + def pixel_coordinates_2d_from( scaled_coordinates_2d: Tuple[float, float], shape_native: Tuple[int, int], @@ -437,7 +498,6 @@ def transform_grid_2d_from_reference_frame( return np.vstack((y, x)).T -@numba_util.jit() def grid_pixels_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -478,30 +538,13 @@ def grid_pixels_2d_slim_from( grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_scaled_2d_slim=grid_scaled_2d_slim, shape=(2,2), pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1, 1]) - return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 - else: - grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_pixels_2d_slim[slim_index, 0] = ( - (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d_slim[slim_index, 1] = ( - (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) - - return grid_pixels_2d_slim + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1, 1]) + return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 @numba_util.jit() From ead617ee95c9fd7b54b0c4d2ed8b28f1af255388 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:05:13 +0100 Subject: [PATCH 057/108] grid_pixel_centres_2d_slim_from, could support JAX --- autoarray/geometry/geometry_util.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 48a271341..a39cba235 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -547,7 +547,6 @@ def grid_pixels_2d_slim_from( return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 -@numba_util.jit() def grid_pixel_centres_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -592,29 +591,13 @@ def grid_pixel_centres_2d_slim_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_pixels_2d_slim = ( - (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 - ).astype(int) - else: - grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) - - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_pixels_2d_slim[slim_index, 0] = int( - (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d_slim[slim_index, 1] = int( - (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + return ( + (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 + ).astype(int) - return grid_pixels_2d_slim @numba_util.jit() From 2769aaf1a77ddfe6834a2e15b93e49051ae947c8 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:06:40 +0100 Subject: [PATCH 058/108] grid_pixel_indexes_2d_slim_from, could support JAX --- autoarray/geometry/geometry_util.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index a39cba235..0d7cf82a9 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -600,7 +600,6 @@ def grid_pixel_centres_2d_slim_from( -@numba_util.jit() def grid_pixel_indexes_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -653,22 +652,12 @@ def grid_pixel_indexes_2d_slim_from( origin=origin, ) - if use_jax: - grid_pixel_indexes_2d_slim = ( - (grid_pixels_2d_slim * np.array([shape_native[1], 1])) - .sum(axis=1) - .astype(int) - ) - else: - grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0]) - - for slim_index in range(grid_pixels_2d_slim.shape[0]): - grid_pixel_indexes_2d_slim[slim_index] = int( - grid_pixels_2d_slim[slim_index, 0] * shape_native[1] - + grid_pixels_2d_slim[slim_index, 1] - ) + return ( + (grid_pixels_2d_slim * np.array([shape_native[1], 1])) + .sum(axis=1) + .astype(int) + ) - return grid_pixel_indexes_2d_slim @numba_util.jit() From b2ba6bd45017bc6c8325a96209c0439180925836 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:07:02 +0100 Subject: [PATCH 059/108] grid_scaled_2d_slim_from, could support JAX --- autoarray/geometry/geometry_util.py | 31 ++++++++--------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 0d7cf82a9..aab4c94aa 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -599,7 +599,6 @@ def grid_pixel_centres_2d_slim_from( ).astype(int) - def grid_pixel_indexes_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -660,7 +659,6 @@ def grid_pixel_indexes_2d_slim_from( -@numba_util.jit() def grid_scaled_2d_slim_from( grid_pixels_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -699,30 +697,17 @@ def grid_scaled_2d_slim_from( grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_pixels_2d_slim=grid_pixels_2d_slim, shape=(2,2), pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1, 1]) - grid_scaled_2d_slim = ( - (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign - ) - else: - grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2)) - - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_scaled_2d_slim[slim_index, 0] = ( - -(grid_pixels_2d_slim[slim_index, 0] - centres_scaled[0] - 0.5) - * pixel_scales[0] - ) - grid_scaled_2d_slim[slim_index, 1] = ( - grid_pixels_2d_slim[slim_index, 1] - centres_scaled[1] - 0.5 - ) * pixel_scales[1] - - return grid_scaled_2d_slim + + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1, 1]) + return ( + (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign + ) + @numba_util.jit() From 05321044405cacfa803d033b45201df15c672c51 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:07:48 +0100 Subject: [PATCH 060/108] grid_pixel_centres_2d_from, could support JAX --- autoarray/geometry/geometry_util.py | 32 ++++++----------------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index aab4c94aa..5dca66c87 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -709,8 +709,6 @@ def grid_scaled_2d_slim_from( ) - -@numba_util.jit() def grid_pixel_centres_2d_from( grid_scaled_2d: np.ndarray, shape_native: Tuple[int, int], @@ -755,30 +753,12 @@ def grid_pixel_centres_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_pixels_2d = ( - (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 - ).astype(int) - else: - grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2)) - - for y in range(grid_scaled_2d.shape[0]): - for x in range(grid_scaled_2d.shape[1]): - grid_pixels_2d[y, x, 0] = int( - (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d[y, x, 1] = int( - (grid_scaled_2d[y, x, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) - - return grid_pixels_2d + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + return ( + (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 + ).astype(int) def extent_symmetric_from( From d90ff2ebb130e86b8cd9b97c3868fc95562cb06c Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:08:10 +0100 Subject: [PATCH 061/108] explciit separate imports --- autoarray/geometry/geometry_util.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 5dca66c87..5da893fac 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -1,7 +1,7 @@ import jax.numpy as jnp +import numpy as np from typing import Tuple, Union -from autoarray.numpy_wrapper import np, use_jax from autoarray import numba_util from autoarray import type as ty @@ -442,11 +442,6 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - # if use_jax: - # shifted_grid_2d = grid_2d.array - np.array(centre) - # else: - # shifted_grid_2d = grid_2d - np.array(centre) - shifted_grid_2d = grid_2d.array - jnp.array(centre) radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) From 59b21e999aa1879d1ff0fe8818f6d195027d8671 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:14:36 +0100 Subject: [PATCH 062/108] fix unit test in test__transform_2d_grid_from_reference_frame --- autoarray/geometry/geometry_util.py | 3 +-- test_autoarray/geometry/test_geometry_util.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 5da893fac..405262f96 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -442,7 +442,7 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - shifted_grid_2d = grid_2d.array - jnp.array(centre) + shifted_grid_2d = np.array(grid_2d) - jnp.array(centre) radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) theta_coordinate_to_profile = jnp.arctan2( @@ -653,7 +653,6 @@ def grid_pixel_indexes_2d_slim_from( ) - def grid_scaled_2d_slim_from( grid_pixels_2d_slim: np.ndarray, shape_native: Tuple[int, int], diff --git a/test_autoarray/geometry/test_geometry_util.py b/test_autoarray/geometry/test_geometry_util.py index b6b8b38d6..2fa8ce63c 100644 --- a/test_autoarray/geometry/test_geometry_util.py +++ b/test_autoarray/geometry/test_geometry_util.py @@ -1071,7 +1071,8 @@ def test__transform_2d_grid_from_reference_frame(): grid_2d=transformed_grid_2d, centre=(8.0, 5.0), angle=137.0 ) - assert grid_2d == pytest.approx(original_grid_2d, 1.0e-4) + + assert (np.abs(grid_2d - original_grid_2d) < 1e-4).all() def test__grid_pixels_2d_slim_from(): From c453a3c1951a0e3e3c9416e6f041c16a38bd61ac Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:17:43 +0100 Subject: [PATCH 063/108] use absolute tolerance to fix geomtry util unit tests --- test_autoarray/geometry/test_geometry_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test_autoarray/geometry/test_geometry_util.py b/test_autoarray/geometry/test_geometry_util.py index 2fa8ce63c..4bc2706dc 100644 --- a/test_autoarray/geometry/test_geometry_util.py +++ b/test_autoarray/geometry/test_geometry_util.py @@ -980,7 +980,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[0.0, 1.0], [1.0, 1.0], [1.0, 0.0]]) + np.array([[0.0, 1.0], [1.0, 1.0], [1.0, 0.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -994,7 +994,7 @@ def test__transform_2d_grid_to_reference_frame(): [0.0, np.sqrt(2)], [np.sqrt(2) / 2.0, np.sqrt(2) / 2.0], ] - ) + ), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1002,7 +1002,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[-1.0, 0.0], [-1.0, 1.0], [0.0, 1.0]]) + np.array([[-1.0, 0.0], [-1.0, 1.0], [0.0, 1.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1010,7 +1010,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[0.0, -1.0], [-1.0, -1.0], [-1.0, 0.0]]) + np.array([[0.0, -1.0], [-1.0, -1.0], [-1.0, 0.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1072,7 +1072,7 @@ def test__transform_2d_grid_from_reference_frame(): ) - assert (np.abs(grid_2d - original_grid_2d) < 1e-4).all() + assert grid_2d == pytest.approx(original_grid_2d, abs=1.0e-4) def test__grid_pixels_2d_slim_from(): From 0c4bb3094f4398fe1c69d606f1961d01e0f4c5a0 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:19:09 +0100 Subject: [PATCH 064/108] fix test__pixel_coordinates_2d_from --- autoarray/geometry/geometry_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 405262f96..b646c7d08 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -412,7 +412,7 @@ def scaled_coordinates_2d_from( origin=(0.0, 0.0) ) """ - central_scaled_coordinates = central_scaled_coordinate_2d_from( + central_scaled_coordinates = central_scaled_coordinate_2d_numba_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origins ) From d891947e6040d52f3057836d8f62e6557e95e0f6 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:26:06 +0100 Subject: [PATCH 065/108] cleaned up jax imports of array_2d_util to make more tests pass --- autoarray/structures/arrays/array_2d_util.py | 103 +++++-------------- 1 file changed, 28 insertions(+), 75 deletions(-) diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index c75cc9750..796c5d965 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -1,5 +1,7 @@ from __future__ import annotations - +import jax +import jax.numpy as jnp +import numpy as np from typing import TYPE_CHECKING, List, Tuple, Union if TYPE_CHECKING: @@ -9,11 +11,8 @@ from autoarray.mask import mask_2d_util from autoarray import exc -from autoarray.numpy_wrapper import use_jax, np, jit from functools import partial -if use_jax: - import jax def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: @@ -25,13 +24,9 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - if use_jax: - array = jax.lax.cond( + array = jax.lax.cond( type(array) is list, lambda _: np.asarray(array), lambda _: array, None ) - elif type(array) is list: - array = np.asarray(array) - return array @@ -42,13 +37,9 @@ def exception_message(): ) cond = len(array_2d.shape) != 1 - if use_jax: - jax.lax.cond( - cond, lambda _: jax.debug.callback(exception_message), lambda _: None, None - ) - elif cond: - exception_message() - + jax.lax.cond( + cond, lambda _: jax.debug.callback(exception_message), lambda _: None, None + ) def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): """ @@ -88,15 +79,12 @@ def exception_message_1(): array_2d.shape[0] != mask_2d.pixels_in_mask ) - if use_jax: - jax.lax.cond( - cond_1, - lambda _: jax.debug.callback(exception_message_1), - lambda _: None, - None, - ) - elif cond_1: - exception_message_1() + jax.lax.cond( + cond_1, + lambda _: jax.debug.callback(exception_message_1), + lambda _: None, + None, + ) def exception_message_2(): raise exc.ArrayException( @@ -114,16 +102,12 @@ def exception_message_2(): cond_2 = (len(array_2d.shape) == 2) and (array_2d.shape != mask_2d.shape_native) - if use_jax: - jax.lax.cond( - cond_2, - lambda _: jax.debug.callback(exception_message_2), - lambda _: None, - None, - ) - elif cond_2: - exception_message_2() - + jax.lax.cond( + cond_2, + lambda _: jax.debug.callback(exception_message_2), + lambda _: None, + None, + ) def convert_array_2d( array_2d: Union[np.ndarray, List], @@ -159,8 +143,7 @@ def convert_array_2d( check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d) is_native = len(array_2d.shape) == 2 - if use_jax: - mask_2d = mask_2d.array + mask_2d = np.array(mask_2d) if is_native and not skip_mask: array_2d *= np.invert(mask_2d) @@ -526,8 +509,6 @@ def index_slim_for_index_2d_from(indexes_2d: np.ndarray, shape_native) -> np.nda return index_slim_for_index_native_2d - -@numba_util.jit() def array_2d_slim_from( array_2d_native: np.ndarray, mask_2d: np.ndarray, @@ -571,23 +552,7 @@ def array_2d_slim_from( array_2d_slim = array_2d_slim_from(mask=mask, array_2d=array_2d) """ - - if use_jax: - array_2d_slim = array_2d_native[~mask_2d.astype(bool)] - else: - total_pixels = np.sum(~mask_2d) - - array_2d_slim = np.zeros(shape=total_pixels) - index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - array_2d_slim[index] = array_2d_native[y, x] - index += 1 - - return array_2d_slim - + return array_2d_native[~mask_2d.astype(bool)] def array_2d_native_from( array_2d_slim: np.ndarray, @@ -641,8 +606,7 @@ def array_2d_native_from( ) -@partial(jit, static_argnums=(1,)) -@numba_util.jit() +@partial(jax.jit, static_argnums=(1,)) def array_2d_via_indexes_from( array_2d_slim: np.ndarray, shape: Tuple[int, int], @@ -675,22 +639,11 @@ def array_2d_via_indexes_from( ndarray The native 2D array of values mapped from the slimmed array with dimensions (total_values, total_values). """ - if use_jax: - array_native_2d = ( - np.zeros(shape) - .at[tuple(native_index_for_slim_index_2d.T)] - .set(array_2d_slim) - ) - else: - array_native_2d = np.zeros(shape) - - for slim_index in range(len(native_index_for_slim_index_2d)): - array_native_2d[ - native_index_for_slim_index_2d[slim_index, 0], - native_index_for_slim_index_2d[slim_index, 1], - ] = array_2d_slim[slim_index] - - return array_native_2d + return ( + jnp.zeros(shape) + .at[tuple(native_index_for_slim_index_2d.T)] + .set(array_2d_slim) + ) @numba_util.jit() @@ -725,7 +678,7 @@ def array_2d_slim_complex_from( A 1D array of values mapped from the 2D array with dimensions (total_unmasked_pixels). """ - total_pixels = np.sum(~mask_2d) + total_pixels = np.sum(~mask) array_1d = 0 + 0j * np.zeros(shape=total_pixels) index = 0 From ea7aa9d7a6402c7c2dc3eff460b3ea6c335295fc Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:30:04 +0100 Subject: [PATCH 066/108] cleanup imports of grid_2d_util --- autoarray/structures/grids/grid_2d_util.py | 62 +++++++--------------- 1 file changed, 20 insertions(+), 42 deletions(-) diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 522bb0525..574c8898d 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -1,11 +1,7 @@ from __future__ import annotations - +import numpy as np import jax.numpy as jnp - -from autoarray.numpy_wrapper import np, use_jax - -if use_jax: - import jax +import jax from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -16,7 +12,6 @@ from autoarray.structures.arrays import array_2d_util from autoarray.geometry import geometry_util from autoarray import numba_util -from autoarray.mask import mask_2d_util from autoarray import type as ty @@ -73,15 +68,12 @@ def exception_message(): """ ) - if use_jax: - jax.lax.cond( - grid_2d.shape[0] != mask_2d.pixels_in_mask, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - elif grid_2d.shape[0] != mask_2d.pixels_in_mask: - exception_message() + jax.lax.cond( + grid_2d.shape[0] != mask_2d.pixels_in_mask, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None, + ) elif len(grid_2d.shape) == 3: @@ -96,15 +88,12 @@ def exception_message(): """ ) - if use_jax: - jax.lax.cond( - (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - elif (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: - exception_message() + jax.lax.cond( + (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None, + ) def convert_grid_2d( @@ -140,12 +129,8 @@ def convert_grid_2d( is_native = len(grid_2d.shape) == 3 if is_native: - if use_jax: - grid_2d = grid_2d.at[:, :, 0].multiply(np.invert(mask_2d.array)) - grid_2d = grid_2d.at[:, :, 1].multiply(np.invert(mask_2d.array)) - else: - grid_2d[:, :, 0] *= np.invert(mask_2d) - grid_2d[:, :, 1] *= np.invert(mask_2d) + grid_2d = grid_2d.at[:, :, 0].multiply(np.invert(mask_2d.array)) + grid_2d = grid_2d.at[:, :, 1].multiply(np.invert(mask_2d.array)) if is_native == store_native: return grid_2d @@ -154,17 +139,10 @@ def convert_grid_2d( grid_2d_native=np.array(grid_2d), mask=np.array(mask_2d), ) - if use_jax: - return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d.array), - mask_2d=np.array(mask_2d), - ) - else: - return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d), - mask_2d=np.array(mask_2d), - ) - + return grid_2d_native_from( + grid_2d_slim=np.array(grid_2d.array), + mask_2d=np.array(mask_2d), + ) def convert_grid_2d_to_slim( grid_2d: Union[np.ndarray, List], mask_2d: Mask2D From 4014d0378048796028a15fd99fcf2d12e2325a4d Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:38:26 +0100 Subject: [PATCH 067/108] convert methods in grid_2d_util assume ndarray --- .../operators/over_sampling/over_sampler.py | 4 +-- autoarray/structures/grids/grid_2d_util.py | 29 +++++-------------- autoarray/structures/grids/uniform_2d.py | 4 +-- 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 6492c00b7..aa80392b4 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -147,11 +147,11 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): ) def tree_flatten(self): - return (self.mask,), () + return (self.mask, self.sub_size), () @classmethod def tree_unflatten(cls, aux_data, children): - return cls(mask=children[0]) + return cls(mask=children[0], sub_size=children[1]) @property def sub_total(self): diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 574c8898d..6c72c00ef 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -55,29 +55,20 @@ def check_grid_2d(grid_2d: np.ndarray): def check_grid_2d_and_mask_2d(grid_2d: np.ndarray, mask_2d: Mask2D): if len(grid_2d.shape) == 2: - - def exception_message(): + if grid_2d.shape[0] != mask_2d.pixels_in_mask: raise exc.GridException( f""" The input 2D grid does not have the same number of values as pixels in the mask. - + The shape of the input grid_2d is {grid_2d.shape}. The mask shape_native is {mask_2d.shape_native}. The mask number of pixels is {mask_2d.pixels_in_mask}. """ ) - jax.lax.cond( - grid_2d.shape[0] != mask_2d.pixels_in_mask, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - elif len(grid_2d.shape) == 3: - - def exception_message(): + if (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: raise exc.GridException( f""" The input 2D grid is not the same dimensions as the mask @@ -88,13 +79,6 @@ def exception_message(): """ ) - jax.lax.cond( - (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - def convert_grid_2d( grid_2d: Union[np.ndarray, List], mask_2d: Mask2D, store_native: bool = False @@ -129,8 +113,8 @@ def convert_grid_2d( is_native = len(grid_2d.shape) == 3 if is_native: - grid_2d = grid_2d.at[:, :, 0].multiply(np.invert(mask_2d.array)) - grid_2d = grid_2d.at[:, :, 1].multiply(np.invert(mask_2d.array)) + grid_2d[:, :, 0] *= np.invert(mask_2d) + grid_2d[:, :, 1] *= np.invert(mask_2d) if is_native == store_native: return grid_2d @@ -140,10 +124,11 @@ def convert_grid_2d( mask=np.array(mask_2d), ) return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d.array), + grid_2d_slim=np.array(grid_2d), mask_2d=np.array(mask_2d), ) + def convert_grid_2d_to_slim( grid_2d: Union[np.ndarray, List], mask_2d: Mask2D ) -> np.ndarray: diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 846d15cc8..f0d20042c 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -1,4 +1,5 @@ from __future__ import annotations +import jax.numpy as jnp import numpy as np from pathlib import Path from typing import List, Optional, Tuple, Union @@ -14,7 +15,6 @@ from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.structures.arrays import array_2d_util from autoarray.structures.grids import grid_2d_util from autoarray.geometry import geometry_util from autoarray.operators.over_sampling import over_sample_util @@ -162,7 +162,7 @@ def __init__( over sampled grid is not passed in it is computed assuming uniformity. """ values = grid_2d_util.convert_grid_2d( - grid_2d=values, + grid_2d=np.array(values), mask_2d=mask, store_native=store_native, ) From 075654fc7fe33d3dcc0b3db0acc595239bb35dec Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:42:57 +0100 Subject: [PATCH 068/108] more simlpifying of convert functions --- autoarray/structures/arrays/array_2d_util.py | 83 +++++++------------- autoarray/structures/arrays/uniform_2d.py | 7 +- 2 files changed, 30 insertions(+), 60 deletions(-) diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index 796c5d965..f269d0106 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -15,6 +15,7 @@ + def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: """ If the input array input a convert is of type list, convert it to type NumPy array. @@ -24,23 +25,17 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - array = jax.lax.cond( - type(array) is list, lambda _: np.asarray(array), lambda _: array, None - ) + array = np.asarray(array) + return array def check_array_2d(array_2d: np.ndarray): - def exception_message(): + if len(array_2d.shape) != 1: raise exc.ArrayException( "An array input into the Array2D.__new__ method is not of shape 1." ) - cond = len(array_2d.shape) != 1 - jax.lax.cond( - cond, lambda _: jax.debug.callback(exception_message), lambda _: None, None - ) - def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): """ The `manual` classmethods in the `Array2D` object take as input a list or ndarray which is returned as an @@ -57,57 +52,38 @@ def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): mask_2d The mask of the output Array2D. """ + if len(array_2d.shape) == 1: + if array_2d.shape[0] != mask_2d.pixels_in_mask: + raise exc.ArrayException( + f""" + The input array is a slim 1D array, but it does not have the same number of entries as pixels in + the mask. - def exception_message_1(): - raise exc.ArrayException( - f""" - The input array is a slim 1D array, but it does not have the same number of entries as pixels in - the mask. - - This indicates that the number of unmaksed pixels in the mask is different to the input slim array - shape. - - The shapes of the two arrays (which this exception is raised because they are different) are as follows: - - Input array_2d_slim.shape = {array_2d.shape[0]} - Input mask_2d.pixels_in_mask = {mask_2d.pixels_in_mask} - Input mask_2d.shape_native = {mask_2d.shape_native} - """ - ) - - cond_1 = (len(array_2d.shape) == 1) and ( - array_2d.shape[0] != mask_2d.pixels_in_mask - ) - - jax.lax.cond( - cond_1, - lambda _: jax.debug.callback(exception_message_1), - lambda _: None, - None, - ) + This indicates that the number of unmaksed pixels in the mask is different to the input slim array + shape. - def exception_message_2(): - raise exc.ArrayException( - f""" - The input array is 2D but not the same dimensions as the mask. + The shapes of the two arrays (which this exception is raised because they are different) are as follows: - This indicates the mask's shape is different to the input array shape. + Input array_2d_slim.shape = {array_2d.shape[0]} + Input mask_2d.pixels_in_mask = {mask_2d.pixels_in_mask} + Input mask_2d.shape_native = {mask_2d.shape_native} + """ + ) - The shapes of the two arrays (which this exception is raised because they are different) are as follows: + if len(array_2d.shape) == 2: + if array_2d.shape != mask_2d.shape_native: + raise exc.ArrayException( + f""" + The input array is 2D but not the same dimensions as the mask. - Input array_2d shape = {array_2d.shape} - Input mask_2d shape_native = {mask_2d.shape_native} - """ - ) + This indicates the mask's shape is different to the input array shape. - cond_2 = (len(array_2d.shape) == 2) and (array_2d.shape != mask_2d.shape_native) + The shapes of the two arrays (which this exception is raised because they are different) are as follows: - jax.lax.cond( - cond_2, - lambda _: jax.debug.callback(exception_message_2), - lambda _: None, - None, - ) + Input array_2d shape = {array_2d.shape} + Input mask_2d shape_native = {mask_2d.shape_native} + """ + ) def convert_array_2d( array_2d: Union[np.ndarray, List], @@ -143,7 +119,6 @@ def convert_array_2d( check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d) is_native = len(array_2d.shape) == 2 - mask_2d = np.array(mask_2d) if is_native and not skip_mask: array_2d *= np.invert(mask_2d) diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index 4406db105..9c01a8a1e 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -229,16 +229,11 @@ def __init__( print(array_2d.native) # masked 2D data representation. """ - try: - values = values._array - except AttributeError: - pass - if conf.instance["general"]["structures"]["native_binned_only"]: store_native = True values = array_2d_util.convert_array_2d( - array_2d=values, + array_2d=np.array(values), mask_2d=mask, store_native=store_native, skip_mask=skip_mask, From 17817b81c843ea8949bd1e49fafa054821c9eef8 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:52:55 +0100 Subject: [PATCH 069/108] mask derive fixed --- autoarray/mask/derive/indexes_2d.py | 4 ++-- autoarray/structures/arrays/array_2d_util.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index d8cf61a6e..4807799d9 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -198,7 +198,7 @@ def edge_slim(self) -> np.ndarray: print(derive_indexes_2d.edge_slim) """ - return mask_2d_util.edge_1d_indexes_from(mask_2d=np.array(self.mask)).astype( + return mask_2d_util.edge_1d_indexes_from(mask_2d=np.array(self.mask).astype("bool")).astype( "int" ) @@ -301,7 +301,7 @@ def border_slim(self) -> np.ndarray: print(derive_indexes_2d.border_slim) """ return mask_2d_util.border_slim_indexes_from( - mask_2d=np.array(self.mask) + mask_2d=np.array(self.mask).astype("bool") ).astype("int") @property diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index f269d0106..7048c07a0 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -14,8 +14,6 @@ from functools import partial - - def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: """ If the input array input a convert is of type list, convert it to type NumPy array. From b76cc9aac5e78c2d88b218e4a9b49b0c499981d7 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:59:50 +0100 Subject: [PATCH 070/108] another way to make hecks only use ndarray --- autoarray/structures/arrays/uniform_2d.py | 7 ++++++- .../arrays/files/array/output_test/array.fits | Bin 5760 -> 5760 bytes 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index 9c01a8a1e..12ed86b0d 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -232,8 +232,13 @@ def __init__( if conf.instance["general"]["structures"]["native_binned_only"]: store_native = True + try: + values = values._array + except AttributeError: + values = values + values = array_2d_util.convert_array_2d( - array_2d=np.array(values), + array_2d=values, mask_2d=mask, store_native=store_native, skip_mask=skip_mask, diff --git a/test_autoarray/structures/arrays/files/array/output_test/array.fits b/test_autoarray/structures/arrays/files/array/output_test/array.fits index 0cff0a8f7db90d3ded28c644cd66e3791627feeb..dab9da2fb676efc023e53b73faa54fe78fc4000c 100644 GIT binary patch delta 56 kcmZqBZP1;N!(?o$Vgmz%Jzl(dBKLc)$qTsk0j$6e2mk;8 delta 84 hcmZqBZP1;N!(?W%G4C>$;|B&XuqT_|T*&>83jh#~5uX46 From c9e275ddd8ad0887bab644366d47e9343338ed95 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:23:29 +0100 Subject: [PATCH 071/108] fixes which ensure grad works on real LH function --- autoarray/fit/fit_util.py | 6 +++--- .../operators/over_sampling/over_sample_util.py | 2 +- autoarray/operators/over_sampling/over_sampler.py | 12 +++++++----- autoarray/structures/arrays/array_2d_util.py | 11 +++++++++-- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index e5d34675b..d40f55d1c 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -1,6 +1,6 @@ from functools import wraps +import jax.numpy as np -from autoarray.numpy_wrapper import np from autoarray.mask.abstract_mask import Mask from autoarray import type as ty @@ -83,7 +83,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - return np.sum(chi_squared_map) + return np.sum(chi_squared_map._array) def noise_normalization_from(*, noise_map: ty.DataLike) -> float: @@ -97,7 +97,7 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float: noise_map The masked noise-map of the dataset. """ - return np.sum(np.log(2 * np.pi * noise_map**2.0)) + return np.sum(np.log(2 * np.pi * noise_map._array**2.0)) def normalized_residual_map_complex_from( diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index 70dc220c4..8966e785e 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -422,7 +422,7 @@ def grid_2d_slim_over_sampled_via_mask_from( grid_slim = np.zeros(shape=(total_sub_pixels, 2)) - centres_scaled = geometry_util.central_scaled_coordinate_2d_from( + centres_scaled = geometry_util.central_scaled_coordinate_2d_numba_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index aa80392b4..6440a2ee7 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -220,11 +220,13 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": except AttributeError: pass - binned_array_2d = over_sample_util.binned_array_2d_from( - array_2d=np.array(array), - mask_2d=np.array(self.mask), - sub_size=np.array(self.sub_size).astype("int"), - ) + # binned_array_2d = over_sample_util.binned_array_2d_from( + # array_2d=np.array(array), + # mask_2d=np.array(self.mask), + # sub_size=np.array(self.sub_size).astype("int"), + # ) + + binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]).mean(axis=1) return Array2D( values=binned_array_2d, diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index 7048c07a0..0e694846d 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -23,8 +23,15 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - array = np.asarray(array) - + if isinstance(array, np.ndarray) or isinstance(array, list): + array = np.asarray(array) + elif isinstance(array, jnp.ndarray): + array = jax.lax.cond( + type(array) is list, + lambda _: jnp.asarray(array), + lambda _: array, + None + ) return array From 70c021242f51da937340ac221a1618474ee492d2 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:32:26 +0100 Subject: [PATCH 072/108] fix all uniform_2d unit tests --- autoarray/geometry/geometry_2d.py | 3 ++- autoarray/structures/arrays/uniform_2d.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/autoarray/geometry/geometry_2d.py b/autoarray/geometry/geometry_2d.py index 29c604405..e78f0f75a 100644 --- a/autoarray/geometry/geometry_2d.py +++ b/autoarray/geometry/geometry_2d.py @@ -184,8 +184,9 @@ def scaled_coordinates_2d_from( ------- A 2D (y,x) pixel-value coordinate. """ + return geometry_util.scaled_coordinates_2d_from( - pixel_coordinates_2d=pixel_coordinates_2d, + pixel_coordinates_2d=np.array(pixel_coordinates_2d), shape_native=self.shape_native, pixel_scales=self.pixel_scales, origins=self.origin, diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index 12ed86b0d..11c478ad5 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -464,7 +464,7 @@ def zoomed_around_mask(self, buffer: int = 1) -> "Array2D": """ extracted_array_2d = array_2d_util.extracted_array_2d_from( - array_2d=np.array(self.native), + array_2d=np.array(self.native._array), y0=self.mask.zoom_region[0] - buffer, y1=self.mask.zoom_region[1] + buffer, x0=self.mask.zoom_region[2] - buffer, @@ -498,7 +498,7 @@ def extent_of_zoomed_array(self, buffer: int = 1) -> np.ndarray: The number pixels around the extracted array used as a buffer. """ extracted_array_2d = array_2d_util.extracted_array_2d_from( - array_2d=np.array(self.native), + array_2d=np.array(self.native._array), y0=self.mask.zoom_region[0] - buffer, y1=self.mask.zoom_region[1] + buffer, x0=self.mask.zoom_region[2] - buffer, @@ -532,7 +532,7 @@ def resized_from( """ resized_array_2d = array_2d_util.resized_array_2d_from( - array_2d=np.array(self.native), resized_shape=new_shape + array_2d=np.array(self.native._array), resized_shape=new_shape ) resized_mask = self.mask.resized_from( @@ -599,7 +599,7 @@ def trimmed_after_convolution_from( resized_mask = self.mask.resized_from(new_shape=trimmed_array_2d.shape) array = array_2d_util.convert_array_2d( - array_2d=trimmed_array_2d, mask_2d=resized_mask + array_2d=trimmed_array_2d._array, mask_2d=resized_mask ) return Array2D( From c417511eb43fbda0925c968128a17d0b2fa235b2 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:36:29 +0100 Subject: [PATCH 073/108] fix all of kernel 2d --- autoarray/structures/arrays/kernel_2d.py | 6 +++--- test_autoarray/structures/arrays/test_kernel_2d.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index c03ceaf32..f80e3f6e3 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -383,7 +383,7 @@ def rescaled_with_odd_dimensions_from( try: kernel_rescaled = rescale( - self.native, + np.array(self.native._array), rescale_factor, anti_aliasing=False, mode="constant", @@ -391,7 +391,7 @@ def rescaled_with_odd_dimensions_from( ) except TypeError: kernel_rescaled = rescale( - self.native, + np.array(self.native._array), rescale_factor, anti_aliasing=False, mode="constant", @@ -472,7 +472,7 @@ def convolved_array_from(self, array: Array2D) -> Array2D: array_2d = array.native - convolved_array_2d = scipy.signal.convolve2d(array_2d, self.native, mode="same") + convolved_array_2d = scipy.signal.convolve2d(array_2d._array, np.array(self.native._array), mode="same") convolved_array_1d = array_2d_util.array_2d_slim_from( mask_2d=np.array(array_2d.mask), diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 292d628f7..9106b43a1 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -457,4 +457,4 @@ def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy normalize=True, ) - assert kernel_astropy == pytest.approx(kernel_2d.native, 1e-4) + assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) From db9cfb7aefe1bcd83afc24b31d249e5ef98df844 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:37:37 +0100 Subject: [PATCH 074/108] fix repr --- test_autoarray/structures/arrays/test_repr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_autoarray/structures/arrays/test_repr.py b/test_autoarray/structures/arrays/test_repr.py index cf0bfa964..8ea8c6b9d 100644 --- a/test_autoarray/structures/arrays/test_repr.py +++ b/test_autoarray/structures/arrays/test_repr.py @@ -3,4 +3,4 @@ def test_repr(): array = aa.Array2D.no_mask([[1, 2], [3, 4]], pixel_scales=1) - assert repr(array) == "Array2D([1., 2., 3., 4.])" + assert repr(array) == "Array2D([1, 2, 3, 4])" From 3cb3f7622d1030349825cba5d7dd8665f9e527f3 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:41:11 +0100 Subject: [PATCH 075/108] remove relocate_to_radial_minimum test as all functionality is to be removed --- .../structures/grids/test_uniform_2d.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/test_autoarray/structures/grids/test_uniform_2d.py b/test_autoarray/structures/grids/test_uniform_2d.py index 4084908ba..d11457416 100644 --- a/test_autoarray/structures/grids/test_uniform_2d.py +++ b/test_autoarray/structures/grids/test_uniform_2d.py @@ -780,37 +780,6 @@ def test__grid_with_coordinates_within_distance_removed_from(): ).all() -def test__grid_radial_minimum(): - grid_2d = np.array([[2.5, 0.0], [4.0, 0.0], [6.0, 0.0]]) - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - assert (deflections == grid_2d).all() - - grid_2d = np.array([[2.0, 0.0], [1.0, 0.0], [6.0, 0.0]]) - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - - assert (deflections == np.array([[2.5, 0.0], [2.5, 0.0], [6.0, 0.0]])).all() - - grid_2d = np.array( - [ - [np.sqrt(2.0), np.sqrt(2.0)], - [1.0, np.sqrt(8.0)], - [np.sqrt(8.0), np.sqrt(8.0)], - ] - ) - - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - - assert deflections == pytest.approx( - np.array([[1.7677, 1.7677], [1.0, np.sqrt(8.0)], [np.sqrt(8), np.sqrt(8.0)]]), - 1.0e-4, - ) - def test__recursive_shape_storage(): grid_2d = aa.Grid2D.no_mask( From 467d1ea614eecafdb70d5741d1390ff32a9a124e Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:48:36 +0100 Subject: [PATCH 076/108] fix Grid2D test_unifrom --- autoarray/structures/grids/uniform_2d.py | 36 ++++++++++--------- .../structures/grids/test_uniform_2d.py | 6 ++-- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index f0d20042c..4d565dd27 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -4,8 +4,6 @@ from pathlib import Path from typing import List, Optional, Tuple, Union -from autoarray.numpy_wrapper import use_jax - from autoconf import conf from autoconf import cached_property from autoconf.fitsable import ndarray_via_fits_from @@ -820,7 +818,7 @@ def grid_with_coordinates_within_distance_removed_from( distance_mask += distances.native < distance mask = Mask2D( - mask=distance_mask, + mask=np.array(distance_mask), pixel_scales=self.pixel_scales, origin=self.origin, ) @@ -841,8 +839,8 @@ def squared_distances_to_coordinate_from( coordinate The (y,x) coordinate from which the squared distance of every grid (y,x) coordinate is computed. """ - if use_jax: - squared_distances = np.square(self.array[:, 0] - coordinate[0]) + np.square( + if isinstance(self, jnp.ndarray): + squared_distances = jnp.square(self.array[:, 0] - coordinate[0]) + jnp.square( self.array[:, 1] - coordinate[1] ) else: @@ -1015,10 +1013,16 @@ def shape_native_scaled_interior(self) -> Tuple[float, float]: of the grid's (y,x) values, whereas the `shape_native_scaled` uses the uniform geometry of the grid and its ``pixel_scales``, which means it has a buffer at each edge of half a ``pixel_scale``. """ - return ( - np.amax(self[:, 0]) - np.amin(self[:, 0]), - np.amax(self[:, 1]) - np.amin(self[:, 1]), - ) + if isinstance(self, jnp.ndarray): + return ( + np.amax(self.array[:, 0]) - np.amin(self.array[:, 0]), + np.amax(self.array[:, 1]) - np.amin(self.array[:, 1]), + ) + else: + return ( + np.amax(self[:, 0]) - np.amin(self[:, 0]), + np.amax(self[:, 1]) - np.amin(self[:, 1]), + ) @property def scaled_minima(self) -> Tuple: @@ -1026,10 +1030,10 @@ def scaled_minima(self) -> Tuple: The (y,x) minimum values of the grid in scaled units, buffed such that their extent is further than the grid's extent. """ - if use_jax: + if isinstance(self, jnp.ndarray): return ( - np.amin(self.array[:, 0]).astype("float"), - np.amin(self.array[:, 1]).astype("float"), + jnp.amin(self.array[:, 0]).astype("float"), + jnp.amin(self.array[:, 1]).astype("float"), ) else: return ( @@ -1043,10 +1047,10 @@ def scaled_maxima(self) -> Tuple: The (y,x) maximum values of the grid in scaled units, buffed such that their extent is further than the grid's extent. """ - if use_jax: + if isinstance(self, jnp.ndarray): return ( - np.amax(self.array[:, 0]).astype("float"), - np.amax(self.array[:, 1]).astype("float"), + jnp.amax(self.array[:, 0]).astype("float"), + jnp.amax(self.array[:, 1]).astype("float"), ) else: return ( @@ -1103,7 +1107,7 @@ def padded_grid_from(self, kernel_shape_native: Tuple[int, int]) -> "Grid2D": ) over_sample_size = np.pad( - self.over_sample_size.native, + self.over_sample_size.native._array, pad_width, mode="constant", constant_values=1, diff --git a/test_autoarray/structures/grids/test_uniform_2d.py b/test_autoarray/structures/grids/test_uniform_2d.py index d11457416..686721c70 100644 --- a/test_autoarray/structures/grids/test_uniform_2d.py +++ b/test_autoarray/structures/grids/test_uniform_2d.py @@ -481,9 +481,10 @@ def test__to_and_from_fits_methods(): def test__shape_native_scaled(): - mask = aa.Mask2D.circular(shape_native=(3, 3), radius=1.0, pixel_scales=(1.0, 1.0)) + mask = aa.Mask2D.circular(shape_native=(3, 3), radius=1.1, pixel_scales=(1.0, 1.0)) grid_2d = aa.Grid2D.from_mask(mask=mask) + assert grid_2d.shape_native_scaled_interior == (2.0, 2.0) mask = aa.Mask2D.elliptical( @@ -574,7 +575,8 @@ def test__grid_2d_radial_projected_shape_slim_from(): pixel_scales=grid_2d.pixel_scales, ) - assert (grid_radii == grid_radii_util).all() + + assert grid_radii == pytest.approx(grid_radii_util, 1.0e-4) assert grid_radial_shape_slim == grid_radii_util.shape[0] grid_radii = grid_2d.grid_2d_radial_projected_from(centre=(0.3, 0.1), angle=60.0) From 7751080c4d627c8cad7626868f1fa372a9110f22 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:52:35 +0100 Subject: [PATCH 077/108] fix grid test_uniform_1d --- autoarray/structures/grids/uniform_1d.py | 6 +----- test_autoarray/structures/grids/test_uniform_1d.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/autoarray/structures/grids/uniform_1d.py b/autoarray/structures/grids/uniform_1d.py index d5870aeca..53a9ec756 100644 --- a/autoarray/structures/grids/uniform_1d.py +++ b/autoarray/structures/grids/uniform_1d.py @@ -1,10 +1,6 @@ from __future__ import annotations import numpy as np -from typing import TYPE_CHECKING, List, Union, Tuple - -if TYPE_CHECKING: - from autoarray.structures.arrays.uniform_1d import Array1D - from autoarray.structures.grids.uniform_2d import Grid2D +from typing import List, Union, Tuple from autoarray.structures.abstract_structure import Structure from autoarray.structures.grids.irregular_2d import Grid2DIrregular diff --git a/test_autoarray/structures/grids/test_uniform_1d.py b/test_autoarray/structures/grids/test_uniform_1d.py index 147bd4b17..b99da3608 100644 --- a/test_autoarray/structures/grids/test_uniform_1d.py +++ b/test_autoarray/structures/grids/test_uniform_1d.py @@ -129,7 +129,7 @@ def test__grid_2d_radial_projected_from(): grid_2d = grid_1d.grid_2d_radial_projected_from(angle=90.0) assert grid_2d.slim == pytest.approx( - np.array([[-1.0, 0.0], [-2.0, 0.0], [-3.0, 0.0], [-4.0, 0.0]]), 1.0e-4 + np.array([[-1.0, 0.0], [-2.0, 0.0], [-3.0, 0.0], [-4.0, 0.0]]), abs=1.0e-4 ) grid_2d = grid_1d.grid_2d_radial_projected_from(angle=45.0) From f4c3269576995a8c2b042bb750a7a134ef062879 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:55:19 +0100 Subject: [PATCH 078/108] hammer hammer hammer --- autoarray/structures/grids/irregular_2d.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/autoarray/structures/grids/irregular_2d.py b/autoarray/structures/grids/irregular_2d.py index 68ad0dce4..57d4c4a6f 100644 --- a/autoarray/structures/grids/irregular_2d.py +++ b/autoarray/structures/grids/irregular_2d.py @@ -1,15 +1,12 @@ import logging -from typing import List, Optional, Tuple, Union +import numpy as np +from typing import List, Tuple, Union -from autoarray.numpy_wrapper import np from autoarray.abstract_ndarray import AbstractNDArray from autoarray.geometry.geometry_2d_irregular import Geometry2DIrregular from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.irregular import ArrayIrregular -from autoarray.structures.grids import grid_2d_util -from autoarray.geometry import geometry_util - logger = logging.getLogger(__name__) From bbed38d73450fa24588d9d9822125d28e4dee5bf Mon Sep 17 00:00:00 2001 From: Richard Hayes Date: Wed, 2 Apr 2025 10:15:29 +0100 Subject: [PATCH 079/108] Update autoarray/plot/multi_plotters.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- autoarray/plot/multi_plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/plot/multi_plotters.py b/autoarray/plot/multi_plotters.py index a58d08c02..84926a416 100644 --- a/autoarray/plot/multi_plotters.py +++ b/autoarray/plot/multi_plotters.py @@ -301,7 +301,7 @@ def output_to_fits( The list of function names that are called to plot the figures on the subplot. figure_name_list The list of figure names that are plotted on the subplot. - filenane + filename The filename that the .fits file is output to. tag_list The list of tags that are used to set the `EXTNAME` of each hdu of the .fits file. From 8d2b338c18663a7908ff9d49843eb8aa7a8f087c Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 15:57:22 +0100 Subject: [PATCH 080/108] fix over sampler test --- autoarray/operators/over_sampling/over_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 6440a2ee7..0aafbe008 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -226,7 +226,7 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": # sub_size=np.array(self.sub_size).astype("int"), # ) - binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]).mean(axis=1) + binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]**2).mean(axis=1) return Array2D( values=binned_array_2d, From e2ddf0f2e7ad565e6eae90cbfee7a4634c868ae7 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 19:14:15 +0100 Subject: [PATCH 081/108] removal of convolver and switch to psf where required --- autoarray/__init__.py | 2 - autoarray/dataset/imaging/dataset.py | 34 +- autoarray/dataset/interferometer/dataset.py | 2 +- autoarray/dataset/preprocess.py | 2 +- autoarray/fixtures.py | 6 - .../inversion/inversion/dataset_interface.py | 6 +- autoarray/inversion/inversion/factory.py | 3 +- .../inversion/inversion/imaging/abstract.py | 15 +- .../inversion/inversion/imaging/mapping.py | 10 +- .../inversion/inversion/imaging/w_tilde.py | 9 +- autoarray/inversion/linear_obj/linear_obj.py | 2 +- .../inversion/mock/mock_inversion_imaging.py | 8 +- autoarray/mask/derive/mask_2d.py | 4 +- autoarray/mock.py | 1 - autoarray/operators/convolver.py | 592 --------- autoarray/operators/mock/mock_convolver.py | 2 +- autoarray/structures/arrays/kernel_2d.py | 25 +- test_autoarray/conftest.py | 5 - .../dataset/imaging/test_dataset.py | 6 +- .../inversion/imaging/test_imaging.py | 20 +- .../imaging/test_inversion_imaging_util.py | 12 +- test_autoarray/operators/test_convolver.py | 1164 ----------------- .../structures/arrays/test_kernel_2d.py | 334 ++++- 23 files changed, 374 insertions(+), 1890 deletions(-) delete mode 100644 autoarray/operators/convolver.py delete mode 100644 test_autoarray/operators/test_convolver.py diff --git a/autoarray/__init__.py b/autoarray/__init__.py index a212002f5..1a6d79c53 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -54,8 +54,6 @@ from .mask.derive.grid_2d import DeriveGrid2D from .mask.mask_1d import Mask1D from .mask.mask_2d import Mask2D -from .operators.convolver import Convolver -from .operators.convolver import Convolver from .operators.transformer import TransformerDFT from .operators.transformer import TransformerNUFFT from .operators.over_sampling.decorator import over_sample diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 20bb8b0be..7fa3c515f 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -9,7 +9,6 @@ from autoarray.dataset.grids import GridsDataset from autoarray.dataset.imaging.w_tilde import WTildeImaging from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.operators.convolver import Convolver from autoarray.structures.arrays.kernel_2d import Kernel2D from autoarray.mask.mask_2d import Mask2D from autoarray import type as ty @@ -30,7 +29,7 @@ def __init__( noise_covariance_matrix: Optional[np.ndarray] = None, over_sample_size_lp: Union[int, Array2D] = 4, over_sample_size_pixelization: Union[int, Array2D] = 4, - pad_for_convolver: bool = False, + pad_for_psf: bool = False, use_normalized_psf: Optional[bool] = True, check_noise_map: bool = True, ): @@ -77,7 +76,7 @@ def __init__( over_sample_size_pixelization How over sampling is performed for the grid which is associated with a pixelization, which is therefore passed into the calculations performed in the `inversion` module. - pad_for_convolver + pad_for_psf The PSF convolution may extend beyond the edges of the image mask, which can lead to edge effects in the convolved image. If `True`, the image and noise-map are padded to ensure the PSF convolution does not extend beyond the edge of the image. @@ -90,9 +89,9 @@ def __init__( self.unmasked = None - self.pad_for_convolver = pad_for_convolver + self.pad_for_psf = pad_for_psf - if pad_for_convolver and psf is not None: + if pad_for_psf and psf is not None: try: data.mask.derive_mask.blurring_from( kernel_shape_native=psf.shape_native @@ -176,25 +175,6 @@ def grids(self): psf=self.psf, ) - @cached_property - def convolver(self): - """ - Returns a `Convolver` from a mask and 2D PSF kernel. - - The `Convolver` stores in memory the array indexing between the mask and PSF, enabling efficient 2D PSF - convolution of images and matrices used for linear algebra calculations (see `operators.convolver`). - - This uses lazy allocation such that the calculation is only performed when the convolver is used, ensuring - efficient set up of the `Imaging` class. - - Returns - ------- - Convolver - The convolver given the masked imaging data's mask and PSF. - """ - - return Convolver(mask=self.mask, kernel=Kernel2D(values=self.psf._array, mask=self.psf.mask, header=self.psf.header)) - @cached_property def w_tilde(self): """ @@ -370,7 +350,7 @@ def apply_mask(self, mask: Mask2D) -> "Imaging": noise_covariance_matrix=noise_covariance_matrix, over_sample_size_lp=over_sample_size_lp, over_sample_size_pixelization=over_sample_size_pixelization, - pad_for_convolver=True, + pad_for_psf=True, ) dataset.unmasked = unmasked_dataset @@ -463,7 +443,7 @@ def apply_noise_scaling( noise_covariance_matrix=self.noise_covariance_matrix, over_sample_size_lp=self.over_sample_size_lp, over_sample_size_pixelization=self.over_sample_size_pixelization, - pad_for_convolver=False, + pad_for_psf=False, check_noise_map=False, ) @@ -511,7 +491,7 @@ def apply_over_sampling( over_sample_size_lp=over_sample_size_lp or self.over_sample_size_lp, over_sample_size_pixelization=over_sample_size_pixelization or self.over_sample_size_pixelization, - pad_for_convolver=False, + pad_for_psf=False, check_noise_map=False, ) diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 21dd600a6..06892ea68 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -276,5 +276,5 @@ def output_to_fits( ) @property - def convolver(self): + def psf(self): return None diff --git a/autoarray/dataset/preprocess.py b/autoarray/dataset/preprocess.py index 046ab3ea8..5c7338204 100644 --- a/autoarray/dataset/preprocess.py +++ b/autoarray/dataset/preprocess.py @@ -328,7 +328,7 @@ def background_noise_map_via_edges_from(image, no_edges): def psf_with_odd_dimensions_from(psf): """ If the PSF kernel has one or two even-sized dimensions, return a PSF object where the kernel has odd-sized - dimensions (odd-sized dimensions are required by a *Convolver*). + dimensions (odd-sized dimensions are required for 2D convolution). Kernels are rescaled using the scikit-image routine rescale, which performs rescaling via an interpolation routine. This may lead to loss of accuracy in the PSF kernel and it is advised that users, where possible, diff --git a/autoarray/fixtures.py b/autoarray/fixtures.py index acfd277df..68a8a73f8 100644 --- a/autoarray/fixtures.py +++ b/autoarray/fixtures.py @@ -110,12 +110,6 @@ def make_blurring_grid_2d_7x7(): return aa.Grid2D.from_mask(mask=make_blurring_mask_2d_7x7()) -# CONVOLVERS # - - -def make_convolver_7x7(): - return aa.Convolver(mask=make_mask_2d_7x7(), kernel=make_psf_3x3()) - def make_image_7x7(): return aa.Array2D.ones(shape_native=(7, 7), pixel_scales=(1.0, 1.0)) diff --git a/autoarray/inversion/inversion/dataset_interface.py b/autoarray/inversion/inversion/dataset_interface.py index abc780411..0e417e71a 100644 --- a/autoarray/inversion/inversion/dataset_interface.py +++ b/autoarray/inversion/inversion/dataset_interface.py @@ -4,7 +4,7 @@ def __init__( data, noise_map, grids=None, - convolver=None, + psf=None, transformer=None, w_tilde=None, noise_covariance_matrix=None, @@ -41,7 +41,7 @@ def __init__( border_relocator The border relocator, which relocates coordinates outside the border of the source-plane data grid to its edge. - convolver + psf Perform 2D convolution of the imaigng data's PSF when computing the operated mapping matrix. transformer Performs a Fourier transform of the image-data from real-space to visibilities when computing the @@ -59,7 +59,7 @@ def __init__( self.data = data self.noise_map = noise_map self.grids = grids - self.convolver = convolver + self.psf = psf self.transformer = transformer self.w_tilde = w_tilde self.noise_covariance_matrix = noise_covariance_matrix diff --git a/autoarray/inversion/inversion/factory.py b/autoarray/inversion/inversion/factory.py index 31c69d57b..327262786 100644 --- a/autoarray/inversion/inversion/factory.py +++ b/autoarray/inversion/inversion/factory.py @@ -172,8 +172,7 @@ def inversion_interferometer_from( dataset The dataset (e.g. `Interferometer`) whose data is reconstructed via the `Inversion`. w_tilde - Object which uses the `Imaging` dataset's PSF / `Convolver` operateor to perform the `Inversion` using the - w-tilde formalism. + Object which uses the `Imaging` dataset's PSF to perform the `Inversion` using the w-tilde formalism. linear_obj_list The list of linear objects (e.g. analytic functions, a mapper with a pixelized grid) which reconstruct the input dataset's data and whose values are solved for via the inversion. diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 1ea826f88..5efc4d0a9 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -32,8 +32,7 @@ def __init__( of the linear object parameters that best reconstruct the dataset to be solved, via linear matrix algebra. This object contains matrices and vectors which perform an inversion for fits to an `Imaging` dataset. This - includes operations which use a PSF / `Convolver` in order to incorporate blurring into the solved for - linear object pixels. + includes operations which use a PSF in order to incorporate blurring into the solved for linear object pixels. The inversion may be regularized, whereby the parameters of the linear objects used to reconstruct the data are smoothed with one another such that their solved for values conform to certain properties (e.g. smoothness @@ -76,8 +75,8 @@ def __init__( ) @property - def convolver(self): - return self.dataset.convolver + def psf(self): + return self.dataset.psf @property def operated_mapping_matrix_list(self) -> List[np.ndarray]: @@ -88,7 +87,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: This is used to construct the simultaneous linear equations which reconstruct the data. This property returns the a list of each linear object's blurred mapping matrix, which is computed by - blurring each linear object's `mapping_matrix` property with the `Convolver` operator. + blurring each linear object's `mapping_matrix` property with the `psf` operator. A linear object may have a `operated_mapping_matrix_override` property, which bypasses the `mapping_matrix` computation and convolution operator and is directly placed in the `operated_mapping_matrix_list`. @@ -96,7 +95,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: return [ ( - self.convolver.convolve_mapping_matrix( + self.psf.convolve_mapping_matrix( mapping_matrix=linear_obj.mapping_matrix ) if linear_obj.operated_mapping_matrix_override is None @@ -139,7 +138,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: if linear_func.operated_mapping_matrix_override is not None: operated_mapping_matrix = linear_func.operated_mapping_matrix_override else: - operated_mapping_matrix = self.convolver.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolve_mapping_matrix( mapping_matrix=linear_func.mapping_matrix ) @@ -221,7 +220,7 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict: mapper_operated_mapping_matrix_dict = {} for mapper in self.cls_list_from(cls=AbstractMapper): - operated_mapping_matrix = self.convolver.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolve_mapping_matrix( mapping_matrix=mapper.mapping_matrix ) diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 5b3e40966..03d73ff63 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -39,10 +39,8 @@ def __init__( Parameters ---------- - noise_map - The noise-map of the observed imaging data which values are solved for. - convolver - The convolver which performs a 2D convolution on the mapping matrix with the imaging data's PSF. + dataset + The dataset containing the image data, noise-map and psf which is fitted by the inversion. linear_obj_list The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed the simultaneous linear equations are combined and solved simultaneously. @@ -80,7 +78,7 @@ def _data_vector_mapper(self) -> np.ndarray: mapper = mapper_list[i] param_range = mapper_param_range_list[i] - operated_mapping_matrix = self.convolver.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolve_mapping_matrix( mapping_matrix=mapper.mapping_matrix ) @@ -142,7 +140,7 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: mapper_i = mapper_list[i] mapper_param_range_i = mapper_param_range_list[i] - operated_mapping_matrix = self.convolver.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolve_mapping_matrix( mapping_matrix=mapper_i.mapping_matrix ) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index 5f791873c..c249f4ce9 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -42,11 +42,8 @@ def __init__( Parameters ---------- - noise_map - The noise-map of the observed imaging data which values are solved for. - convolver - The convolver used to perform 2D convolution of the imaigng data's PSF when computing the operated - mapping matrix. + dataset + The dataset containing the image data, noise-map and psf which is fitted by the inversion. w_tilde An object containing matrices that construct the linear equations via the w-tilde formalism which bypasses the mapping matrix. @@ -76,7 +73,7 @@ def w_tilde_data(self): return inversion_imaging_util.w_tilde_data_imaging_from( image_native=np.array(self.data.native), noise_map_native=np.array(self.noise_map.native), - kernel_native=np.array(self.convolver.kernel.native), + kernel_native=np.array(self.psf.kernel.native), native_index_for_slim_index=self.data.mask.derive_indexes.native_for_slim, ) diff --git a/autoarray/inversion/linear_obj/linear_obj.py b/autoarray/inversion/linear_obj/linear_obj.py index 3bcc1ae57..1e0bacc85 100644 --- a/autoarray/inversion/linear_obj/linear_obj.py +++ b/autoarray/inversion/linear_obj/linear_obj.py @@ -124,7 +124,7 @@ def pixel_signals_from(self, signal_scale) -> np.ndarray: def operated_mapping_matrix_override(self) -> Optional[np.ndarray]: """ An `Inversion` takes the `mapping_matrix` of each linear object and combines it with the data's operators - (e.g. a `Convolver` for `Imaging` data) to compute the `operated_mapping_matrix`. + (e.g. a PSF for `Imaging` data) to compute the `operated_mapping_matrix`. If this property is overwritten this operation is not performed, with the `operated_mapping_matrix` output by this property automatically used instead. diff --git a/autoarray/inversion/mock/mock_inversion_imaging.py b/autoarray/inversion/mock/mock_inversion_imaging.py index 64076db56..de7f2baa1 100644 --- a/autoarray/inversion/mock/mock_inversion_imaging.py +++ b/autoarray/inversion/mock/mock_inversion_imaging.py @@ -12,7 +12,7 @@ def __init__( self, data=None, noise_map=None, - convolver=None, + psf=None, linear_obj_list=None, operated_mapping_matrix=None, linear_func_operated_mapping_matrix_dict=None, @@ -22,7 +22,7 @@ def __init__( dataset = DatasetInterface( data=data, noise_map=noise_map, - convolver=convolver, + psf=psf, ) super().__init__( @@ -70,7 +70,7 @@ def __init__( self, data=None, noise_map=None, - convolver=None, + psf=None, w_tilde=None, linear_obj_list=None, curvature_matrix_mapper_diag=None, @@ -79,7 +79,7 @@ def __init__( dataset = DatasetInterface( data=data, noise_map=noise_map, - convolver=convolver, + psf=psf, ) super().__init__( diff --git a/autoarray/mask/derive/mask_2d.py b/autoarray/mask/derive/mask_2d.py index 4d3090a52..9332e273d 100644 --- a/autoarray/mask/derive/mask_2d.py +++ b/autoarray/mask/derive/mask_2d.py @@ -146,8 +146,8 @@ def blurring_from(self, kernel_shape_native: Tuple[int, int]) -> Mask2D: Returns a blurring ``Mask2D``, representing all masked pixels (given by ``True``) whose values are blurred into unmasked pixels (given by ``False``) when a 2D convolution is performed. - This mask is used by the ``Convolver2D`` object to ensure that 2D convolution can be performed on masked - data structures without missing values. + This mask is used by the PSF to ensure that 2D convolution can be performed on masked data structures without + missing values. For example, for the following ``Mask2D``: diff --git a/autoarray/mock.py b/autoarray/mock.py index 78a857e7e..2261ba5e2 100644 --- a/autoarray/mock.py +++ b/autoarray/mock.py @@ -15,7 +15,6 @@ from autoarray.fit.mock.mock_fit_imaging import MockFitImaging from autoarray.fit.mock.mock_fit_interferometer import MockFitInterferometer from autoarray.mask.mock.mock_mask import MockMask -from autoarray.operators.mock.mock_convolver import MockConvolver from autoarray.structures.mock.mock_grid import MockGrid2DMesh from autoarray.structures.mock.mock_grid import MockMeshGrid from autoarray.structures.mock.mock_decorators import MockGridRadialMinimum diff --git a/autoarray/operators/convolver.py b/autoarray/operators/convolver.py deleted file mode 100644 index 73d3767d5..000000000 --- a/autoarray/operators/convolver.py +++ /dev/null @@ -1,592 +0,0 @@ -from autoarray import numba_util -import numpy as np - -from autoarray.structures.arrays.uniform_2d import Array2D - -from autoarray import exc -from autoarray.mask import mask_2d_util - -from os import environ - -use_jax = environ.get("USE_JAX", "0") == "1" - -if use_jax: - import jax - import jax.numpy as jnp - - -class Convolver: - def __init__(self, mask, kernel): - """ - Class to setup the 1D convolution of an / mapping matrix. - - Take a simple 3x3 and masks: - - [[2, 8, 2], - [5, 7, 5], - [3, 1, 4]] - - [[True, False, True], (True means that the value is masked) - [False, False, False], - [True, False, True]] - - A set of values in a corresponding 1d array of this might be represented as: - - [2, 8, 2, 5, 7, 5, 3, 1, 4] - - and after masking as: - - [8, 5, 7, 5, 1] - - Setup is required to perform 2D real-space convolution on the masked array. This module finds the \ - relationship between the unmasked 2D data, masked data and kernel, so that 2D real-space convolutions \ - can be efficiently applied to reduced 1D masked structures. - - This calculation also accounts for the blurring of light outside of the masked regions which blurs into \ - the masked region. - - - **IMAGE FRAMES:** - - For a masked in 2D, one can compute for every pixel all of the unmasked pixels it will blur light into for \ - a given PSF kernel size, e.g.: - - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI This is an imaging.Mask2D, where: - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI x = `True` (Pixel is masked and excluded from lens) - IxIxIxIoIoIoIxIxIxIxI o = `False` (Pixel is not masked and included in lens) - IxIxIxIoIoIoIxIxIxIxI - IxIxIxIoIoIoIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - - Here, there are 9 unmasked pixels. Indexing of each unmasked pixel goes from the top-left corner right and \ - downwards, therefore: - - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxI0I1I2IxIxIxIxI - IxIxIxI3I4I5IxIxIxIxI - IxIxIxI6I7I8IxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - - For every unmasked pixel, the Convolver over-lays the PSF and computes three quantities; - - image_frame_indexes - The indexes of all masked pixels it will blur light into. - image_frame_kernels - The kernel values that overlap each masked pixel it will blur light into. - image_frame_length - The number of masked pixels it will blur light into (unmasked pixels are excluded) - - For example, if we had the following 3x3 kernel: - - I0.1I0.2I0.3I - I0.4I0.5I0.6I - I0.7I0.8I0.9I - - For pixel 0 above, when we overlap the kernel 4 unmasked pixels overlap this kernel, such that: - - image_frame_indexes = [0, 1, 3, 4] - image_frame_kernels = [0.5, 0.6, 0.8, 0.9] - image_frame_length = 4 - - Noting that the other 5 kernel values (0.1, 0.2, 0.3, 0.4, 0.7) overlap masked pixels and are thus discarded. - - For pixel 1, we get the following results: - - image_frame_indexes = [0, 1, 2, 3, 4, 5] - image_frame_kernels = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9] - image_frame_lengths = 6 - - In the majority of cases, the kernel will overlap only unmasked pixels. This is the case above for \ - central pixel 4, where: - - image_frame_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 8] - image_frame_kernels = [0,1, 0.2, 0,3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] - image_frame_lengths = 9 - - Once we have set up all these quantities, the convolution routine simply uses them to convolve a 1D array of a - masked or the masked of a util in the inversion module. - - - **BLURRING FRAMES:** - - Whilst the scheme above accounts for all blurred light within the masks, it does not account for the fact that \ - pixels outside of the masks will also blur light into it. This effect is accounted for using blurring frames. - - It is omitted for mapping matrix blurring, as an inversion does not fit data outside of the masks. - - First, a blurring masks is computed from a masks, which describes all pixels which are close enough to the masks \ - to blur light into it for a given kernel size. Following the example above, the following blurring masks is \ - computed: - - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI This is an example grid.Mask2D, where: - IxIxIxIxIxIxIxIxIxIxI - IxIxIoIoIoIoIoIxIxIxI x = `True` (Pixel is masked and excluded from lens) - IxIxIoIxIxIxIoIxIxIxI o = `False` (Pixel is not masked and included in lens) - IxIxIoIxIxIxIoIxIxIxI - IxIxIoIxIxIxIoIxIxIxI - IxIxIoIoIoIoIoIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - - Indexing again goes from the top-left corner right and downwards: - - IxIxI xI xI xI xI xIxIxIxI - IxIxI xI xI xI xI xIxIxIxI - IxIxI xI xI xI xI xIxIxIxI - IxIxI 0I 1I 2I 3I 4IxIxIxI - IxIxI 5I xI xI xI 6IxIxIxI - IxIxI 7I xI xI xI 8IxIxIxI - IxIxI 9I xI xI xI10IxIxIxI - IxIxI11I12I13I14I15IxIxIxI - IxIxI xI xI xI xI xIxIxIxI - IxIxI xI xI xI xI xIxIxIxI - - For every unmasked blurring-pixel, the Convolver over-lays the PSF kernel and computes three quantities; - - blurring_frame_indexes - The indexes of all unmasked pixels (not unmasked blurring pixels) it will \ - blur light into. - bluring_frame_kernels - The kernel values that overlap each pixel it will blur light into. - blurring_frame_length - The number of pixels it will blur light into. - - The blurring frame therefore does not perform any blurring which blurs light into other blurring pixels. \ - It only performs computations which add light inside of the masks. - - For pixel 0 above, when we overlap the 3x3 kernel above only 1 unmasked pixels overlaps the kernel, such that: - - blurring_frame_indexes = [0] (This 0 refers to pixel 0 within the masks, not blurring_frame_pixel 0) - blurring_frame_kernels = [0.9] - blurring_frame_length = 1 - - For pixel 1 above, when we overlap the 3x3 kernel above 2 unmasked pixels overlap the kernel, such that: - - blurring_frame_indexes = [0, 1] (This 0 and 1 refer to pixels 0 and 1 within the masks) - blurring_frame_kernels = [0.8, 0.9] - blurring_frame_length = 2 - - For pixel 3 above, when we overlap the 3x3 kernel above 3 unmasked pixels overlap the kernel, such that: - - blurring_frame_indexes = [0, 1, 2] (Again, these are pixels 0, 1 and 2) - blurring_frame_kernels = [0.7, 0.8, 0.9] - blurring_frame_length = 3 - - Parameters - ---------- - mask - The mask within which the convolved signal is calculated. - blurring_mask - A masks of pixels outside the masks but whose light blurs into it after PSF convolution. - kernel : grid.PSF or ndarray - An array representing a PSF. - """ - if kernel.shape_native[0] % 2 == 0 or kernel.shape_native[1] % 2 == 0: - raise exc.KernelException("PSF kernel must be odd") - - self.mask = mask - - self.mask_index_array = np.full(mask.shape, -1) - self.pixels_in_mask = int(np.size(mask) - np.sum(mask)) - - count = 0 - for x in range(mask.shape[0]): - for y in range(mask.shape[1]): - if not mask[x, y]: - self.mask_index_array[x, y] = count - count += 1 - - self.kernel = kernel - self.kernel_max_size = self.kernel.shape_native[0] * self.kernel.shape_native[1] - - mask_1d_index = 0 - self.image_frame_1d_indexes = np.zeros( - (self.pixels_in_mask, self.kernel_max_size), dtype="int" - ) - self.image_frame_1d_kernels = np.zeros( - (self.pixels_in_mask, self.kernel_max_size) - ) - self.image_frame_1d_lengths = np.zeros((self.pixels_in_mask), dtype="int") - for x in range(self.mask_index_array.shape[0]): - for y in range(self.mask_index_array.shape[1]): - if not mask[x][y]: - ( - image_frame_1d_indexes, - image_frame_1d_kernels, - ) = self.frame_at_coordinates_jit( - coordinates=(x, y), - mask=np.array(mask), - mask_index_array=self.mask_index_array, - kernel_2d=self.kernel.native, - ) - self.image_frame_1d_indexes[mask_1d_index, :] = ( - image_frame_1d_indexes - ) - self.image_frame_1d_kernels[mask_1d_index, :] = ( - image_frame_1d_kernels - ) - self.image_frame_1d_lengths[mask_1d_index] = image_frame_1d_indexes[ - image_frame_1d_indexes >= 0 - ].shape[0] - mask_1d_index += 1 - - self.blurring_mask = mask_2d_util.blurring_mask_2d_from( - mask_2d=np.array(mask), - kernel_shape_native=kernel.shape_native, - ) - - self.pixels_in_blurring_mask = int( - np.size(self.blurring_mask) - np.sum(self.blurring_mask) - ) - - mask_1d_index = 0 - self.blurring_frame_1d_indexes = np.zeros( - (self.pixels_in_blurring_mask, self.kernel_max_size), dtype="int" - ) - self.blurring_frame_1d_kernels = np.zeros( - (self.pixels_in_blurring_mask, self.kernel_max_size) - ) - self.blurring_frame_1d_lengths = np.zeros( - (self.pixels_in_blurring_mask), dtype="int" - ) - for x in range(mask.shape[0]): - for y in range(mask.shape[1]): - if mask[x][y] and not self.blurring_mask[x, y]: - ( - image_frame_1d_indexes, - image_frame_1d_kernels, - ) = self.frame_at_coordinates_jit( - coordinates=(x, y), - mask=np.array(mask), - mask_index_array=np.array(self.mask_index_array), - kernel_2d=np.array(self.kernel.native), - ) - self.blurring_frame_1d_indexes[mask_1d_index, :] = ( - image_frame_1d_indexes - ) - self.blurring_frame_1d_kernels[mask_1d_index, :] = ( - image_frame_1d_kernels - ) - self.blurring_frame_1d_lengths[mask_1d_index] = ( - image_frame_1d_indexes[image_frame_1d_indexes >= 0].shape[0] - ) - mask_1d_index += 1 - - @staticmethod - @numba_util.jit() - def frame_at_coordinates_jit(coordinates, mask, mask_index_array, kernel_2d): - """ - Returns the frame (indexes of pixels light is blurred into) and kernel_frame (kernel kernel values of those \ - pixels) for a given coordinate in a masks and its PSF. - - Parameters - ---------- - coordinates: Tuple[int, int] - The coordinates of mask_index_array on which the frame should be centred - kernel_shape_native: Tuple[int, int] - The shape of the kernel for which this frame will be used - """ - - kernel_shape_native = kernel_2d.shape - kernel_max_size = kernel_shape_native[0] * kernel_shape_native[1] - - half_x = int(kernel_shape_native[0] / 2) - half_y = int(kernel_shape_native[1] / 2) - - frame = -1 * np.ones((kernel_max_size)) - kernel_frame = -1.0 * np.ones((kernel_max_size)) - - count = 0 - for i in range(kernel_shape_native[0]): - for j in range(kernel_shape_native[1]): - x = coordinates[0] - half_x + i - y = coordinates[1] - half_y + j - if ( - 0 <= x < mask_index_array.shape[0] - and 0 <= y < mask_index_array.shape[1] - ): - value = mask_index_array[x, y] - if value >= 0 and not mask[x, y]: - frame[count] = value - kernel_frame[count] = kernel_2d[i, j] - count += 1 - - return frame, kernel_frame - - def jax_convolve(self, image, blurring_image, method="auto"): - slim_to_2D_index_image = jnp.nonzero( - jnp.logical_not(self.mask.array), size=image.shape[0] - ) - slim_to_2D_index_blurring = jnp.nonzero( - jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0] - ) - expanded_image_native = jnp.zeros(self.mask.shape) - expanded_image_native = expanded_image_native.at[slim_to_2D_index_image].set( - image.array - ) - expanded_image_native = expanded_image_native.at[slim_to_2D_index_blurring].set( - blurring_image.array - ) - kernel = np.array(self.kernel.native.array) - convolve_native = jax.scipy.signal.convolve( - expanded_image_native, kernel, mode="same", method=method - ) - convolve_slim = convolve_native[slim_to_2D_index_image] - return convolve_slim - - def convolve_image(self, image, blurring_image, jax_method="fft"): - """ - For a given 1D array and blurring array, convolve the two using this convolver. - - Parameters - ---------- - image - 1D array of the values which are to be blurred with the convolver's PSF. - blurring_image - 1D array of the blurring values which blur into the array after PSF convolution. - jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. - """ - - def exception_message(): - raise exc.KernelException( - "You cannot use the convolve_image function of a Convolver if the Convolver was" - "not created with a blurring_mask." - ) - - if use_jax: - jax.lax.cond( - self.blurring_mask is None, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - - return self.jax_convolve(image, blurring_image, method=jax_method) - - else: - if self.blurring_mask is None: - exception_message() - - convolved_image = self.convolve_jit( - image_1d_array=np.array(image.slim), - image_frame_1d_indexes=self.image_frame_1d_indexes, - image_frame_1d_kernels=self.image_frame_1d_kernels, - image_frame_1d_lengths=self.image_frame_1d_lengths, - blurring_1d_array=np.array(blurring_image.slim), - blurring_frame_1d_indexes=self.blurring_frame_1d_indexes, - blurring_frame_1d_kernels=self.blurring_frame_1d_kernels, - blurring_frame_1d_lengths=self.blurring_frame_1d_lengths, - ) - - return Array2D(values=convolved_image, mask=self.mask) - - @staticmethod - @numba_util.jit() - def convolve_jit( - image_1d_array, - image_frame_1d_indexes, - image_frame_1d_kernels, - image_frame_1d_lengths, - blurring_1d_array, - blurring_frame_1d_indexes, - blurring_frame_1d_kernels, - blurring_frame_1d_lengths, - ): - blurred_image_1d = np.zeros(image_1d_array.shape) - - for image_1d_index in range(len(image_1d_array)): - frame_1d_indexes = image_frame_1d_indexes[image_1d_index] - frame_1d_kernel = image_frame_1d_kernels[image_1d_index] - frame_1d_length = image_frame_1d_lengths[image_1d_index] - image_value = image_1d_array[image_1d_index] - - for kernel_1d_index in range(frame_1d_length): - vector_index = frame_1d_indexes[kernel_1d_index] - kernel_value = frame_1d_kernel[kernel_1d_index] - blurred_image_1d[vector_index] += image_value * kernel_value - - for blurring_1d_index in range(len(blurring_1d_array)): - frame_1d_indexes = blurring_frame_1d_indexes[blurring_1d_index] - frame_1d_kernel = blurring_frame_1d_kernels[blurring_1d_index] - frame_1d_length = blurring_frame_1d_lengths[blurring_1d_index] - image_value = blurring_1d_array[blurring_1d_index] - - for kernel_1d_index in range(frame_1d_length): - vector_index = frame_1d_indexes[kernel_1d_index] - kernel_value = frame_1d_kernel[kernel_1d_index] - blurred_image_1d[vector_index] += image_value * kernel_value - - return blurred_image_1d - - def convolve_image_no_blurring(self, image): - """For a given 1D array and blurring array, convolve the two using this convolver. - - Parameters - ---------- - image - 1D array of the values which are to be blurred with the convolver's PSF. - """ - - convolved_image = self.convolve_no_blurring_jit( - image_1d_array=np.array(image.slim), - image_frame_1d_indexes=self.image_frame_1d_indexes, - image_frame_1d_kernels=self.image_frame_1d_kernels, - image_frame_1d_lengths=self.image_frame_1d_lengths, - ) - - return Array2D(values=convolved_image, mask=self.mask) - - def convolve_image_no_blurring_interpolation(self, image): - """For a given 1D array and blurring array, convolve the two using this convolver. - - Parameters - ---------- - image - 1D array of the values which are to be blurred with the convolver's PSF. - """ - - convolved_image = self.convolve_no_blurring_jit( - image_1d_array=image, - image_frame_1d_indexes=self.image_frame_1d_indexes, - image_frame_1d_kernels=self.image_frame_1d_kernels, - image_frame_1d_lengths=self.image_frame_1d_lengths, - ) - - return Array2D(values=convolved_image, mask=self.mask) - - @staticmethod - @numba_util.jit() - def convolve_no_blurring_jit( - image_1d_array, - image_frame_1d_indexes, - image_frame_1d_kernels, - image_frame_1d_lengths, - ): - blurred_image_1d = np.zeros(image_1d_array.shape) - - for image_1d_index in range(len(image_1d_array)): - frame_1d_indexes = image_frame_1d_indexes[image_1d_index] - frame_1d_kernel = image_frame_1d_kernels[image_1d_index] - frame_1d_length = image_frame_1d_lengths[image_1d_index] - image_value = image_1d_array[image_1d_index] - - for kernel_1d_index in range(frame_1d_length): - vector_index = frame_1d_indexes[kernel_1d_index] - kernel_value = frame_1d_kernel[kernel_1d_index] - blurred_image_1d[vector_index] += image_value * kernel_value - - return blurred_image_1d - - def convolve_mapping_matrix(self, mapping_matrix): - """For a given inversion mapping matrix, convolve every pixel's mapped with the PSF kernel. - - A mapping matrix provides non-zero entries in all elements which map two pixels to one another - (see *inversions.mappers*). - - For example, lets take an which is masked using a 'cross' of 5 pixels: - - [[ True, False, True]], - [[False, False, False]], - [[ True, False, True]] - - As example mapping matrix of this cross is as follows (5 pixels x 3 source pixels): - - [1, 0, 0] [0->0] - [1, 0, 0] [1->0] - [0, 1, 0] [2->1] - [0, 1, 0] [3->1] - [0, 0, 1] [4->2] - - For each source-pixel, we can create an of its unit-surface brightnesses by util the non-zero - entries back to masks. For example, doing this for source pixel 1 gives: - - [[0.0, 1.0, 0.0]], - [[1.0, 0.0, 0.0]] - [[0.0, 0.0, 0.0]] - - And source pixel 2: - - [[0.0, 0.0, 0.0]], - [[0.0, 1.0, 1.0]] - [[0.0, 0.0, 0.0]] - - We then convolve each of these with our PSF kernel, in 2 dimensions, like we would a grid. For - example, using the kernel below: - - kernel: - - [[0.0, 0.1, 0.0]] - [[0.1, 0.6, 0.1]] - [[0.0, 0.1, 0.0]] - - Blurred Source Pixel 1 (we don't need to perform the convolution into masked pixels): - - [[0.0, 0.6, 0.0]], - [[0.6, 0.0, 0.0]], - [[0.0, 0.0, 0.0]] - - Blurred Source pixel 2: - - [[0.0, 0.0, 0.0]], - [[0.0, 0.7, 0.7]], - [[0.0, 0.0, 0.0]] - - Finally, we map each of these blurred back to a blurred mapping matrix, which is analogous to the - mapping matrix. - - [0.6, 0.0, 0.0] [0->0] - [0.6, 0.0, 0.0] [1->0] - [0.0, 0.7, 0.0] [2->1] - [0.0, 0.7, 0.0] [3->1] - [0.0, 0.0, 0.6] [4->2] - - If the mapping matrix is sub-gridded, we perform the convolution on the fractional surface brightnesses in an - identical fashion to above. - - Parameters - ---------- - mapping_matrix - The 2D mapping matrix describing how every inversion pixel maps to a pixel on the data pixel. - """ - return self.convolve_matrix_jit( - mapping_matrix=mapping_matrix, - image_frame_1d_indexes=self.image_frame_1d_indexes, - image_frame_1d_kernels=self.image_frame_1d_kernels, - image_frame_1d_lengths=self.image_frame_1d_lengths, - ) - - @staticmethod - @numba_util.jit() - def convolve_matrix_jit( - mapping_matrix, - image_frame_1d_indexes, - image_frame_1d_kernels, - image_frame_1d_lengths, - ): - blurred_mapping_matrix = np.zeros(mapping_matrix.shape) - - for pixel_1d_index in range(mapping_matrix.shape[1]): - for image_1d_index in range(mapping_matrix.shape[0]): - value = mapping_matrix[image_1d_index, pixel_1d_index] - - if value > 0: - frame_1d_indexes = image_frame_1d_indexes[image_1d_index] - frame_1d_kernel = image_frame_1d_kernels[image_1d_index] - frame_1d_length = image_frame_1d_lengths[image_1d_index] - - for kernel_1d_index in range(frame_1d_length): - vector_index = frame_1d_indexes[kernel_1d_index] - kernel_value = frame_1d_kernel[kernel_1d_index] - blurred_mapping_matrix[vector_index, pixel_1d_index] += ( - value * kernel_value - ) - - return blurred_mapping_matrix diff --git a/autoarray/operators/mock/mock_convolver.py b/autoarray/operators/mock/mock_convolver.py index 290be228f..7dc5dfcbb 100644 --- a/autoarray/operators/mock/mock_convolver.py +++ b/autoarray/operators/mock/mock_convolver.py @@ -1,4 +1,4 @@ -class MockConvolver: +class MockPSF: def __init__(self, operated_mapping_matrix=None): self.operated_mapping_matrix = operated_mapping_matrix diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index f80e3f6e3..09ff331ef 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -1,4 +1,6 @@ from astropy import units +import jax +import jax.numpy as jnp import numpy as np import scipy.signal from pathlib import Path @@ -361,7 +363,7 @@ def rescaled_with_odd_dimensions_from( ) -> "Kernel2D": """ If the PSF kernel has one or two even-sized dimensions, return a PSF object where the kernel has odd-sized - dimensions (odd-sized dimensions are required by a *Convolver*). + dimensions (odd-sized dimensions are required for 2D convolution). The PSF can be scaled to larger / smaller sizes than the input size, if the rescale factor uses values that deviate furher from 1.0. @@ -511,3 +513,24 @@ def convolved_array_with_mask_from(self, array: Array2D, mask: Mask2D) -> Array2 ) return Array2D(values=convolved_array_1d, mask=mask) + + def jax_convolve(self, image, blurring_image, method="auto"): + slim_to_2D_index_image = jnp.nonzero( + jnp.logical_not(self.mask.array), size=image.shape[0] + ) + slim_to_2D_index_blurring = jnp.nonzero( + jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0] + ) + expanded_image_native = jnp.zeros(self.mask.shape) + expanded_image_native = expanded_image_native.at[slim_to_2D_index_image].set( + image.array + ) + expanded_image_native = expanded_image_native.at[slim_to_2D_index_blurring].set( + blurring_image.array + ) + kernel = np.array(self.kernel.native.array) + convolve_native = jax.scipy.signal.convolve( + expanded_image_native, kernel, mode="same", method=method + ) + convolve_slim = convolve_native[slim_to_2D_index_image] + return convolve_slim \ No newline at end of file diff --git a/test_autoarray/conftest.py b/test_autoarray/conftest.py index 1dbe19e19..657ac4b0d 100644 --- a/test_autoarray/conftest.py +++ b/test_autoarray/conftest.py @@ -119,11 +119,6 @@ def make_psf_3x3(): return fixtures.make_psf_3x3() -@pytest.fixture(name="convolver_7x7") -def make_convolver_7x7(): - return fixtures.make_convolver_7x7() - - @pytest.fixture(name="imaging_7x7") def make_imaging_7x7(): return fixtures.make_imaging_7x7() diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index 6025733f5..9c079b94d 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -37,14 +37,14 @@ def test__psf_and_mask_hit_edge__automatically_pads_image_and_noise_map(): psf = aa.Kernel2D.ones(shape_native=(3, 3), pixel_scales=1.0) dataset = aa.Imaging( - data=image, noise_map=noise_map, psf=psf, pad_for_convolver=False + data=image, noise_map=noise_map, psf=psf, pad_for_psf=False ) assert dataset.data.shape_native == (3, 3) assert dataset.noise_map.shape_native == (3, 3) dataset = aa.Imaging( - data=image, noise_map=noise_map, psf=psf, pad_for_convolver=True + data=image, noise_map=noise_map, psf=psf, pad_for_psf=True ) assert dataset.data.shape_native == (5, 5) @@ -144,7 +144,6 @@ def test__apply_mask(imaging_7x7, mask_2d_7x7, psf_3x3): assert (masked_imaging_7x7.psf.slim == (1.0 / 3.0) * psf_3x3.slim).all() assert type(masked_imaging_7x7.psf) == aa.Kernel2D - assert type(masked_imaging_7x7.convolver) == aa.Convolver assert masked_imaging_7x7.w_tilde.curvature_preload.shape == (35,) assert masked_imaging_7x7.w_tilde.indexes.shape == (35,) assert masked_imaging_7x7.w_tilde.lengths.shape == (9,) @@ -226,7 +225,6 @@ def test__different_imaging_without_mock_objects__customize_constructor_inputs() assert masked_dataset.psf.native == pytest.approx( (1.0 / 49.0) * np.ones((7, 7)), 1.0e-4 ) - assert masked_dataset.convolver.kernel.shape_native == (7, 7) assert (masked_dataset.data == np.array([1.0])).all() assert (masked_dataset.noise_map == np.array([2.0])).all() diff --git a/test_autoarray/inversion/inversion/imaging/test_imaging.py b/test_autoarray/inversion/inversion/imaging/test_imaging.py index 46d726bad..df615ad85 100644 --- a/test_autoarray/inversion/inversion/imaging/test_imaging.py +++ b/test_autoarray/inversion/inversion/imaging/test_imaging.py @@ -11,18 +11,18 @@ directory = path.dirname(path.realpath(__file__)) -def test__operated_mapping_matrix_property(convolver_7x7, rectangular_mapper_7x7_3x3): +def test__operated_mapping_matrix_property(psf_7x7, rectangular_mapper_7x7_3x3): inversion = aa.m.MockInversionImaging( - convolver=convolver_7x7, linear_obj_list=[rectangular_mapper_7x7_3x3] + psf=psf_7x7, linear_obj_list=[rectangular_mapper_7x7_3x3] ) assert inversion.operated_mapping_matrix_list[0][0, 0] == pytest.approx(1.0, 1e-4) assert inversion.operated_mapping_matrix[0, 0] == pytest.approx(1.0, 1e-4) - convolver = aa.m.MockConvolver(operated_mapping_matrix=np.ones((2, 2))) + psf = aa.m.MockPSF(operated_mapping_matrix=np.ones((2, 2))) inversion = aa.m.MockInversionImaging( - convolver=convolver, + psf=psf, linear_obj_list=[rectangular_mapper_7x7_3x3, rectangular_mapper_7x7_3x3], ) @@ -42,9 +42,9 @@ def test__operated_mapping_matrix_property(convolver_7x7, rectangular_mapper_7x7 def test__operated_mapping_matrix_property__with_operated_mapping_matrix_override( - convolver_7x7, rectangular_mapper_7x7_3x3 + psf_7x7, rectangular_mapper_7x7_3x3 ): - convolver = aa.m.MockConvolver(operated_mapping_matrix=np.ones((2, 2))) + psf = aa.m.MockPSF(operated_mapping_matrix=np.ones((2, 2))) operated_mapping_matrix_override = np.array([[1.0, 2.0], [3.0, 4.0]]) @@ -54,7 +54,7 @@ def test__operated_mapping_matrix_property__with_operated_mapping_matrix_overrid ) inversion = aa.m.MockInversionImaging( - convolver=convolver, linear_obj_list=[rectangular_mapper_7x7_3x3, linear_obj] + psf=psf, linear_obj_list=[rectangular_mapper_7x7_3x3, linear_obj] ) operated_mapping_matrix_0 = np.array([[1.0, 1.0], [1.0, 1.0]]) @@ -73,7 +73,7 @@ def test__operated_mapping_matrix_property__with_operated_mapping_matrix_overrid def test__curvature_matrix(rectangular_mapper_7x7_3x3): noise_map = np.ones(2) - convolver = aa.m.MockConvolver(operated_mapping_matrix=np.ones((2, 10))) + psf = aa.m.MockPSF(operated_mapping_matrix=np.ones((2, 10))) operated_mapping_matrix_override = np.array([[1.0, 2.0], [3.0, 4.0]]) @@ -87,7 +87,7 @@ def test__curvature_matrix(rectangular_mapper_7x7_3x3): dataset = aa.DatasetInterface( data=np.ones(2), noise_map=noise_map, - convolver=convolver, + psf=psf, ) inversion = aa.InversionImagingMapping( @@ -135,7 +135,7 @@ def test__w_tilde_checks_noise_map_and_raises_exception_if_preloads_dont_match_n dataset = aa.DatasetInterface( data=np.ones(9), noise_map=np.ones(9), - convolver=aa.m.MockConvolver(matrix_shape), + psf=aa.m.MockPSF(matrix_shape), ) # noinspection PyTypeChecker diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index 711135ebc..b616a95d4 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -181,7 +181,7 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): shape_native=(7, 7), pixel_scales=mask.pixel_scales, sigma=1.0, normalize=True ) - convolver = aa.Convolver(mask=mask, kernel=kernel) + psf = kernel pixelization = aa.mesh.Rectangular(shape=(20, 20)) @@ -203,7 +203,7 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): mapping_matrix = mapper.mapping_matrix - blurred_mapping_matrix = convolver.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolve_mapping_matrix( mapping_matrix=mapping_matrix ) @@ -258,7 +258,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): shape_native=(7, 7), pixel_scales=mask.pixel_scales, sigma=1.0, normalize=True ) - convolver = aa.Convolver(mask=mask, kernel=kernel) + psf = kernel pixelization = aa.mesh.Rectangular(shape=(20, 20)) @@ -282,7 +282,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): w_tilde=w_tilde, mapping_matrix=mapping_matrix ) - blurred_mapping_matrix = convolver.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolve_mapping_matrix( mapping_matrix=mapping_matrix ) @@ -303,7 +303,7 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): shape_native=(7, 7), pixel_scales=mask.pixel_scales, sigma=1.0, normalize=True ) - convolver = aa.Convolver(mask=mask, kernel=kernel) + psf = kernel pixelization = aa.mesh.Rectangular(shape=(20, 20)) @@ -356,7 +356,7 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): pix_pixels=pixelization.pixels, ) - blurred_mapping_matrix = convolver.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolve_mapping_matrix( mapping_matrix=mapping_matrix ) diff --git a/test_autoarray/operators/test_convolver.py b/test_autoarray/operators/test_convolver.py deleted file mode 100644 index c298dda31..000000000 --- a/test_autoarray/operators/test_convolver.py +++ /dev/null @@ -1,1164 +0,0 @@ -import numpy as np -import pytest - -import autoarray as aa -from autoarray import exc - - -@pytest.fixture(name="simple_mask_2d_7x7") -def make_simple_mask_2d_7x7(): - mask = [ - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - [True, True, False, False, False, True, True], - [True, True, False, False, False, True, True], - [True, True, False, False, False, True, True], - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - ] - - return aa.Mask2D(mask=mask, pixel_scales=1.0) - - -@pytest.fixture(name="simple_mask_5x5") -def make_simple_mask_5x5(): - mask = [ - [True, True, True, True, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, False, False, False, True], - [True, True, True, True, True], - ] - - return aa.Mask2D(mask=mask, pixel_scales=1.0) - - -@pytest.fixture(name="cross_mask") -def make_cross_mask(): - mask = np.full((5, 5), True) - - mask[2, 2] = False - mask[1, 2] = False - mask[3, 2] = False - mask[2, 1] = False - mask[2, 3] = False - - return aa.Mask2D(mask=mask, pixel_scales=1.0) - - -def test__numbering__uses_mask_correctly(simple_mask_5x5, cross_mask): - convolver = aa.Convolver( - mask=simple_mask_5x5, - kernel=aa.Kernel2D.ones(shape_native=(1, 1), pixel_scales=1.0), - ) - - mask_index_array = convolver.mask_index_array - - assert mask_index_array.shape == (5, 5) - # noinspection PyUnresolvedReferences - assert ( - mask_index_array - == np.array( - [ - [-1, -1, -1, -1, -1], - [-1, 0, 1, 2, -1], - [-1, 3, 4, 5, -1], - [-1, 6, 7, 8, -1], - [-1, -1, -1, -1, -1], - ] - ) - ).all() - - convolver = aa.Convolver( - mask=cross_mask, kernel=aa.Kernel2D.ones(shape_native=(1, 1), pixel_scales=1.0) - ) - - assert ( - convolver.mask_index_array - == np.array( - [ - [-1, -1, -1, -1, -1], - [-1, -1, 0, -1, -1], - [-1, 1, 2, 3, -1], - [-1, -1, 4, -1, -1], - [-1, -1, -1, -1, -1], - ] - ) - ).all() - - -def test__even_kernel_failure(): - with pytest.raises(exc.KernelException): - aa.Convolver( - mask=np.full((3, 3), False), - kernel=aa.Kernel2D.ones(shape_native=(2, 2), pixel_scales=1.0), - ) - - -def test__frame_extraction__frame_and_kernel_frame_at_coords(simple_mask_5x5): - convolver = aa.Convolver( - mask=simple_mask_5x5, - kernel=aa.Kernel2D.no_mask( - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], pixel_scales=1.0 - ), - ) - - frame, kernel_frame = convolver.frame_at_coordinates_jit( - coordinates=(2, 2), - mask=np.array(simple_mask_5x5), - mask_index_array=convolver.mask_index_array, - kernel_2d=np.array(convolver.kernel.native), - ) - - assert (frame == np.array([i for i in range(9)])).all() - - assert ( - kernel_frame == np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]]) - ).all() - - corner_frame = np.array([0, 1, 3, 4, -1, -1, -1, -1, -1]) - - frame, kernel_frame = convolver.frame_at_coordinates_jit( - coordinates=(1, 1), - mask=np.array(simple_mask_5x5), - mask_index_array=convolver.mask_index_array, - kernel_2d=np.array(convolver.kernel.native), - ) - - assert (frame == corner_frame).all() - - frame, kernel_frame = convolver.frame_at_coordinates_jit( - coordinates=(1, 1), - mask=np.array(simple_mask_5x5), - mask_index_array=convolver.mask_index_array, - kernel_2d=np.array(convolver.kernel.native), - ) - - assert (kernel_frame == np.array([5.0, 6.0, 8.0, 9.0, -1, -1, -1, -1, -1])).all() - - assert 9 == len(convolver.image_frame_1d_indexes) - - assert ( - convolver.image_frame_1d_indexes[4] == np.array([i for i in range(9)]) - ).all() - - -def test__frame_extraction__more_complicated_frames(simple_mask_2d_7x7): - convolver = aa.Convolver( - mask=simple_mask_2d_7x7, - kernel=aa.Kernel2D.no_mask( - [ - [1.0, 2.0, 3.0, 4.0, 5.0], - [6.0, 7.0, 8.0, 9.0, 10.0], - [11.0, 12.0, 13.0, 14.0, 15.0], - [16.0, 17.0, 18.0, 19.0, 20.0], - [21.0, 22.0, 23.0, 24.0, 25.0], - ], - pixel_scales=1.0, - ), - ) - - frame, kernel_frame = convolver.frame_at_coordinates_jit( - coordinates=(2, 2), - mask=np.array(simple_mask_2d_7x7), - mask_index_array=convolver.mask_index_array, - kernel_2d=np.array(convolver.kernel.native), - ) - - assert ( - kernel_frame - == np.array( - [ - 13.0, - 14.0, - 15.0, - 18.0, - 19.0, - 20.0, - 23.0, - 24.0, - 25.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ] - ) - ).all() - - frame, kernel_frame = convolver.frame_at_coordinates_jit( - coordinates=(3, 2), - mask=np.array(simple_mask_2d_7x7), - mask_index_array=convolver.mask_index_array, - kernel_2d=np.array(convolver.kernel.native), - ) - - assert ( - kernel_frame - == np.array( - [ - 8.0, - 9.0, - 10.0, - 13.0, - 14.0, - 15.0, - 18.0, - 19.0, - 20.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1, - ] - ) - ).all() - - frame, kernel_frame = convolver.frame_at_coordinates_jit( - coordinates=(3, 3), - mask=np.array(simple_mask_2d_7x7), - mask_index_array=convolver.mask_index_array, - kernel_2d=np.array(convolver.kernel.native), - ) - - assert ( - kernel_frame - == np.array( - [ - 7.0, - 8.0, - 9.0, - 12.0, - 13.0, - 14.0, - 17.0, - 18.0, - 19.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1, - ] - ) - ).all() - - -def test__image_frame_indexes__for_different_masks(cross_mask, simple_mask_2d_7x7): - convolver = aa.Convolver( - mask=cross_mask, - kernel=aa.Kernel2D.no_mask( - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], pixel_scales=1.0 - ), - ) - - assert 5 == len(convolver.image_frame_1d_indexes) - - assert ( - convolver.image_frame_1d_indexes[0] - == np.array([0, 1, 2, 3, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[1] - == np.array([0, 1, 2, 4, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[2] == np.array([0, 1, 2, 3, 4, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[3] - == np.array([0, 2, 3, 4, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[4] - == np.array([1, 2, 3, 4, -1, -1, -1, -1, -1]) - ).all() - - convolver = aa.Convolver( - mask=simple_mask_2d_7x7, - kernel=aa.Kernel2D.ones(shape_native=(3, 5), pixel_scales=1.0), - ) - - assert ( - convolver.image_frame_1d_indexes[0] - == np.array([0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[1] - == np.array([0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[2] - == np.array([0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[3] - == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[4] - == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[5] - == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[6] - == np.array([3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[7] - == np.array([3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[8] - == np.array([3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - - convolver = aa.Convolver( - mask=simple_mask_2d_7x7, - kernel=aa.Kernel2D.ones(shape_native=(5, 3), pixel_scales=1.0), - ) - - assert ( - convolver.image_frame_1d_indexes[0] - == np.array([0, 1, 3, 4, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[1] - == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[2] - == np.array([1, 2, 4, 5, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[3] - == np.array([0, 1, 3, 4, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[4] - == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[5] - == np.array([1, 2, 4, 5, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[6] - == np.array([0, 1, 3, 4, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[7] - == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_indexes[8] - == np.array([1, 2, 4, 5, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - - convolver = aa.Convolver( - mask=simple_mask_2d_7x7, - kernel=aa.Kernel2D.ones(shape_native=(5, 5), pixel_scales=1.0), - ) - - assert ( - convolver.image_frame_1d_indexes[0] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - assert ( - convolver.image_frame_1d_indexes[1] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - assert ( - convolver.image_frame_1d_indexes[2] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - assert ( - convolver.image_frame_1d_indexes[3] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - assert ( - convolver.image_frame_1d_indexes[4] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - assert ( - convolver.image_frame_1d_indexes[5] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - assert ( - convolver.image_frame_1d_indexes[6] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - assert ( - convolver.image_frame_1d_indexes[7] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - assert ( - convolver.image_frame_1d_indexes[8] - == np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - ] - ) - ).all() - - -def test_image_frame_kernels__different_shape_masks( - simple_mask_5x5, simple_mask_2d_7x7 -): - convolver = aa.Convolver( - mask=simple_mask_5x5, - kernel=aa.Kernel2D.no_mask( - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], pixel_scales=1.0 - ), - ) - - assert 9 == len(convolver.image_frame_1d_indexes) - - assert ( - convolver.image_frame_1d_kernels[4] - == np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]]) - ).all() - - convolver = aa.Convolver( - mask=simple_mask_2d_7x7, - kernel=aa.Kernel2D.no_mask( - [ - [1.0, 2.0, 3.0, 4.0, 5.0], - [6.0, 7.0, 8.0, 9.0, 10.0], - [11.0, 12.0, 13.0, 14.0, 15.0], - ], - pixel_scales=1.0, - ), - ) - - assert ( - convolver.image_frame_1d_kernels[0] - == np.array( - [8.0, 9.0, 10.0, 13.0, 14.0, 15.0, -1, -1, -1, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[1] - == np.array( - [7.0, 8.0, 9.0, 12.0, 13.0, 14.0, -1, -1, -1, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[2] - == np.array( - [6.0, 7.0, 8.0, 11.0, 12.0, 13.0, -1, -1, -1, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[3] - == np.array( - [3.0, 4.0, 5.0, 8.0, 9.0, 10.0, 13.0, 14.0, 15.0, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[4] - == np.array( - [2.0, 3.0, 4.0, 7.0, 8.0, 9.0, 12.0, 13.0, 14.0, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[5] - == np.array( - [1.0, 2.0, 3.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[6] - == np.array([3.0, 4.0, 5.0, 8.0, 9.0, 10.0, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_kernels[7] - == np.array([2.0, 3.0, 4.0, 7.0, 8.0, 9.0, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_kernels[8] - == np.array([1.0, 2.0, 3.0, 6.0, 7.0, 8.0, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - - convolver = aa.Convolver( - mask=simple_mask_2d_7x7, - kernel=aa.Kernel2D.no_mask( - [ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - [7.0, 8.0, 9.0], - [10.0, 11.0, 12.0], - [13.0, 14.0, 15.0], - ], - pixel_scales=1.0, - ), - ) - - assert ( - convolver.image_frame_1d_kernels[0] - == np.array( - [8.0, 9.0, 11.0, 12.0, 14.0, 15.0, -1, -1, -1, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[1] - == np.array( - [7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[2] - == np.array( - [7.0, 8.0, 10.0, 11.0, 13.0, 14.0, -1, -1, -1, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[3] - == np.array( - [5.0, 6.0, 8.0, 9.0, 11.0, 12.0, -1, -1, -1, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[4] - == np.array( - [4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[5] - == np.array( - [4.0, 5.0, 7.0, 8.0, 10.0, 11.0, -1, -1, -1, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[6] - == np.array([2.0, 3.0, 5.0, 6.0, 8.0, 9.0, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.image_frame_1d_kernels[7] - == np.array( - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, -1, -1, -1, -1, -1, -1] - ) - ).all() - assert ( - convolver.image_frame_1d_kernels[8] - == np.array([1.0, 2.0, 4.0, 5.0, 7.0, 8.0, -1, -1, -1, -1, -1, -1, -1, -1, -1]) - ).all() - - -def test__blurring_frame_indexes__blurring_region_3x3_kernel(cross_mask): - convolver = aa.Convolver( - mask=cross_mask, kernel=aa.Kernel2D.ones(shape_native=(3, 3), pixel_scales=1.0) - ) - - assert ( - convolver.blurring_frame_1d_indexes[4] - == np.array([0, 1, 2, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.blurring_frame_1d_indexes[5] - == np.array([0, 2, 3, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.blurring_frame_1d_indexes[10] - == np.array([1, 2, 4, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.blurring_frame_1d_indexes[11] - == np.array([2, 3, 4, -1, -1, -1, -1, -1, -1]) - ).all() - - -def test__blurring_frame_kernels__blurring_region_3x3_kernel(cross_mask): - convolver = aa.Convolver( - mask=cross_mask, - kernel=aa.Kernel2D.no_mask( - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], pixel_scales=1.0 - ), - ) - - assert ( - convolver.blurring_frame_1d_kernels[4] - == np.array([6.0, 8.0, 9.0, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.blurring_frame_1d_kernels[5] - == np.array([4.0, 7.0, 8.0, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.blurring_frame_1d_kernels[10] - == np.array([2.0, 3.0, 6.0, -1, -1, -1, -1, -1, -1]) - ).all() - assert ( - convolver.blurring_frame_1d_kernels[11] - == np.array([1.0, 2.0, 4.0, -1, -1, -1, -1, -1, -1]) - ).all() - - -def test__frame_lengths__frames_are_from_examples_above__lengths_are_right( - simple_mask_2d_7x7, -): - convolver = aa.Convolver( - mask=simple_mask_2d_7x7, - kernel=aa.Kernel2D.ones(shape_native=(3, 5), pixel_scales=1.0), - ) - - # convolver_image.image_frame_indexes[0] == np.array([0, 1, 2, 3, 4, 5]) - # convolver_image.image_frame_indexes[1] == np.array([0, 1, 2, 3, 4, 5]) - # convolver_image.image_frame_indexes[2] == np.array([0, 1, 2, 3, 4, 5]) - # convolver_image.image_frame_indexes[3] == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) - # (convolver_image.image_frame_indexes[4] == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) - # convolver_image.image_frame_indexes[5] == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) - # convolver_image.image_frame_indexes[6] == np.array([3, 4, 5, 6, 7, 8]) - # convolver_image.image_frame_indexes[7] == np.array([3, 4, 5, 6, 7, 8]) - # convolver_image.image_frame_indexes[8] == np.array([3, 4, 5, 6, 7, 8]) - - assert ( - convolver.image_frame_1d_lengths == np.array([6, 6, 6, 9, 9, 9, 6, 6, 6]) - ).all() - - -def test__convolve_mapping_matrix__asymetric_convolver__matrix_blurred_correctly(): - mask = np.array( - [ - [True, True, True, True, True, True], - [True, False, False, False, False, True], - [True, False, False, False, False, True], - [True, False, False, False, False, True], - [True, False, False, False, False, True], - [True, True, True, True, True, True], - ] - ) - - asymmetric_kernel = aa.Kernel2D.no_mask( - values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 - ) - - convolver = aa.Convolver(mask=mask, kernel=asymmetric_kernel) - - mapping = np.array( - [ - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [ - 0, - 1, - 0, - ], # The 0.3 should be 'chopped' from this pixel as it is on the right-most edge - [0, 0, 0], - [1, 0, 0], - [0, 0, 1], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - ] - ) - - blurred_mapping = convolver.convolve_mapping_matrix(mapping) - - assert ( - blurred_mapping - == np.array( - [ - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0.4, 0], - [0, 0.2, 0], - [0.4, 0, 0], - [0.2, 0, 0.4], - [0.3, 0, 0.2], - [0, 0.1, 0.3], - [0, 0, 0], - [0.1, 0, 0], - [0, 0, 0.1], - [0, 0, 0], - ] - ) - ).all() - - asymmetric_kernel = aa.Kernel2D.no_mask( - values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 - ) - - convolver = aa.Convolver(mask=mask, kernel=asymmetric_kernel) - - mapping = np.array( - [ - [0, 1, 0], - [0, 1, 0], - [0, 1, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [ - 0, - 1, - 0, - ], # The 0.3 should be 'chopped' from this pixel as it is on the right-most edge - [1, 0, 0], - [1, 0, 0], - [0, 0, 1], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - ] - ) - - blurred_mapping = convolver.convolve_mapping_matrix(mapping) - - assert blurred_mapping == pytest.approx( - np.array( - [ - [0, 0.6, 0], - [0, 0.9, 0], - [0, 0.5, 0], - [0, 0.3, 0], - [0, 0.1, 0], - [0, 0.1, 0], - [0, 0.5, 0], - [0, 0.2, 0], - [0.6, 0, 0], - [0.5, 0, 0.4], - [0.3, 0, 0.2], - [0, 0.1, 0.3], - [0.1, 0, 0], - [0.1, 0, 0], - [0, 0, 0.1], - [0, 0, 0], - ] - ), - 1e-4, - ) - - -def test__convolution__cross_mask_with_blurring_entries__returns_array(): - cross_mask = aa.Mask2D( - mask=[ - [True, True, True, True, True], - [True, True, False, True, True], - [True, False, False, False, True], - [True, True, False, True, True], - [True, True, True, True, True], - ], - pixel_scales=0.1, - ) - - kernel = aa.Kernel2D.no_mask( - values=[[0, 0.2, 0], [0.2, 0.4, 0.2], [0, 0.2, 0]], pixel_scales=0.1 - ) - - convolver = aa.Convolver(mask=cross_mask, kernel=kernel) - - image_array = aa.Array2D(values=[1, 0, 0, 0, 0], mask=cross_mask) - - blurring_mask = cross_mask.derive_mask.blurring_from( - kernel_shape_native=kernel.shape_native - ) - - blurring_array = aa.Array2D( - values=[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], mask=blurring_mask - ) - - result = convolver.convolve_image(image=image_array, blurring_image=blurring_array) - - assert (np.round(result, 1) == np.array([0.6, 0.2, 0.2, 0.0, 0.0])).all() - - -def test__compare_to_full_2d_convolution(): - # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. - - import scipy.signal - - mask = aa.Mask2D.circular( - shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 - ) - kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) - image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) - - blurred_image_via_scipy = scipy.signal.convolve2d( - image.native, kernel.native, mode="same" - ) - blurred_image_via_scipy = aa.Array2D.no_mask( - values=blurred_image_via_scipy, pixel_scales=1.0 - ) - blurred_masked_image_via_scipy = aa.Array2D( - values=blurred_image_via_scipy.native, mask=mask - ) - - # Now reproduce this data using the frame convolver_image - - masked_image = aa.Array2D(values=image.native, mask=mask) - - blurring_mask = mask.derive_mask.blurring_from( - kernel_shape_native=kernel.shape_native - ) - - convolver = aa.Convolver(mask=mask, kernel=kernel) - - blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) - - blurred_masked_im_1 = convolver.convolve_image( - image=masked_image, blurring_image=blurring_image - ) - - assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) - - -def test__compare_to_full_2d_convolution__no_blurring_image(): - # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. - - import scipy.signal - - mask = aa.Mask2D.circular( - shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 - ) - kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) - image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) - - blurring_mask = mask.derive_mask.blurring_from( - kernel_shape_native=kernel.shape_native - ) - blurred_image_via_scipy = scipy.signal.convolve2d( - image.native * blurring_mask, kernel.native, mode="same" - ) - blurred_image_via_scipy = aa.Array2D.no_mask( - values=blurred_image_via_scipy, pixel_scales=1.0 - ) - blurred_masked_image_via_scipy = aa.Array2D( - values=blurred_image_via_scipy.native, mask=mask - ) - - # Now reproduce this data using the frame convolver_image - - masked_image = aa.Array2D(values=image.native, mask=mask) - - convolver = aa.Convolver(mask=mask, kernel=kernel) - - blurred_masked_im_1 = convolver.convolve_image_no_blurring(image=masked_image) - - assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) - - -def test__summed_convolved_array_from(): - mask = aa.Mask2D( - mask=[ - [True, True, True, True, True], - [True, True, True, True, True], - [True, False, False, False, True], - [True, True, True, True, True], - [True, True, True, True, True], - ], - pixel_scales=0.1, - ) - - kernel = aa.Kernel2D.no_mask( - values=[[0, 0.0, 0], [0.5, 1.0, 0.5], [0, 0.0, 0]], pixel_scales=0.1 - ) - - convolver = aa.Convolver(mask=mask, kernel=kernel) - - image_array = aa.Array2D(values=[1.0, 2.0, 3.0], mask=mask) - - summed_convolved_array = convolver.convolve_image_no_blurring(image=image_array) - - assert summed_convolved_array == pytest.approx(np.array([2.0, 4.0, 4.0]), 1.0e-4) diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 9106b43a1..23f66de19 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -248,6 +248,58 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): assert (kernel_2d.native == (1.0 / 15.0) * np.ones((5, 3))).all() +def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy_gaussian_model(): + pixel_scales = 0.1 + + x_stddev = ( + 1.0e-5 + * (units.deg).to(units.arcsec) + / pixel_scales + / (2.0 * np.sqrt(2.0 * np.log(2.0))) + ) + y_stddev = ( + 2.0e-5 + * (units.deg).to(units.arcsec) + / pixel_scales + / (2.0 * np.sqrt(2.0 * np.log(2.0))) + ) + + theta_deg = 230.0 + theta = Angle(theta_deg, "deg").radian + + gaussian_astropy = functional_models.Gaussian2D( + amplitude=1.0, + x_mean=1.0, + y_mean=1.0, + x_stddev=x_stddev, + y_stddev=y_stddev, + theta=theta, + ) + + shape = (3, 3) + y, x = np.mgrid[0 : shape[1], 0 : shape[0]] + kernel_astropy = gaussian_astropy(x, y) + kernel_astropy /= np.sum(kernel_astropy) + + kernel_2d = aa.Kernel2D.from_as_gaussian_via_alma_fits_header_parameters( + shape_native=shape, + pixel_scales=pixel_scales, + y_stddev=2.0e-5, + x_stddev=1.0e-5, + theta=theta_deg, + normalize=True, + ) + + assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) + + +def test__convolved_array_from__not_odd_x_odd_kernel__raises_error(): + kernel_2d = aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) + + with pytest.raises(exc.KernelException): + kernel_2d.convolved_array_from(np.ones((5, 5))) + + def test__convolved_array_from(): array_2d = aa.Array2D.no_mask( values=[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], pixel_scales=1.0 @@ -408,53 +460,261 @@ def test__convolved_array_from(): ).all() -def test__convolved_array_from__not_odd_x_odd_kernel__raises_error(): - kernel_2d = aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) +def test__convolve_mapping_matrix__asymetric_convolver__matrix_blurred_correctly(): + mask = np.array( + [ + [True, True, True, True, True, True], + [True, False, False, False, False, True], + [True, False, False, False, False, True], + [True, False, False, False, False, True], + [True, False, False, False, False, True], + [True, True, True, True, True, True], + ] + ) - with pytest.raises(exc.KernelException): - kernel_2d.convolved_array_from(np.ones((5, 5))) + asymmetric_kernel = aa.Kernel2D.no_mask( + values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 + ) + convolver = aa.Convolver(mask=mask, kernel=asymmetric_kernel) -def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy_gaussian_model(): - pixel_scales = 0.1 + mapping = np.array( + [ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [ + 0, + 1, + 0, + ], # The 0.3 should be 'chopped' from this pixel as it is on the right-most edge + [0, 0, 0], + [1, 0, 0], + [0, 0, 1], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + ] + ) + + blurred_mapping = convolver.convolve_mapping_matrix(mapping) - x_stddev = ( - 1.0e-5 - * (units.deg).to(units.arcsec) - / pixel_scales - / (2.0 * np.sqrt(2.0 * np.log(2.0))) + assert ( + blurred_mapping + == np.array( + [ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0.4, 0], + [0, 0.2, 0], + [0.4, 0, 0], + [0.2, 0, 0.4], + [0.3, 0, 0.2], + [0, 0.1, 0.3], + [0, 0, 0], + [0.1, 0, 0], + [0, 0, 0.1], + [0, 0, 0], + ] + ) + ).all() + + asymmetric_kernel = aa.Kernel2D.no_mask( + values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 ) - y_stddev = ( - 2.0e-5 - * (units.deg).to(units.arcsec) - / pixel_scales - / (2.0 * np.sqrt(2.0 * np.log(2.0))) + + convolver = aa.Convolver(mask=mask, kernel=asymmetric_kernel) + + mapping = np.array( + [ + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [ + 0, + 1, + 0, + ], # The 0.3 should be 'chopped' from this pixel as it is on the right-most edge + [1, 0, 0], + [1, 0, 0], + [0, 0, 1], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + ] + ) + + blurred_mapping = convolver.convolve_mapping_matrix(mapping) + + assert blurred_mapping == pytest.approx( + np.array( + [ + [0, 0.6, 0], + [0, 0.9, 0], + [0, 0.5, 0], + [0, 0.3, 0], + [0, 0.1, 0], + [0, 0.1, 0], + [0, 0.5, 0], + [0, 0.2, 0], + [0.6, 0, 0], + [0.5, 0, 0.4], + [0.3, 0, 0.2], + [0, 0.1, 0.3], + [0.1, 0, 0], + [0.1, 0, 0], + [0, 0, 0.1], + [0, 0, 0], + ] + ), + 1e-4, ) - theta_deg = 230.0 - theta = Angle(theta_deg, "deg").radian - gaussian_astropy = functional_models.Gaussian2D( - amplitude=1.0, - x_mean=1.0, - y_mean=1.0, - x_stddev=x_stddev, - y_stddev=y_stddev, - theta=theta, +def test__convolution__cross_mask_with_blurring_entries__returns_array(): + cross_mask = aa.Mask2D( + mask=[ + [True, True, True, True, True], + [True, True, False, True, True], + [True, False, False, False, True], + [True, True, False, True, True], + [True, True, True, True, True], + ], + pixel_scales=0.1, ) - shape = (3, 3) - y, x = np.mgrid[0 : shape[1], 0 : shape[0]] - kernel_astropy = gaussian_astropy(x, y) - kernel_astropy /= np.sum(kernel_astropy) + kernel = aa.Kernel2D.no_mask( + values=[[0, 0.2, 0], [0.2, 0.4, 0.2], [0, 0.2, 0]], pixel_scales=0.1 + ) - kernel_2d = aa.Kernel2D.from_as_gaussian_via_alma_fits_header_parameters( - shape_native=shape, - pixel_scales=pixel_scales, - y_stddev=2.0e-5, - x_stddev=1.0e-5, - theta=theta_deg, - normalize=True, + convolver = aa.Convolver(mask=cross_mask, kernel=kernel) + + image_array = aa.Array2D(values=[1, 0, 0, 0, 0], mask=cross_mask) + + blurring_mask = cross_mask.derive_mask.blurring_from( + kernel_shape_native=kernel.shape_native ) - assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) + blurring_array = aa.Array2D( + values=[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], mask=blurring_mask + ) + + result = convolver.convolve_image(image=image_array, blurring_image=blurring_array) + + assert (np.round(result, 1) == np.array([0.6, 0.2, 0.2, 0.0, 0.0])).all() + + +def test__compare_to_full_2d_convolution(): + # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. + + import scipy.signal + + mask = aa.Mask2D.circular( + shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 + ) + kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) + image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) + + blurred_image_via_scipy = scipy.signal.convolve2d( + image.native, kernel.native, mode="same" + ) + blurred_image_via_scipy = aa.Array2D.no_mask( + values=blurred_image_via_scipy, pixel_scales=1.0 + ) + blurred_masked_image_via_scipy = aa.Array2D( + values=blurred_image_via_scipy.native, mask=mask + ) + + # Now reproduce this data using the frame convolver_image + + masked_image = aa.Array2D(values=image.native, mask=mask) + + blurring_mask = mask.derive_mask.blurring_from( + kernel_shape_native=kernel.shape_native + ) + + convolver = aa.Convolver(mask=mask, kernel=kernel) + + blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) + + blurred_masked_im_1 = convolver.convolve_image( + image=masked_image, blurring_image=blurring_image + ) + + assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) + + +def test__compare_to_full_2d_convolution__no_blurring_image(): + # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. + + import scipy.signal + + mask = aa.Mask2D.circular( + shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 + ) + kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) + image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) + + blurring_mask = mask.derive_mask.blurring_from( + kernel_shape_native=kernel.shape_native + ) + blurred_image_via_scipy = scipy.signal.convolve2d( + image.native * blurring_mask, kernel.native, mode="same" + ) + blurred_image_via_scipy = aa.Array2D.no_mask( + values=blurred_image_via_scipy, pixel_scales=1.0 + ) + blurred_masked_image_via_scipy = aa.Array2D( + values=blurred_image_via_scipy.native, mask=mask + ) + + # Now reproduce this data using the frame convolver_image + + masked_image = aa.Array2D(values=image.native, mask=mask) + + convolver = aa.Convolver(mask=mask, kernel=kernel) + + blurred_masked_im_1 = convolver.convolve_image_no_blurring(image=masked_image) + + assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) + + +def test__summed_convolved_array_from(): + mask = aa.Mask2D( + mask=[ + [True, True, True, True, True], + [True, True, True, True, True], + [True, False, False, False, True], + [True, True, True, True, True], + [True, True, True, True, True], + ], + pixel_scales=0.1, + ) + + kernel = aa.Kernel2D.no_mask( + values=[[0, 0.0, 0], [0.5, 1.0, 0.5], [0, 0.0, 0]], pixel_scales=0.1 + ) + + convolver = aa.Convolver(mask=mask, kernel=kernel) + + image_array = aa.Array2D(values=[1.0, 2.0, 3.0], mask=mask) + + summed_convolved_array = convolver.convolve_image_no_blurring(image=image_array) + + assert summed_convolved_array == pytest.approx(np.array([2.0, 4.0, 4.0]), 1.0e-4) From 0b4763ce2d1838b97479e46eb800e319a3ae841d Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 19:17:04 +0100 Subject: [PATCH 082/108] cleaned up test_kernel_2d --- .../structures/arrays/test_kernel_2d.py | 181 +----------------- 1 file changed, 7 insertions(+), 174 deletions(-) diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 23f66de19..bd52796a0 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -301,107 +301,6 @@ def test__convolved_array_from__not_odd_x_odd_kernel__raises_error(): def test__convolved_array_from(): - array_2d = aa.Array2D.no_mask( - values=[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], pixel_scales=1.0 - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[0.0, 1.0, 0.0], [1.0, 2.0, 1.0], [0.0, 1.0, 0.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) - - assert (blurred_array_2d == kernel_2d).all() - - array_2d = aa.Array2D.no_mask( - values=[ - [0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - pixel_scales=1.0, - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[0.0, 1.0, 0.0], [1.0, 2.0, 1.0], [0.0, 1.0, 0.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array=array_2d) - - assert ( - blurred_array_2d.native - == np.array( - [ - [0.0, 1.0, 0.0, 0.0], - [1.0, 2.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ] - ) - ).all() - - array_2d = aa.Array2D.no_mask( - values=[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - pixel_scales=1.0, - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[0.0, 1.0, 0.0], [1.0, 2.0, 1.0], [0.0, 1.0, 0.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) - - assert ( - blurred_array_2d.native - == np.array( - [[0.0, 1.0, 0.0], [1.0, 2.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]] - ) - ).all() - - array_2d = aa.Array2D.no_mask( - values=[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - pixel_scales=1.0, - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[0.0, 1.0, 0.0], [1.0, 2.0, 1.0], [0.0, 1.0, 0.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) - - assert ( - blurred_array_2d.native - == np.array([[0.0, 1.0, 0.0, 0.0], [1.0, 2.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0]]) - ).all() - - array_2d = aa.Array2D.no_mask( - values=[ - [0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - pixel_scales=1.0, - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 3.0, 3.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) - - assert ( - blurred_array_2d.native - == np.array( - [ - [1.0, 1.0, 1.0, 0.0], - [2.0, 3.0, 2.0, 1.0], - [1.0, 5.0, 5.0, 1.0], - [0.0, 1.0, 3.0, 3.0], - ] - ) - ).all() array_2d = aa.Array2D.no_mask( [ @@ -460,7 +359,7 @@ def test__convolved_array_from(): ).all() -def test__convolve_mapping_matrix__asymetric_convolver__matrix_blurred_correctly(): +def test__convolve_mapping_matrix(): mask = np.array( [ [True, True, True, True, True, True], @@ -472,12 +371,10 @@ def test__convolve_mapping_matrix__asymetric_convolver__matrix_blurred_correctly ] ) - asymmetric_kernel = aa.Kernel2D.no_mask( + kernel = aa.Kernel2D.no_mask( values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 ) - convolver = aa.Convolver(mask=mask, kernel=asymmetric_kernel) - mapping = np.array( [ [0, 0, 0], @@ -503,7 +400,7 @@ def test__convolve_mapping_matrix__asymetric_convolver__matrix_blurred_correctly ] ) - blurred_mapping = convolver.convolve_mapping_matrix(mapping) + blurred_mapping = kernel.convolve_mapping_matrix(mapping) assert ( blurred_mapping @@ -529,12 +426,10 @@ def test__convolve_mapping_matrix__asymetric_convolver__matrix_blurred_correctly ) ).all() - asymmetric_kernel = aa.Kernel2D.no_mask( + kernel = aa.Kernel2D.no_mask( values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 ) - convolver = aa.Convolver(mask=mask, kernel=asymmetric_kernel) - mapping = np.array( [ [0, 1, 0], @@ -560,7 +455,7 @@ def test__convolve_mapping_matrix__asymetric_convolver__matrix_blurred_correctly ] ) - blurred_mapping = convolver.convolve_mapping_matrix(mapping) + blurred_mapping = kernel.convolve_mapping_matrix(mapping) assert blurred_mapping == pytest.approx( np.array( @@ -587,39 +482,6 @@ def test__convolve_mapping_matrix__asymetric_convolver__matrix_blurred_correctly ) -def test__convolution__cross_mask_with_blurring_entries__returns_array(): - cross_mask = aa.Mask2D( - mask=[ - [True, True, True, True, True], - [True, True, False, True, True], - [True, False, False, False, True], - [True, True, False, True, True], - [True, True, True, True, True], - ], - pixel_scales=0.1, - ) - - kernel = aa.Kernel2D.no_mask( - values=[[0, 0.2, 0], [0.2, 0.4, 0.2], [0, 0.2, 0]], pixel_scales=0.1 - ) - - convolver = aa.Convolver(mask=cross_mask, kernel=kernel) - - image_array = aa.Array2D(values=[1, 0, 0, 0, 0], mask=cross_mask) - - blurring_mask = cross_mask.derive_mask.blurring_from( - kernel_shape_native=kernel.shape_native - ) - - blurring_array = aa.Array2D( - values=[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], mask=blurring_mask - ) - - result = convolver.convolve_image(image=image_array, blurring_image=blurring_array) - - assert (np.round(result, 1) == np.array([0.6, 0.2, 0.2, 0.0, 0.0])).all() - - def test__compare_to_full_2d_convolution(): # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. @@ -649,11 +511,9 @@ def test__compare_to_full_2d_convolution(): kernel_shape_native=kernel.shape_native ) - convolver = aa.Convolver(mask=mask, kernel=kernel) - blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) - blurred_masked_im_1 = convolver.convolve_image( + blurred_masked_im_1 = kernel.convolve_image( image=masked_image, blurring_image=blurring_image ) @@ -688,33 +548,6 @@ def test__compare_to_full_2d_convolution__no_blurring_image(): masked_image = aa.Array2D(values=image.native, mask=mask) - convolver = aa.Convolver(mask=mask, kernel=kernel) - - blurred_masked_im_1 = convolver.convolve_image_no_blurring(image=masked_image) + blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image) assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) - - -def test__summed_convolved_array_from(): - mask = aa.Mask2D( - mask=[ - [True, True, True, True, True], - [True, True, True, True, True], - [True, False, False, False, True], - [True, True, True, True, True], - [True, True, True, True, True], - ], - pixel_scales=0.1, - ) - - kernel = aa.Kernel2D.no_mask( - values=[[0, 0.0, 0], [0.5, 1.0, 0.5], [0, 0.0, 0]], pixel_scales=0.1 - ) - - convolver = aa.Convolver(mask=mask, kernel=kernel) - - image_array = aa.Array2D(values=[1.0, 2.0, 3.0], mask=mask) - - summed_convolved_array = convolver.convolve_image_no_blurring(image=image_array) - - assert summed_convolved_array == pytest.approx(np.array([2.0, 4.0, 4.0]), 1.0e-4) From 5762fcf320794b12f66359d5d29a56ddc3eb130c Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 19:19:43 +0100 Subject: [PATCH 083/108] remopve convolved_array_with_mask_From --- autoarray/structures/arrays/kernel_2d.py | 31 ------------------------ 1 file changed, 31 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 09ff331ef..d384f8a4b 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -483,37 +483,6 @@ def convolved_array_from(self, array: Array2D) -> Array2D: return Array2D(values=convolved_array_1d, mask=array_2d.mask) - def convolved_array_with_mask_from(self, array: Array2D, mask: Mask2D) -> Array2D: - """ - Convolve an array with this Kernel2D - - Parameters - ---------- - image - An array representing the image the Kernel2D is convolved with. - - Returns - ------- - convolved_image - An array representing the image after convolution. - - Raises - ------ - KernelException if either Kernel2D psf dimension is odd - """ - - if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") - - convolved_array_2d = scipy.signal.convolve2d(array, self.native, mode="same") - - convolved_array_1d = array_2d_util.array_2d_slim_from( - mask_2d=np.array(mask), - array_2d_native=np.array(convolved_array_2d), - ) - - return Array2D(values=convolved_array_1d, mask=mask) - def jax_convolve(self, image, blurring_image, method="auto"): slim_to_2D_index_image = jnp.nonzero( jnp.logical_not(self.mask.array), size=image.shape[0] From 6fa35a6f838fe89e4fb0209bc25c9f960e99e152 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 19:21:20 +0100 Subject: [PATCH 084/108] simplify jax_convolve --- autoarray/structures/arrays/kernel_2d.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index d384f8a4b..def73f067 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -484,22 +484,29 @@ def convolved_array_from(self, array: Array2D) -> Array2D: return Array2D(values=convolved_array_1d, mask=array_2d.mask) def jax_convolve(self, image, blurring_image, method="auto"): - slim_to_2D_index_image = jnp.nonzero( + + slim_to_native = jnp.nonzero( jnp.logical_not(self.mask.array), size=image.shape[0] ) - slim_to_2D_index_blurring = jnp.nonzero( + slim_to_native_blurring = jnp.nonzero( jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0] ) + expanded_image_native = jnp.zeros(self.mask.shape) - expanded_image_native = expanded_image_native.at[slim_to_2D_index_image].set( + + expanded_image_native = expanded_image_native.at[slim_to_native].set( image.array ) - expanded_image_native = expanded_image_native.at[slim_to_2D_index_blurring].set( + expanded_image_native = expanded_image_native.at[slim_to_native_blurring].set( blurring_image.array ) + kernel = np.array(self.kernel.native.array) + convolve_native = jax.scipy.signal.convolve( expanded_image_native, kernel, mode="same", method=method ) - convolve_slim = convolve_native[slim_to_2D_index_image] + + convolve_slim = convolve_native[slim_to_native] + return convolve_slim \ No newline at end of file From 0386bddb1860869c91766b3ead0672f3b87418e7 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 19:47:45 +0100 Subject: [PATCH 085/108] test__convolve_image --- autoarray/structures/arrays/kernel_2d.py | 66 ++++++++++++++++--- .../structures/arrays/test_kernel_2d.py | 52 ++++++++++++--- 2 files changed, 98 insertions(+), 20 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index def73f067..28ee92725 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -8,7 +8,6 @@ from autoconf.fitsable import header_obj_from -from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.uniform_2d import AbstractArray2D from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.uniform_2d import Grid2D @@ -17,6 +16,7 @@ from autoarray import exc from autoarray import type as ty from autoarray.structures.arrays import array_2d_util +from autoarray.mask.mask_2d import mask_2d_util class Kernel2D(AbstractArray2D): @@ -483,28 +483,74 @@ def convolved_array_from(self, array: Array2D) -> Array2D: return Array2D(values=convolved_array_1d, mask=array_2d.mask) - def jax_convolve(self, image, blurring_image, method="auto"): + def convolve_image(self, image, blurring_image, jax_method="fft"): + """ + For a given 1D array and blurring array, convolve the two using this convolver. + + Parameters + ---------- + image + 1D array of the values which are to be blurred with the convolver's PSF. + blurring_image + 1D array of the blurring values which blur into the array after PSF convolution. + jax_method + If JAX is enabled this keyword will indicate what method is used for the PSF + convolution. Can be either `direct` to calculate it in real space or `fft` + to calculated it via a fast Fourier transform. `fft` is typically faster for + kernels that are more than about 5x5. Default is `fft`. + """ + + if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: + raise exc.KernelException("Kernel2D Kernel2D must be odd") + + print(type(image.native + blurring_image.native)) + print(type(self.native)) + + convolved_array_2d = scipy.signal.convolve2d((image.native + blurring_image.native)._array, self.native._array, mode="same") + convolved_array_1d = array_2d_util.array_2d_slim_from( + mask_2d=np.array(image.mask), + array_2d_native=convolved_array_2d, + ) + + return Array2D(values=convolved_array_1d, mask=image.mask) + + def convolve_image_jax_from(self, array, blurring_array, method="auto"): + """ + For a given 1D array and blurring array, convolve the two using this convolver. + + Parameters + ---------- + array + 1D array of the values which are to be blurred with the convolver's PSF. + blurring_array + 1D array of the blurring values which blur into the array after PSF convolution. + jax_method + If JAX is enabled this keyword will indicate what method is used for the PSF + convolution. Can be either `direct` to calculate it in real space or `fft` + to calculated it via a fast Fourier transform. `fft` is typically faster for + kernels that are more than about 5x5. Default is `fft`. + """ slim_to_native = jnp.nonzero( - jnp.logical_not(self.mask.array), size=image.shape[0] + jnp.logical_not(self.mask.array), size=array.shape[0] ) slim_to_native_blurring = jnp.nonzero( - jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0] + jnp.logical_not(self.blurring_mask), size=blurring_array.shape[0] ) - expanded_image_native = jnp.zeros(self.mask.shape) + expanded_array_native = jnp.zeros(self.mask.shape) - expanded_image_native = expanded_image_native.at[slim_to_native].set( - image.array + expanded_array_native = expanded_array_native.at[slim_to_native].set( + array.array ) - expanded_image_native = expanded_image_native.at[slim_to_native_blurring].set( - blurring_image.array + expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set( + blurring_array.array ) kernel = np.array(self.kernel.native.array) convolve_native = jax.scipy.signal.convolve( - expanded_image_native, kernel, mode="same", method=method + expanded_array_native, kernel, mode="same", method=method ) convolve_slim = convolve_native[slim_to_native] diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index bd52796a0..50c931ae5 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -1,11 +1,10 @@ -from astropy.io import fits from astropy import units from astropy.modeling import functional_models from astropy.coordinates import Angle +import jax.numpy as jnp import numpy as np import pytest from os import path -import os import autoarray as aa from autoarray import exc @@ -359,6 +358,36 @@ def test__convolved_array_from(): ).all() +def test__convolved_array_from__input_jax_array(): + + array_2d = jnp.array( + [ + [0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0], + ]) + + kernel_2d = aa.Kernel2D.no_mask( + values=[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 3.0, 3.0]], pixel_scales=1.0 + ) + + blurred_array_2d = kernel_2d.convolved_array_from(array_2d) + + assert ( + blurred_array_2d.native + == np.array( + [ + [1.0, 1.0, 0.0, 0.0], + [2.0, 1.0, 1.0, 1.0], + [3.0, 3.0, 2.0, 2.0], + [0.0, 0.0, 1.0, 3.0], + ] + ) + ).all() + + + def test__convolve_mapping_matrix(): mask = np.array( [ @@ -482,19 +511,19 @@ def test__convolve_mapping_matrix(): ) -def test__compare_to_full_2d_convolution(): - # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. - - import scipy.signal +def test__convolve_image(): mask = aa.Mask2D.circular( shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 ) - kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) - image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) + + import scipy.signal + + kernel = np.arange(49).reshape(7, 7) + image = np.arange(900).reshape(30, 30) blurred_image_via_scipy = scipy.signal.convolve2d( - image.native, kernel.native, mode="same" + image, kernel, mode="same" ) blurred_image_via_scipy = aa.Array2D.no_mask( values=blurred_image_via_scipy, pixel_scales=1.0 @@ -503,7 +532,10 @@ def test__compare_to_full_2d_convolution(): values=blurred_image_via_scipy.native, mask=mask ) - # Now reproduce this data using the frame convolver_image + # Now reproduce this data using the convolve_image function + + image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) + kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) masked_image = aa.Array2D(values=image.native, mask=mask) From 537b5ef8683d4a4ef4168fc0e9583dc94604c787 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 19:54:04 +0100 Subject: [PATCH 086/108] convolve_image now only uses JAX --- autoarray/structures/arrays/kernel_2d.py | 52 +++++-------------- .../structures/arrays/test_kernel_2d.py | 5 +- 2 files changed, 16 insertions(+), 41 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 28ee92725..3f34c49bc 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -54,6 +54,9 @@ def __init__( store_native=store_native, ) + if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: + raise exc.KernelException("Kernel2D Kernel2D must be odd") + if normalize: self._array = np.divide(self._array, np.sum(self._array)) @@ -500,59 +503,28 @@ def convolve_image(self, image, blurring_image, jax_method="fft"): kernels that are more than about 5x5. Default is `fft`. """ - if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") - - print(type(image.native + blurring_image.native)) - print(type(self.native)) - - convolved_array_2d = scipy.signal.convolve2d((image.native + blurring_image.native)._array, self.native._array, mode="same") - - convolved_array_1d = array_2d_util.array_2d_slim_from( - mask_2d=np.array(image.mask), - array_2d_native=convolved_array_2d, - ) - - return Array2D(values=convolved_array_1d, mask=image.mask) - - def convolve_image_jax_from(self, array, blurring_array, method="auto"): - """ - For a given 1D array and blurring array, convolve the two using this convolver. - - Parameters - ---------- - array - 1D array of the values which are to be blurred with the convolver's PSF. - blurring_array - 1D array of the blurring values which blur into the array after PSF convolution. - jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. - """ slim_to_native = jnp.nonzero( - jnp.logical_not(self.mask.array), size=array.shape[0] + jnp.logical_not(image.mask.array), size=image.shape[0] ) slim_to_native_blurring = jnp.nonzero( - jnp.logical_not(self.blurring_mask), size=blurring_array.shape[0] + jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] ) - expanded_array_native = jnp.zeros(self.mask.shape) + expanded_array_native = jnp.zeros(image.mask.shape) expanded_array_native = expanded_array_native.at[slim_to_native].set( - array.array + image.array ) expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set( - blurring_array.array + blurring_image.array ) - kernel = np.array(self.kernel.native.array) + kernel = np.array(self.native.array) convolve_native = jax.scipy.signal.convolve( - expanded_array_native, kernel, mode="same", method=method + expanded_array_native, kernel, mode="same", method=jax_method ) - convolve_slim = convolve_native[slim_to_native] + convolved_array_1d = convolve_native[slim_to_native] - return convolve_slim \ No newline at end of file + return Array2D(values=convolved_array_1d, mask=image.mask) \ No newline at end of file diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 50c931ae5..11dab836d 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -549,7 +549,10 @@ def test__convolve_image(): image=masked_image, blurring_image=blurring_image ) - assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) + print(blurred_masked_image_via_scipy) + print(blurred_masked_im_1) + + assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4) def test__compare_to_full_2d_convolution__no_blurring_image(): From 27bf06a9b15b6a20755ffef749a0a05a2dcef70a Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 19:56:18 +0100 Subject: [PATCH 087/108] fix test on array shape --- .../structures/arrays/test_kernel_2d.py | 170 +++++++----------- 1 file changed, 68 insertions(+), 102 deletions(-) diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 11dab836d..d70854486 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -292,11 +292,10 @@ def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) -def test__convolved_array_from__not_odd_x_odd_kernel__raises_error(): - kernel_2d = aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) +def test__not_odd_x_odd_kernel__raises_error(): with pytest.raises(exc.KernelException): - kernel_2d.convolved_array_from(np.ones((5, 5))) + aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) def test__convolved_array_from(): @@ -358,34 +357,78 @@ def test__convolved_array_from(): ).all() -def test__convolved_array_from__input_jax_array(): +def test__convolve_image(): - array_2d = jnp.array( - [ - [0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0], - ]) + mask = aa.Mask2D.circular( + shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 + ) - kernel_2d = aa.Kernel2D.no_mask( - values=[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 3.0, 3.0]], pixel_scales=1.0 + import scipy.signal + + kernel = np.arange(49).reshape(7, 7) + image = np.arange(900).reshape(30, 30) + + blurred_image_via_scipy = scipy.signal.convolve2d( + image, kernel, mode="same" + ) + blurred_image_via_scipy = aa.Array2D.no_mask( + values=blurred_image_via_scipy, pixel_scales=1.0 + ) + blurred_masked_image_via_scipy = aa.Array2D( + values=blurred_image_via_scipy.native, mask=mask ) - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) + # Now reproduce this data using the convolve_image function - assert ( - blurred_array_2d.native - == np.array( - [ - [1.0, 1.0, 0.0, 0.0], - [2.0, 1.0, 1.0, 1.0], - [3.0, 3.0, 2.0, 2.0], - [0.0, 0.0, 1.0, 3.0], - ] - ) - ).all() + image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) + kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) + + masked_image = aa.Array2D(values=image.native, mask=mask) + + blurring_mask = mask.derive_mask.blurring_from( + kernel_shape_native=kernel.shape_native + ) + + blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) + + blurred_masked_im_1 = kernel.convolve_image( + image=masked_image, blurring_image=blurring_image + ) + + assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4) + + +def test__compare_to_full_2d_convolution__no_blurring_image(): + # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. + + import scipy.signal + + mask = aa.Mask2D.circular( + shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 + ) + kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) + image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) + + blurring_mask = mask.derive_mask.blurring_from( + kernel_shape_native=kernel.shape_native + ) + blurred_image_via_scipy = scipy.signal.convolve2d( + image.native * blurring_mask, kernel.native, mode="same" + ) + blurred_image_via_scipy = aa.Array2D.no_mask( + values=blurred_image_via_scipy, pixel_scales=1.0 + ) + blurred_masked_image_via_scipy = aa.Array2D( + values=blurred_image_via_scipy.native, mask=mask + ) + + # Now reproduce this data using the frame convolver_image + masked_image = aa.Array2D(values=image.native, mask=mask) + + blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image) + + assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) def test__convolve_mapping_matrix(): @@ -509,80 +552,3 @@ def test__convolve_mapping_matrix(): ), 1e-4, ) - - -def test__convolve_image(): - - mask = aa.Mask2D.circular( - shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 - ) - - import scipy.signal - - kernel = np.arange(49).reshape(7, 7) - image = np.arange(900).reshape(30, 30) - - blurred_image_via_scipy = scipy.signal.convolve2d( - image, kernel, mode="same" - ) - blurred_image_via_scipy = aa.Array2D.no_mask( - values=blurred_image_via_scipy, pixel_scales=1.0 - ) - blurred_masked_image_via_scipy = aa.Array2D( - values=blurred_image_via_scipy.native, mask=mask - ) - - # Now reproduce this data using the convolve_image function - - image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) - kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) - - masked_image = aa.Array2D(values=image.native, mask=mask) - - blurring_mask = mask.derive_mask.blurring_from( - kernel_shape_native=kernel.shape_native - ) - - blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) - - blurred_masked_im_1 = kernel.convolve_image( - image=masked_image, blurring_image=blurring_image - ) - - print(blurred_masked_image_via_scipy) - print(blurred_masked_im_1) - - assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4) - - -def test__compare_to_full_2d_convolution__no_blurring_image(): - # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. - - import scipy.signal - - mask = aa.Mask2D.circular( - shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 - ) - kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) - image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) - - blurring_mask = mask.derive_mask.blurring_from( - kernel_shape_native=kernel.shape_native - ) - blurred_image_via_scipy = scipy.signal.convolve2d( - image.native * blurring_mask, kernel.native, mode="same" - ) - blurred_image_via_scipy = aa.Array2D.no_mask( - values=blurred_image_via_scipy, pixel_scales=1.0 - ) - blurred_masked_image_via_scipy = aa.Array2D( - values=blurred_image_via_scipy.native, mask=mask - ) - - # Now reproduce this data using the frame convolver_image - - masked_image = aa.Array2D(values=image.native, mask=mask) - - blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image) - - assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) From 44cd4157b070a56b26077189352eb7aa105d967b Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 20:01:05 +0100 Subject: [PATCH 088/108] convolve_image_no_blurring --- autoarray/structures/arrays/kernel_2d.py | 37 +++++++++++++++++++ .../structures/arrays/test_kernel_2d.py | 20 ++++++---- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 3f34c49bc..cf2f451f9 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -527,4 +527,41 @@ def convolve_image(self, image, blurring_image, jax_method="fft"): convolved_array_1d = convolve_native[slim_to_native] + return Array2D(values=convolved_array_1d, mask=image.mask) + + def convolve_image_no_blurring(self, image, jax_method="fft"): + """ + For a given 1D array and blurring array, convolve the two using this convolver. + + Parameters + ---------- + image + 1D array of the values which are to be blurred with the convolver's PSF. + blurring_image + 1D array of the blurring values which blur into the array after PSF convolution. + jax_method + If JAX is enabled this keyword will indicate what method is used for the PSF + convolution. Can be either `direct` to calculate it in real space or `fft` + to calculated it via a fast Fourier transform. `fft` is typically faster for + kernels that are more than about 5x5. Default is `fft`. + """ + + slim_to_native = jnp.nonzero( + jnp.logical_not(image.mask.array), size=image.shape[0] + ) + + expanded_array_native = jnp.zeros(image.mask.shape) + + expanded_array_native = expanded_array_native.at[slim_to_native].set( + image.array + ) + + kernel = np.array(self.native.array) + + convolve_native = jax.scipy.signal.convolve( + expanded_array_native, kernel, mode="same", method=jax_method + ) + + convolved_array_1d = convolve_native[slim_to_native] + return Array2D(values=convolved_array_1d, mask=image.mask) \ No newline at end of file diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index d70854486..c24092ff1 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -398,22 +398,23 @@ def test__convolve_image(): assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4) -def test__compare_to_full_2d_convolution__no_blurring_image(): +def test__convolve_image_no_blurring(): # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. - import scipy.signal - mask = aa.Mask2D.circular( shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 ) - kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) - image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) + + import scipy.signal + + kernel = np.arange(49).reshape(7, 7) + image = np.arange(900).reshape(30, 30) blurring_mask = mask.derive_mask.blurring_from( - kernel_shape_native=kernel.shape_native + kernel_shape_native=kernel.shape ) blurred_image_via_scipy = scipy.signal.convolve2d( - image.native * blurring_mask, kernel.native, mode="same" + image * blurring_mask, kernel, mode="same" ) blurred_image_via_scipy = aa.Array2D.no_mask( values=blurred_image_via_scipy, pixel_scales=1.0 @@ -424,11 +425,14 @@ def test__compare_to_full_2d_convolution__no_blurring_image(): # Now reproduce this data using the frame convolver_image + kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) + image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) + masked_image = aa.Array2D(values=image.native, mask=mask) blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image) - assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4) + assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4) def test__convolve_mapping_matrix(): From 4e0b92596ad40168b9908bec6d356a7a4bcce421 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 20:55:00 +0100 Subject: [PATCH 089/108] maapping matrix convolve works --- autoarray/dataset/imaging/dataset.py | 3 +++ .../inversion/inversion/imaging/w_tilde.py | 3 ++- autoarray/structures/arrays/kernel_2d.py | 23 +++++++++++------- .../dataset/imaging/test_dataset.py | 7 ++++++ .../structures/arrays/test_kernel_2d.py | 24 ++++++++----------- 5 files changed, 37 insertions(+), 23 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 7fa3c515f..e3ec74d3b 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -166,6 +166,9 @@ def __init__( self.psf = psf + if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: + raise exc.KernelException("Kernel2D Kernel2D must be odd") + @cached_property def grids(self): return GridsDataset( diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index c249f4ce9..888be987e 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -526,7 +526,8 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: ) mapped_reconstructed_image = self.convolver.convolve_image_no_blurring( - image=mapped_reconstructed_image + image=mapped_reconstructed_image, + mask=self.mask ) else: diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index cf2f451f9..a9c148709 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -54,9 +54,6 @@ def __init__( store_native=store_native, ) - if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") - if normalize: self._array = np.divide(self._array, np.sum(self._array)) @@ -529,7 +526,7 @@ def convolve_image(self, image, blurring_image, jax_method="fft"): return Array2D(values=convolved_array_1d, mask=image.mask) - def convolve_image_no_blurring(self, image, jax_method="fft"): + def convolve_image_no_blurring(self, image, mask, jax_method="fft"): """ For a given 1D array and blurring array, convolve the two using this convolver. @@ -547,13 +544,13 @@ def convolve_image_no_blurring(self, image, jax_method="fft"): """ slim_to_native = jnp.nonzero( - jnp.logical_not(image.mask.array), size=image.shape[0] + jnp.logical_not(mask.array), size=image.shape[0] ) - expanded_array_native = jnp.zeros(image.mask.shape) + expanded_array_native = jnp.zeros(mask.shape) expanded_array_native = expanded_array_native.at[slim_to_native].set( - image.array + image ) kernel = np.array(self.native.array) @@ -564,4 +561,14 @@ def convolve_image_no_blurring(self, image, jax_method="fft"): convolved_array_1d = convolve_native[slim_to_native] - return Array2D(values=convolved_array_1d, mask=image.mask) \ No newline at end of file + return Array2D(values=convolved_array_1d, mask=mask) + + def convolve_mapping_matrix(self, mapping_matrix, mask): + """For a given 1D array and blurring array, convolve the two using this convolver. + + Parameters + ---------- + image + 1D array of the values which are to be blurred with the convolver's PSF. + """ + return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None))(mapping_matrix, mask).T \ No newline at end of file diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index 9c079b94d..9758f7cdb 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -8,6 +8,8 @@ import autoarray as aa +from autoarray import exc + test_data_path = path.join( "{}".format(path.dirname(path.realpath(__file__))), "files", @@ -241,3 +243,8 @@ def test__noise_map_unmasked_has_zeros_or_negative__raises_exception(): with pytest.raises(aa.exc.DatasetException): aa.Imaging(data=array, noise_map=noise_map) + +def test__psf_not_odd_x_odd_kernel__raises_error(): + + with pytest.raises(exc.KernelException): + aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) \ No newline at end of file diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index c24092ff1..2941801ba 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -292,12 +292,6 @@ def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) -def test__not_odd_x_odd_kernel__raises_error(): - - with pytest.raises(exc.KernelException): - aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) - - def test__convolved_array_from(): array_2d = aa.Array2D.no_mask( @@ -430,13 +424,13 @@ def test__convolve_image_no_blurring(): masked_image = aa.Array2D(values=image.native, mask=mask) - blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image) + blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image, mask=mask) assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4) def test__convolve_mapping_matrix(): - mask = np.array( + mask = aa.Mask2D(mask=np.array( [ [True, True, True, True, True, True], [True, False, False, False, False, True], @@ -445,7 +439,7 @@ def test__convolve_mapping_matrix(): [True, False, False, False, False, True], [True, True, True, True, True, True], ] - ) + ), pixel_scales=1.0) kernel = aa.Kernel2D.no_mask( values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 @@ -476,11 +470,11 @@ def test__convolve_mapping_matrix(): ] ) - blurred_mapping = kernel.convolve_mapping_matrix(mapping) + blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask) assert ( blurred_mapping - == np.array( + == pytest.approx(np.array( [ [0, 0, 0], [0, 0, 0], @@ -500,7 +494,7 @@ def test__convolve_mapping_matrix(): [0, 0, 0], ] ) - ).all() + ), 1.0e-4) kernel = aa.Kernel2D.no_mask( values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 @@ -531,7 +525,9 @@ def test__convolve_mapping_matrix(): ] ) - blurred_mapping = kernel.convolve_mapping_matrix(mapping) + blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask) + + print(blurred_mapping) assert blurred_mapping == pytest.approx( np.array( @@ -554,5 +550,5 @@ def test__convolve_mapping_matrix(): [0, 0, 0], ] ), - 1e-4, + abs=1e-4, ) From 731fcb5914e2b11abaf9b6bb520e277c140ac9e0 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 20:58:07 +0100 Subject: [PATCH 090/108] black --- autoarray/fixtures.py | 1 - autoarray/geometry/geometry_util.py | 21 ++--- .../inversion/inversion/imaging/w_tilde.py | 3 +- autoarray/mask/derive/indexes_2d.py | 6 +- autoarray/mask/mask_1d.py | 1 + autoarray/mask/mask_2d_util.py | 33 ++++--- .../over_sampling/over_sample_util.py | 10 +-- .../operators/over_sampling/over_sampler.py | 4 +- autoarray/structures/arrays/array_2d_util.py | 13 ++- autoarray/structures/arrays/kernel_2d.py | 16 ++-- autoarray/structures/grids/uniform_2d.py | 6 +- .../dataset/imaging/test_dataset.py | 11 +-- test_autoarray/geometry/test_geometry_util.py | 4 +- .../imaging/test_inversion_imaging_util.py | 4 +- test_autoarray/mask/test_mask_2d_util.py | 1 - .../structures/arrays/test_kernel_2d.py | 90 ++++++++++--------- .../structures/grids/test_uniform_2d.py | 2 - 17 files changed, 110 insertions(+), 116 deletions(-) diff --git a/autoarray/fixtures.py b/autoarray/fixtures.py index 68a8a73f8..e546fb497 100644 --- a/autoarray/fixtures.py +++ b/autoarray/fixtures.py @@ -110,7 +110,6 @@ def make_blurring_grid_2d_7x7(): return aa.Grid2D.from_mask(mask=make_blurring_mask_2d_7x7()) - def make_image_7x7(): return aa.Array2D.ones(shape_native=(7, 7), pixel_scales=(1.0, 1.0)) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index b646c7d08..e2c6c8898 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -181,6 +181,7 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales + @numba_util.jit() def central_pixel_coordinates_2d_numba_from( shape_native: Tuple[int, int], @@ -205,6 +206,7 @@ def central_pixel_coordinates_2d_numba_from( """ return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) + @numba_util.jit() def central_scaled_coordinate_2d_numba_from( shape_native: Tuple[int, int], @@ -305,6 +307,7 @@ def central_scaled_coordinate_2d_from( return (y_pixel, x_pixel) + def pixel_coordinates_2d_from( scaled_coordinates_2d: Tuple[float, float], shape_native: Tuple[int, int], @@ -589,9 +592,9 @@ def grid_pixel_centres_2d_slim_from( centres_scaled = np.array(centres_scaled) pixel_scales = np.array(pixel_scales) sign = np.array([-1.0, 1.0]) - return ( - (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 - ).astype(int) + return ((sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5).astype( + int + ) def grid_pixel_indexes_2d_slim_from( @@ -647,9 +650,7 @@ def grid_pixel_indexes_2d_slim_from( ) return ( - (grid_pixels_2d_slim * np.array([shape_native[1], 1])) - .sum(axis=1) - .astype(int) + (grid_pixels_2d_slim * np.array([shape_native[1], 1])).sum(axis=1).astype(int) ) @@ -698,9 +699,7 @@ def grid_scaled_2d_slim_from( centres_scaled = np.array(centres_scaled) pixel_scales = np.array(pixel_scales) sign = np.array([-1, 1]) - return ( - (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign - ) + return (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign def grid_pixel_centres_2d_from( @@ -750,9 +749,7 @@ def grid_pixel_centres_2d_from( centres_scaled = np.array(centres_scaled) pixel_scales = np.array(pixel_scales) sign = np.array([-1.0, 1.0]) - return ( - (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 - ).astype(int) + return ((sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5).astype(int) def extent_symmetric_from( diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index 888be987e..aef314586 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -526,8 +526,7 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: ) mapped_reconstructed_image = self.convolver.convolve_image_no_blurring( - image=mapped_reconstructed_image, - mask=self.mask + image=mapped_reconstructed_image, mask=self.mask ) else: diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index 4807799d9..062c8e664 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -198,9 +198,9 @@ def edge_slim(self) -> np.ndarray: print(derive_indexes_2d.edge_slim) """ - return mask_2d_util.edge_1d_indexes_from(mask_2d=np.array(self.mask).astype("bool")).astype( - "int" - ) + return mask_2d_util.edge_1d_indexes_from( + mask_2d=np.array(self.mask).astype("bool") + ).astype("int") @property def edge_native(self) -> np.ndarray: diff --git a/autoarray/mask/mask_1d.py b/autoarray/mask/mask_1d.py index 9330e62ee..8c36d8866 100644 --- a/autoarray/mask/mask_1d.py +++ b/autoarray/mask/mask_1d.py @@ -25,6 +25,7 @@ class Mask1DKeys(Enum): PIXSCA = "PIXSCA" ORIGIN = "ORIGIN" + class Mask1D(Mask): def __init__( self, diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 9ea4e35c0..462448073 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -6,6 +6,7 @@ from autoarray import exc from autoarray.numpy_wrapper import np as jnp + def native_index_for_slim_index_2d_from( mask_2d: np.ndarray, ) -> np.ndarray: @@ -400,6 +401,7 @@ def mask_2d_via_pixel_coordinates_from( return mask_2d return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) # Apply buf + def min_false_distance_to_edge(mask: np.ndarray) -> Tuple[int, int]: """ Compute the minimum 1D distance in the y and x directions from any `False` value at the mask's extreme positions @@ -618,14 +620,18 @@ def edge_1d_indexes_from(mask_2d: np.ndarray) -> np.ndarray: array([0, 1, 2, 3, 5, 6, 7, 8]) """ # Pad the mask to handle edge cases without index errors - padded_mask = np.pad(mask_2d, pad_width=1, mode='constant', constant_values=True) + padded_mask = np.pad(mask_2d, pad_width=1, mode="constant", constant_values=True) # Identify neighbors in 3x3 regions around each pixel neighbors = ( - padded_mask[:-2, 1:-1] | padded_mask[2:, 1:-1] | # Up, Down - padded_mask[1:-1, :-2] | padded_mask[1:-1, 2:] | # Left, Right - padded_mask[:-2, :-2] | padded_mask[:-2, 2:] | # Top-left, Top-right - padded_mask[2:, :-2] | padded_mask[2:, 2:] # Bottom-left, Bottom-right + padded_mask[:-2, 1:-1] + | padded_mask[2:, 1:-1] # Up, Down + | padded_mask[1:-1, :-2] + | padded_mask[1:-1, 2:] # Left, Right + | padded_mask[:-2, :-2] + | padded_mask[:-2, 2:] # Top-left, Top-right + | padded_mask[2:, :-2] + | padded_mask[2:, 2:] # Bottom-left, Bottom-right ) # Identify edge pixels: False values with at least one True neighbor @@ -708,10 +714,10 @@ def border_slim_indexes_from(mask_2d: np.ndarray) -> np.ndarray: # Identify border pixels: where the full length in any direction is True border_mask = ( - (up_sums == np.arange(height)[:, None]) | - (down_sums == np.arange(height - 1, -1, -1)[:, None]) | - (left_sums == np.arange(width)[None, :]) | - (right_sums == np.arange(width - 1, -1, -1)[None, :]) + (up_sums == np.arange(height)[:, None]) + | (down_sums == np.arange(height - 1, -1, -1)[:, None]) + | (left_sums == np.arange(width)[None, :]) + | (right_sums == np.arange(width - 1, -1, -1)[None, :]) ) & ~mask_2d # Create an index array where False entries get sequential 1D indices @@ -767,14 +773,16 @@ def buffed_mask_2d_from(mask_2d: np.ndarray, buffer: int = 1) -> np.ndarray: buffer_range = np.arange(-buffer, buffer + 1) # Generate all possible neighbors for each False entry - dy, dx = np.meshgrid(buffer_range, buffer_range, indexing='ij') + dy, dx = np.meshgrid(buffer_range, buffer_range, indexing="ij") neighbors = np.stack([dy.ravel(), dx.ravel()], axis=-1) # Calculate all neighboring positions for all False coordinates all_neighbors = np.add(np.array(false_coords).T[:, np.newaxis], neighbors) # Clip the neighbors to stay within the bounds of the mask - valid_neighbors = np.clip(all_neighbors, [0, 0], [mask_2d.shape[0] - 1, mask_2d.shape[1] - 1]) + valid_neighbors = np.clip( + all_neighbors, [0, 0], [mask_2d.shape[0] - 1, mask_2d.shape[1] - 1] + ) # Update the buffed mask: set all the neighbors to False buffed_mask_2d[valid_neighbors[:, :, 0], valid_neighbors[:, :, 1]] = False @@ -833,6 +841,3 @@ def rescaled_mask_2d_from(mask_2d: np.ndarray, rescale_factor: float) -> np.ndar rescaled_mask_2d[:, 0] = 1 rescaled_mask_2d[:, rescaled_mask_2d.shape[1] - 1] = 1 return np.isclose(rescaled_mask_2d, 1) - - - diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index 8966e785e..084d0bd0d 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -463,16 +463,10 @@ def grid_2d_slim_over_sampled_via_mask_from( # ) # else: grid_slim[sub_index, 0] = -( - y_scaled - - y_sub_half - + y1 * y_sub_step - + (y_sub_step / 2.0) + y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0) ) grid_slim[sub_index, 1] = ( - x_scaled - - x_sub_half - + x1 * x_sub_step - + (x_sub_step / 2.0) + x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0) ) sub_index += 1 diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 0aafbe008..6f12f4b9f 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -226,7 +226,9 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": # sub_size=np.array(self.sub_size).astype("int"), # ) - binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]**2).mean(axis=1) + binned_array_2d = array.reshape( + self.mask.shape_slim, self.sub_size[0] ** 2 + ).mean(axis=1) return Array2D( values=binned_array_2d, diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index 0e694846d..d8a0f7e10 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -27,10 +27,7 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array = np.asarray(array) elif isinstance(array, jnp.ndarray): array = jax.lax.cond( - type(array) is list, - lambda _: jnp.asarray(array), - lambda _: array, - None + type(array) is list, lambda _: jnp.asarray(array), lambda _: array, None ) return array @@ -41,6 +38,7 @@ def check_array_2d(array_2d: np.ndarray): "An array input into the Array2D.__new__ method is not of shape 1." ) + def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): """ The `manual` classmethods in the `Array2D` object take as input a list or ndarray which is returned as an @@ -90,6 +88,7 @@ def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): """ ) + def convert_array_2d( array_2d: Union[np.ndarray, List], mask_2d: Mask2D, @@ -489,6 +488,7 @@ def index_slim_for_index_2d_from(indexes_2d: np.ndarray, shape_native) -> np.nda return index_slim_for_index_native_2d + def array_2d_slim_from( array_2d_native: np.ndarray, mask_2d: np.ndarray, @@ -534,6 +534,7 @@ def array_2d_slim_from( """ return array_2d_native[~mask_2d.astype(bool)] + def array_2d_native_from( array_2d_slim: np.ndarray, mask_2d: np.ndarray, @@ -620,9 +621,7 @@ def array_2d_via_indexes_from( The native 2D array of values mapped from the slimmed array with dimensions (total_values, total_values). """ return ( - jnp.zeros(shape) - .at[tuple(native_index_for_slim_index_2d.T)] - .set(array_2d_slim) + jnp.zeros(shape).at[tuple(native_index_for_slim_index_2d.T)].set(array_2d_slim) ) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index a9c148709..5938198cf 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -474,7 +474,9 @@ def convolved_array_from(self, array: Array2D) -> Array2D: array_2d = array.native - convolved_array_2d = scipy.signal.convolve2d(array_2d._array, np.array(self.native._array), mode="same") + convolved_array_2d = scipy.signal.convolve2d( + array_2d._array, np.array(self.native._array), mode="same" + ) convolved_array_1d = array_2d_util.array_2d_slim_from( mask_2d=np.array(array_2d.mask), @@ -543,15 +545,11 @@ def convolve_image_no_blurring(self, image, mask, jax_method="fft"): kernels that are more than about 5x5. Default is `fft`. """ - slim_to_native = jnp.nonzero( - jnp.logical_not(mask.array), size=image.shape[0] - ) + slim_to_native = jnp.nonzero(jnp.logical_not(mask.array), size=image.shape[0]) expanded_array_native = jnp.zeros(mask.shape) - expanded_array_native = expanded_array_native.at[slim_to_native].set( - image - ) + expanded_array_native = expanded_array_native.at[slim_to_native].set(image) kernel = np.array(self.native.array) @@ -571,4 +569,6 @@ def convolve_mapping_matrix(self, mapping_matrix, mask): image 1D array of the values which are to be blurred with the convolver's PSF. """ - return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None))(mapping_matrix, mask).T \ No newline at end of file + return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None))( + mapping_matrix, mask + ).T diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 4d565dd27..670cbfcbe 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -840,9 +840,9 @@ def squared_distances_to_coordinate_from( The (y,x) coordinate from which the squared distance of every grid (y,x) coordinate is computed. """ if isinstance(self, jnp.ndarray): - squared_distances = jnp.square(self.array[:, 0] - coordinate[0]) + jnp.square( - self.array[:, 1] - coordinate[1] - ) + squared_distances = jnp.square( + self.array[:, 0] - coordinate[0] + ) + jnp.square(self.array[:, 1] - coordinate[1]) else: squared_distances = np.square(self[:, 0] - coordinate[0]) + np.square( self[:, 1] - coordinate[1] diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index 9758f7cdb..45ef0a80f 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -38,16 +38,12 @@ def test__psf_and_mask_hit_edge__automatically_pads_image_and_noise_map(): noise_map = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) psf = aa.Kernel2D.ones(shape_native=(3, 3), pixel_scales=1.0) - dataset = aa.Imaging( - data=image, noise_map=noise_map, psf=psf, pad_for_psf=False - ) + dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=False) assert dataset.data.shape_native == (3, 3) assert dataset.noise_map.shape_native == (3, 3) - dataset = aa.Imaging( - data=image, noise_map=noise_map, psf=psf, pad_for_psf=True - ) + dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=True) assert dataset.data.shape_native == (5, 5) assert dataset.noise_map.shape_native == (5, 5) @@ -244,7 +240,8 @@ def test__noise_map_unmasked_has_zeros_or_negative__raises_exception(): with pytest.raises(aa.exc.DatasetException): aa.Imaging(data=array, noise_map=noise_map) + def test__psf_not_odd_x_odd_kernel__raises_error(): with pytest.raises(exc.KernelException): - aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) \ No newline at end of file + aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) diff --git a/test_autoarray/geometry/test_geometry_util.py b/test_autoarray/geometry/test_geometry_util.py index 4bc2706dc..e65608bf1 100644 --- a/test_autoarray/geometry/test_geometry_util.py +++ b/test_autoarray/geometry/test_geometry_util.py @@ -994,7 +994,8 @@ def test__transform_2d_grid_to_reference_frame(): [0.0, np.sqrt(2)], [np.sqrt(2) / 2.0, np.sqrt(2) / 2.0], ] - ), abs=1.0e-4 + ), + abs=1.0e-4, ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1071,7 +1072,6 @@ def test__transform_2d_grid_from_reference_frame(): grid_2d=transformed_grid_2d, centre=(8.0, 5.0), angle=137.0 ) - assert grid_2d == pytest.approx(original_grid_2d, abs=1.0e-4) diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index b616a95d4..e05e13350 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -282,9 +282,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): w_tilde=w_tilde, mapping_matrix=mapping_matrix ) - blurred_mapping_matrix = psf.convolve_mapping_matrix( - mapping_matrix=mapping_matrix - ) + blurred_mapping_matrix = psf.convolve_mapping_matrix(mapping_matrix=mapping_matrix) curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( mapping_matrix=blurred_mapping_matrix, diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index bd2ebd84a..f3db2938e 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -723,7 +723,6 @@ def test__mask_1d_indexes_from(): assert masked_slim[-1] == 48 - def test__edge_1d_indexes_from(): mask = np.array( [ diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 2941801ba..4cf8b92d7 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -362,9 +362,7 @@ def test__convolve_image(): kernel = np.arange(49).reshape(7, 7) image = np.arange(900).reshape(30, 30) - blurred_image_via_scipy = scipy.signal.convolve2d( - image, kernel, mode="same" - ) + blurred_image_via_scipy = scipy.signal.convolve2d(image, kernel, mode="same") blurred_image_via_scipy = aa.Array2D.no_mask( values=blurred_image_via_scipy, pixel_scales=1.0 ) @@ -389,7 +387,9 @@ def test__convolve_image(): image=masked_image, blurring_image=blurring_image ) - assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4) + assert blurred_masked_image_via_scipy == pytest.approx( + blurred_masked_im_1.array, 1e-4 + ) def test__convolve_image_no_blurring(): @@ -404,9 +404,7 @@ def test__convolve_image_no_blurring(): kernel = np.arange(49).reshape(7, 7) image = np.arange(900).reshape(30, 30) - blurring_mask = mask.derive_mask.blurring_from( - kernel_shape_native=kernel.shape - ) + blurring_mask = mask.derive_mask.blurring_from(kernel_shape_native=kernel.shape) blurred_image_via_scipy = scipy.signal.convolve2d( image * blurring_mask, kernel, mode="same" ) @@ -424,22 +422,29 @@ def test__convolve_image_no_blurring(): masked_image = aa.Array2D(values=image.native, mask=mask) - blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image, mask=mask) + blurred_masked_im_1 = kernel.convolve_image_no_blurring( + image=masked_image, mask=mask + ) - assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4) + assert blurred_masked_image_via_scipy == pytest.approx( + blurred_masked_im_1.array, 1e-4 + ) def test__convolve_mapping_matrix(): - mask = aa.Mask2D(mask=np.array( - [ - [True, True, True, True, True, True], - [True, False, False, False, False, True], - [True, False, False, False, False, True], - [True, False, False, False, False, True], - [True, False, False, False, False, True], - [True, True, True, True, True, True], - ] - ), pixel_scales=1.0) + mask = aa.Mask2D( + mask=np.array( + [ + [True, True, True, True, True, True], + [True, False, False, False, False, True], + [True, False, False, False, False, True], + [True, False, False, False, False, True], + [True, False, False, False, False, True], + [True, True, True, True, True, True], + ] + ), + pixel_scales=1.0, + ) kernel = aa.Kernel2D.no_mask( values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 @@ -474,27 +479,30 @@ def test__convolve_mapping_matrix(): assert ( blurred_mapping - == pytest.approx(np.array( - [ - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [0, 0.4, 0], - [0, 0.2, 0], - [0.4, 0, 0], - [0.2, 0, 0.4], - [0.3, 0, 0.2], - [0, 0.1, 0.3], - [0, 0, 0], - [0.1, 0, 0], - [0, 0, 0.1], - [0, 0, 0], - ] - ) - ), 1.0e-4) + == pytest.approx( + np.array( + [ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0.4, 0], + [0, 0.2, 0], + [0.4, 0, 0], + [0.2, 0, 0.4], + [0.3, 0, 0.2], + [0, 0.1, 0.3], + [0, 0, 0], + [0.1, 0, 0], + [0, 0, 0.1], + [0, 0, 0], + ] + ) + ), + 1.0e-4, + ) kernel = aa.Kernel2D.no_mask( values=[[0, 0.0, 0], [0.4, 0.2, 0.3], [0, 0.1, 0]], pixel_scales=1.0 @@ -527,8 +535,6 @@ def test__convolve_mapping_matrix(): blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask) - print(blurred_mapping) - assert blurred_mapping == pytest.approx( np.array( [ diff --git a/test_autoarray/structures/grids/test_uniform_2d.py b/test_autoarray/structures/grids/test_uniform_2d.py index 686721c70..78813329e 100644 --- a/test_autoarray/structures/grids/test_uniform_2d.py +++ b/test_autoarray/structures/grids/test_uniform_2d.py @@ -575,7 +575,6 @@ def test__grid_2d_radial_projected_shape_slim_from(): pixel_scales=grid_2d.pixel_scales, ) - assert grid_radii == pytest.approx(grid_radii_util, 1.0e-4) assert grid_radial_shape_slim == grid_radii_util.shape[0] @@ -782,7 +781,6 @@ def test__grid_with_coordinates_within_distance_removed_from(): ).all() - def test__recursive_shape_storage(): grid_2d = aa.Grid2D.no_mask( values=[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], From ef5ba9efef1ea3c6234fcc5557c380e9b8e20db6 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 21:08:38 +0100 Subject: [PATCH 091/108] finish --- .../inversion/inversion/imaging/w_tilde.py | 2 +- autoarray/structures/arrays/kernel_2d.py | 12 ++++++------ .../arrays/files/array/output_test/array.fits | Bin 5760 -> 5760 bytes 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index aef314586..199ba66b2 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -525,7 +525,7 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: values=mapped_reconstructed_image, mask=self.mask ) - mapped_reconstructed_image = self.convolver.convolve_image_no_blurring( + mapped_reconstructed_image = self.psf.convolve_image_no_blurring( image=mapped_reconstructed_image, mask=self.mask ) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 5938198cf..bb62656a9 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -487,12 +487,12 @@ def convolved_array_from(self, array: Array2D) -> Array2D: def convolve_image(self, image, blurring_image, jax_method="fft"): """ - For a given 1D array and blurring array, convolve the two using this convolver. + For a given 1D array and blurring array, convolve the two using this psf. Parameters ---------- image - 1D array of the values which are to be blurred with the convolver's PSF. + 1D array of the values which are to be blurred with the psf's PSF. blurring_image 1D array of the blurring values which blur into the array after PSF convolution. jax_method @@ -530,12 +530,12 @@ def convolve_image(self, image, blurring_image, jax_method="fft"): def convolve_image_no_blurring(self, image, mask, jax_method="fft"): """ - For a given 1D array and blurring array, convolve the two using this convolver. + For a given 1D array and blurring array, convolve the two using this psf. Parameters ---------- image - 1D array of the values which are to be blurred with the convolver's PSF. + 1D array of the values which are to be blurred with the psf's PSF. blurring_image 1D array of the blurring values which blur into the array after PSF convolution. jax_method @@ -562,12 +562,12 @@ def convolve_image_no_blurring(self, image, mask, jax_method="fft"): return Array2D(values=convolved_array_1d, mask=mask) def convolve_mapping_matrix(self, mapping_matrix, mask): - """For a given 1D array and blurring array, convolve the two using this convolver. + """For a given 1D array and blurring array, convolve the two using this psf. Parameters ---------- image - 1D array of the values which are to be blurred with the convolver's PSF. + 1D array of the values which are to be blurred with the psf's PSF. """ return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None))( mapping_matrix, mask diff --git a/test_autoarray/structures/arrays/files/array/output_test/array.fits b/test_autoarray/structures/arrays/files/array/output_test/array.fits index dab9da2fb676efc023e53b73faa54fe78fc4000c..0cff0a8f7db90d3ded28c644cd66e3791627feeb 100644 GIT binary patch delta 84 hcmZqBZP1;N!(?W%G4C>$;|B&XuqT_|T*&>83jh#~5uX46 delta 56 kcmZqBZP1;N!(?o$Vgmz%Jzl(dBKLc)$qTsk0j$6e2mk;8 From 33a3ccb69b17af77457ca0d000539e942852a25f Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 14:50:01 +0100 Subject: [PATCH 092/108] fix some decorator unit tests --- autoarray/structures/decorators/abstract.py | 2 +- autoarray/structures/decorators/to_grid.py | 3 +-- autoarray/structures/mock/mock_decorators.py | 6 +++--- test_autoarray/structures/decorators/test_to_grid.py | 8 ++++---- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/autoarray/structures/decorators/abstract.py b/autoarray/structures/decorators/abstract.py index c9e5fca87..e033815b4 100644 --- a/autoarray/structures/decorators/abstract.py +++ b/autoarray/structures/decorators/abstract.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Union import numpy as np diff --git a/autoarray/structures/decorators/to_grid.py b/autoarray/structures/decorators/to_grid.py index 144137c69..4797c37ce 100644 --- a/autoarray/structures/decorators/to_grid.py +++ b/autoarray/structures/decorators/to_grid.py @@ -1,6 +1,5 @@ -from autoarray.numpy_wrapper import np from functools import wraps - +import numpy as np from typing import List, Union from autoarray.structures.decorators.abstract import AbstractMaker diff --git a/autoarray/structures/mock/mock_decorators.py b/autoarray/structures/mock/mock_decorators.py index 876b456d7..013b0a62a 100644 --- a/autoarray/structures/mock/mock_decorators.py +++ b/autoarray/structures/mock/mock_decorators.py @@ -116,7 +116,7 @@ def ndarray_2d_from(self, grid, *args, **kwargs): Such functions are common in **PyAutoGalaxy** for light and mass profile objects. """ - return np.multiply(2.0, grid) + return np.multiply(2.0, grid.array) @decorators.to_vector_yx def ndarray_yx_2d_from(self, grid, *args, **kwargs): @@ -146,7 +146,7 @@ def ndarray_2d_list_from(self, grid, *args, **kwargs): Such functions are common in **PyAutoGalaxy** for light and mass profile objects. """ - return [np.multiply(1.0, grid), np.multiply(2.0, grid)] + return [np.multiply(1.0, grid.array), np.multiply(2.0, grid.array)] @decorators.to_vector_yx def ndarray_yx_2d_list_from(self, grid, *args, **kwargs): @@ -156,7 +156,7 @@ def ndarray_yx_2d_list_from(self, grid, *args, **kwargs): Such functions are common in **PyAutoGalaxy** for light and mass profile objects. """ - return [np.multiply(1.0, grid), np.multiply(2.0, grid)] + return [np.multiply(1.0, grid.array), np.multiply(2.0, grid.array)] class MockGridRadialMinimum: diff --git a/test_autoarray/structures/decorators/test_to_grid.py b/test_autoarray/structures/decorators/test_to_grid.py index 60c70d71b..2e8b1be2f 100644 --- a/test_autoarray/structures/decorators/test_to_grid.py +++ b/test_autoarray/structures/decorators/test_to_grid.py @@ -15,11 +15,11 @@ def test__in_grid_1d__out_ndarray_2d(): assert isinstance(ndarray_2d, aa.Grid2D) assert ndarray_2d.native == pytest.approx( - np.array([[[0.0, 0.0], [0.0, -1.0], [0.0, 1.0], [0.0, 0.0]]]), 1.0e-4 + np.array([[[0.0, 0.0], [0.0, -1.0], [0.0, 1.0], [0.0, 0.0]]]), abs=1.0e-4 ) -def test__in_grid_1d__out_ndarray_2d_list(): +def test__in_dgrid_1d__out_ndarray_2d_list(): mask = aa.Mask1D(mask=[True, False, False, True], pixel_scales=(1.0,)) grid_1d = aa.Grid1D.from_mask(mask=mask) @@ -30,12 +30,12 @@ def test__in_grid_1d__out_ndarray_2d_list(): assert isinstance(ndarray_2d_list[0], aa.Grid2D) assert ndarray_2d_list[0].native == pytest.approx( - np.array([[[0.0, 0.0], [0.0, -0.5], [0.0, 0.5], [0.0, 0.0]]]), 1.0e-4 + np.array([[[0.0, 0.0], [0.0, -0.5], [0.0, 0.5], [0.0, 0.0]]]), abs=1.0e-4 ) assert isinstance(ndarray_2d_list[1], aa.Grid2D) assert ndarray_2d_list[1].native == pytest.approx( - np.array([[[0.0, 0.0], [0.0, -1.0], [0.0, 1.0], [0.0, 0.0]]]), 1.0e-4 + np.array([[[0.0, 0.0], [0.0, -1.0], [0.0, 1.0], [0.0, 0.0]]]), abs=1.0e-4 ) From 8b8dc9e9479dc17d9e91e5631e1033a75766a2bd Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 14:59:45 +0100 Subject: [PATCH 093/108] removing numpy wrapper to do explicit impots --- autoarray/abstract_ndarray.py | 17 +++++++++-------- .../operators/over_sampling/over_sampler.py | 13 +++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index 8b5fbc00e..ded8c5452 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -4,10 +4,11 @@ from abc import ABC from abc import abstractmethod +import jax.numpy as jnp from autoconf.fitsable import output_to_fits -from autoarray.numpy_wrapper import np, register_pytree_node, Array +from autoarray.numpy_wrapper import register_pytree_node, Array from typing import TYPE_CHECKING @@ -82,7 +83,7 @@ def __init__(self, array): def invert(self): new = self.copy() - new._array = np.invert(new._array) + new._array = jnp.invert(new._array) return new @classmethod @@ -104,7 +105,7 @@ def instance_flatten(cls, instance): @staticmethod def flip_hdu_for_ds9(values): if conf.instance["general"]["fits"]["flip_for_ds9"]: - return np.flipud(values) + return jnp.flipud(values) return values @classmethod @@ -117,7 +118,7 @@ def instance_unflatten(cls, aux_data, children): setattr(instance, key, value) return instance - def with_new_array(self, array: np.ndarray) -> "AbstractNDArray": + def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray": """ Copy this object but give it a new array. @@ -164,7 +165,7 @@ def __iter__(self): @to_new_array def sqrt(self): - return np.sqrt(self._array) + return jnp.sqrt(self._array) @property def array(self): @@ -330,13 +331,13 @@ def __getitem__(self, item): result = self._array[item] if isinstance(item, slice): result = self.with_new_array(result) - if isinstance(result, np.ndarray): + if isinstance(result, jnp.ndarray): result = self.with_new_array(result) return result def __setitem__(self, key, value): - if isinstance(key, (np.ndarray, AbstractNDArray, Array)): - self._array = np.where(key, value, self._array) + if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)): + self._array = jnp.where(key, value, self._array) else: self._array[key] = value diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 6f12f4b9f..65393709c 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -1,5 +1,6 @@ -from autoarray.numpy_wrapper import np -from typing import List, Tuple, Union +import numpy as np +import jax.numpy as jnp +from typing import Union from autoconf import conf from autoconf import cached_property @@ -184,7 +185,7 @@ def sub_pixel_areas(self) -> np.ndarray: """ The area of every sub-pixel in the mask. """ - sub_pixel_areas = np.zeros(self.sub_total) + sub_pixel_areas = jnp.zeros(self.sub_total) k = 0 @@ -221,9 +222,9 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": pass # binned_array_2d = over_sample_util.binned_array_2d_from( - # array_2d=np.array(array), - # mask_2d=np.array(self.mask), - # sub_size=np.array(self.sub_size).astype("int"), + # array_2d=jnp.array(array), + # mask_2d=jnp.array(self.mask), + # sub_size=jnp.array(self.sub_size).astype("int"), # ) binned_array_2d = array.reshape( From 7115f9cbc5893d768e52a08b4d1ff2c313b39e9f Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:19:22 +0100 Subject: [PATCH 094/108] move relocate radial --- autoarray/config/grids.yaml | 3 - autoarray/structures/decorators/__init__.py | 1 - .../structures/decorators/relocate_radial.py | 106 ------------------ autoarray/structures/mock/mock_decorators.py | 4 - test_autoarray/config/grids.yaml | 12 -- 5 files changed, 126 deletions(-) delete mode 100644 autoarray/config/grids.yaml delete mode 100644 autoarray/structures/decorators/relocate_radial.py delete mode 100644 test_autoarray/config/grids.yaml diff --git a/autoarray/config/grids.yaml b/autoarray/config/grids.yaml deleted file mode 100644 index f82eaa5df..000000000 --- a/autoarray/config/grids.yaml +++ /dev/null @@ -1,3 +0,0 @@ -radial_minimum: - function_name: - class_name: 1.0e-08 diff --git a/autoarray/structures/decorators/__init__.py b/autoarray/structures/decorators/__init__.py index d85cbe6ee..1efb9137e 100644 --- a/autoarray/structures/decorators/__init__.py +++ b/autoarray/structures/decorators/__init__.py @@ -4,4 +4,3 @@ from .to_grid import to_grid from .to_vector_yx import to_vector_yx from .transform import transform -from .relocate_radial import relocate_to_radial_minimum diff --git a/autoarray/structures/decorators/relocate_radial.py b/autoarray/structures/decorators/relocate_radial.py deleted file mode 100644 index 58411714f..000000000 --- a/autoarray/structures/decorators/relocate_radial.py +++ /dev/null @@ -1,106 +0,0 @@ -from autoarray.numpy_wrapper import np, use_jax -from functools import wraps - -from typing import Union - -from autoconf.exc import ConfigException - -from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.structures.grids.uniform_2d import Grid2D -from autoconf import conf - - -def relocate_to_radial_minimum(func): - """ - Checks whether any coordinates in the grid are radially near (0.0, 0.0), which can lead to numerical faults in - the evaluation of a function (e.g. numerical integration reaching a singularity at (0.0, 0.0)). - - If any coordinates are radially within the radial minimum threshold, their (y,x) coordinates are shifted to that - value to ensure they are evaluated at that coordinate. - - The value the (y,x) coordinates are rounded to is set in the 'radial_minimum.yaml' config. - - Parameters - ---------- - func - A function that takes a grid of coordinates which may have a singularity as (0.0, 0.0) - - Returns - ------- - A function that has an input grid whose radial coordinates are relocated to the radial minimum. - """ - - @wraps(func) - def wrapper( - obj: object, - grid: Union[np.ndarray, Grid2D, Grid2DIrregular], - *args, - **kwargs, - ) -> Union[np.ndarray, Grid2D, Grid2DIrregular]: - """ - Checks whether any coordinates in the grid are radially near (0.0, 0.0), which can lead to numerical faults in - the evaluation of a function (e.g. numerical integration reaching a singularity at (0.0, 0.0)). - - If any coordinates are radially within the radial minimum threshold, their (y,x) coordinates are shifted to that - value to ensure they are evaluated at that coordinate. - - The value the (y,x) coordinates are rounded to is set in the 'radial_minimum.yaml' config. - - Parameters - ---------- - obj - An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid. - grid - The (y, x) coordinates which are to be radially moved from (0.0, 0.0). - - Returns - ------- - The grid_like object whose coordinates are radially moved from (0.0, 0.0). - """ - if use_jax: - return func(obj, grid, *args, **kwargs) - - try: - grid_radial_minimum = conf.instance["grids"]["radial_minimum"][ - "radial_minimum" - ][obj.__class__.__name__] - - except KeyError as e: - raise ConfigException( - rf""" - The {obj.__class__.__name__} profile you are using does not have a corresponding - entry in the `config/grid.yaml` config file. - - When a profile is evaluated at (0.0, 0.0), they commonly break due to numericalinstabilities (e.g. - division by zero). To prevent this, the code relocates the (y,x) coordinates of the grid to a - minimum radial value, specified in the `config/grids.yaml` config file. - - For example, if the value in `grid.yaml` is `radial_minimum: 1e-6`, then any (y,x) coordinates - with a radial distance less than 1e-6 to (0.0, 0.0) are relocated to 1e-6. - - For a profile to be used it must have an entry in the `config/grids.yaml` config file. Go to this - file now and add your profile to the `radial_minimum` section. Adopting a value of 1e-6 is a good - default choice. - - If you are going to make a pull request to add your profile to the source code, you should also - add an entry to the `config/grids.yaml` config file of the source code itself - (e.g. `PyAutoGalaxy/autogalaxy/config/grids.yaml`). - """ - ) - - with np.errstate(all="ignore"): # Division by zero fixed via isnan - grid_radii = obj.radial_grid_from(grid=grid) - - grid_radial_scale = np.where( - grid_radii < grid_radial_minimum, grid_radial_minimum / grid_radii, 1.0 - ) - moved_grid = np.multiply(grid, grid_radial_scale[:, None]) - - if hasattr(grid, "with_new_array"): - moved_grid = grid.with_new_array(moved_grid) - - moved_grid[np.isnan(np.array(moved_grid))] = grid_radial_minimum - - return func(obj, moved_grid, *args, **kwargs) - - return wrapper diff --git a/autoarray/structures/mock/mock_decorators.py b/autoarray/structures/mock/mock_decorators.py index 013b0a62a..c02ebc0b8 100644 --- a/autoarray/structures/mock/mock_decorators.py +++ b/autoarray/structures/mock/mock_decorators.py @@ -165,7 +165,3 @@ def __init__(self): def radial_grid_from(self, grid): return np.sqrt(np.add(np.square(grid[:, 0]), np.square(grid[:, 1]))) - - @decorators.relocate_to_radial_minimum - def deflections_yx_2d_from(self, grid): - return grid diff --git a/test_autoarray/config/grids.yaml b/test_autoarray/config/grids.yaml deleted file mode 100644 index 61c268a27..000000000 --- a/test_autoarray/config/grids.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# Certain light and mass profile calculations become ill defined at (0.0, 0.0) or close to this value. This can lead -# to numerical issues in the calculation of the profile, for example a np.nan may arise, crashing the code. - -# To avoid this, we set a minimum value for the radial coordinate of the profile. If the radial coordinate is below -# this value, it is rounded up to this value. This ensures that the profile cannot receive a radial coordinate of 0.0. - -# For example, if an input grid coordinate has a radial coordinate of 1e-12, for most profiles this will be rounded up -# to radial_minimum=1e-08. This is a small enough value that it should not impact the results of the profile calculation. - -radial_minimum: - radial_minimum: - MockGridRadialMinimum: 2.5 \ No newline at end of file From 6f027158f823c43e3ab665285497b35ec424c5cf Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:26:25 +0100 Subject: [PATCH 095/108] more removal of numpy wrapper nps --- autoarray/mask/derive/indexes_2d.py | 3 ++- autoarray/mask/mask_2d_util.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index 062c8e664..0d0a36b26 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging +import numpy as np -from autoarray.numpy_wrapper import np, register_pytree_node_class +from autoarray.numpy_wrapper import register_pytree_node_class from typing import TYPE_CHECKING if TYPE_CHECKING: diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 462448073..10a40b473 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -1,10 +1,10 @@ import numpy as np +import jax.numpy as jnp from scipy.ndimage import convolve from typing import Tuple import warnings from autoarray import exc -from autoarray.numpy_wrapper import np as jnp def native_index_for_slim_index_2d_from( From aa4c9e6e5b8868d35fcfb6d9dce206b63f38e956 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:28:07 +0100 Subject: [PATCH 096/108] remove all numpy wrappers --- autoarray/operators/contour.py | 15 ++++++--------- autoarray/structures/decorators/to_vector_yx.py | 3 +-- autoarray/structures/decorators/transform.py | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index c7da5c7f1..c352ea19f 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -1,6 +1,6 @@ from __future__ import annotations -from autoarray.numpy_wrapper import np, use_jax import numpy +import jax.numpy as jnp from skimage import measure from scipy.spatial import ConvexHull from scipy.spatial import QhullError @@ -42,16 +42,13 @@ def contour_array(self): return self._contour_array pixel_centres = geometry_util.grid_pixel_centres_2d_slim_from( - grid_scaled_2d_slim=np.array(self.grid), + grid_scaled_2d_slim=jnp.array(self.grid), shape_native=self.shape_native, pixel_scales=self.pixel_scales, ).astype("int") - arr = np.zeros(self.shape_native) - if use_jax: - arr = arr.at[tuple(np.array(pixel_centres).T)].set(1) - else: - arr[tuple(np.array(pixel_centres).T)] = 1 + arr = jnp.zeros(self.shape_native) + arr = arr.at[tuple(jnp.array(pixel_centres).T)].set(1) return arr @@ -74,7 +71,7 @@ def contour_list(self): pixel_scales=self.pixel_scales, ) - factor = 0.5 * np.array(self.pixel_scales) * np.array([-1.0, 1.0]) + factor = 0.5 * jnp.array(self.pixel_scales) * jnp.array([-1.0, 1.0]) grid_scaled_1d += factor contour_list.append(Grid2DIrregular(values=grid_scaled_1d)) @@ -104,7 +101,7 @@ def hull( hull_x = grid_convex[hull_vertices, 0] hull_y = grid_convex[hull_vertices, 1] - grid_hull = np.zeros((len(hull_vertices), 2)) + grid_hull = jnp.zeros((len(hull_vertices), 2)) grid_hull[:, 1] = hull_x grid_hull[:, 0] = hull_y diff --git a/autoarray/structures/decorators/to_vector_yx.py b/autoarray/structures/decorators/to_vector_yx.py index 1cf23346d..90aea99ea 100644 --- a/autoarray/structures/decorators/to_vector_yx.py +++ b/autoarray/structures/decorators/to_vector_yx.py @@ -1,6 +1,5 @@ -from autoarray.numpy_wrapper import np from functools import wraps - +import numpy as np from typing import List, Union from autoarray.structures.decorators.abstract import AbstractMaker diff --git a/autoarray/structures/decorators/transform.py b/autoarray/structures/decorators/transform.py index bd837a399..eca0d883b 100644 --- a/autoarray/structures/decorators/transform.py +++ b/autoarray/structures/decorators/transform.py @@ -1,5 +1,5 @@ -from autoarray.numpy_wrapper import np from functools import wraps +import numpy as np from typing import Union from autoarray.structures.grids.uniform_1d import Grid1D From 3b6ab48b21dff83bc144ec19d48d408fc604793b Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:36:33 +0100 Subject: [PATCH 097/108] remove warning for now --- autoarray/numba_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/numba_util.py b/autoarray/numba_util.py index db34f3e1a..9e0298b73 100644 --- a/autoarray/numba_util.py +++ b/autoarray/numba_util.py @@ -33,7 +33,7 @@ try: if os.environ.get("USE_JAX") == "1": - logger.warning("JAX and numba do not work together, so JAX is being used.") + 1 else: import numba From 44a2808e1761798ae0ab077a07a94408c43e843d Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:48:11 +0100 Subject: [PATCH 098/108] fix structure plotters --- autoarray/operators/contour.py | 14 +++++++------- autoarray/plot/wrap/two_d/array_overlay.py | 4 +++- .../structures/plot/test_structure_plotters.py | 1 + 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index c352ea19f..2de247d3c 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy +import numpy as np import jax.numpy as jnp from skimage import measure from scipy.spatial import ConvexHull @@ -42,7 +42,7 @@ def contour_array(self): return self._contour_array pixel_centres = geometry_util.grid_pixel_centres_2d_slim_from( - grid_scaled_2d_slim=jnp.array(self.grid), + grid_scaled_2d_slim=np.array(self.grid), shape_native=self.shape_native, pixel_scales=self.pixel_scales, ).astype("int") @@ -56,7 +56,7 @@ def contour_array(self): def contour_list(self): # make sure to use base numpy to convert JAX array back to a normal array contour_indices_list = measure.find_contours( - numpy.array(self.contour_array.array), 0 + np.array(self.contour_array), 0 ) if len(contour_indices_list) == 0: @@ -71,7 +71,7 @@ def contour_list(self): pixel_scales=self.pixel_scales, ) - factor = 0.5 * jnp.array(self.pixel_scales) * jnp.array([-1.0, 1.0]) + factor = 0.5 * np.array(self.pixel_scales) * np.array([-1.0, 1.0]) grid_scaled_1d += factor contour_list.append(Grid2DIrregular(values=grid_scaled_1d)) @@ -86,10 +86,10 @@ def hull( return None # cast JAX arrays to base numpy arrays - grid_convex = numpy.zeros((len(self.grid), 2)) + grid_convex = np.zeros((len(self.grid), 2)) - grid_convex[:, 0] = numpy.array(self.grid[:, 1]) - grid_convex[:, 1] = numpy.array(self.grid[:, 0]) + grid_convex[:, 0] = np.array(self.grid[:, 1]) + grid_convex[:, 1] = np.array(self.grid[:, 0]) try: hull = ConvexHull(grid_convex) diff --git a/autoarray/plot/wrap/two_d/array_overlay.py b/autoarray/plot/wrap/two_d/array_overlay.py index 57652e8df..5de20b879 100644 --- a/autoarray/plot/wrap/two_d/array_overlay.py +++ b/autoarray/plot/wrap/two_d/array_overlay.py @@ -19,4 +19,6 @@ def overlay_array(self, array, figure): aspect = figure.aspect_from(shape_native=array.shape_native) extent = array.extent_of_zoomed_array(buffer=0) - plt.imshow(X=array.native, aspect=aspect, extent=extent, **self.config_dict) + print(type(array)) + + plt.imshow(X=array.native._array, aspect=aspect, extent=extent, **self.config_dict) diff --git a/test_autoarray/structures/plot/test_structure_plotters.py b/test_autoarray/structures/plot/test_structure_plotters.py index ad1ca0251..d455c86f4 100644 --- a/test_autoarray/structures/plot/test_structure_plotters.py +++ b/test_autoarray/structures/plot/test_structure_plotters.py @@ -3,6 +3,7 @@ from os import path import pytest import numpy as np +import jax.numpy as jnp import shutil directory = path.dirname(path.realpath(__file__)) From ea139fcd95bd3400b620df39041e40d0470beba0 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:56:03 +0100 Subject: [PATCH 099/108] clean up vectors_yx --- autoarray/structures/vectors/uniform.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/autoarray/structures/vectors/uniform.py b/autoarray/structures/vectors/uniform.py index 89d589139..fc66b0e8d 100644 --- a/autoarray/structures/vectors/uniform.py +++ b/autoarray/structures/vectors/uniform.py @@ -1,7 +1,8 @@ import logging -# import numpy as np -from autofit.jax_wrapper import numpy as np, use_jax +import numpy as np +import jax.numpy as jnp +# from autofit.jax_wrapper import numpy as np, use_jax from typing import List, Optional, Tuple, Union from autoarray.structures.arrays.uniform_2d import Array2D @@ -396,11 +397,7 @@ def magnitudes(self) -> Array2D: """ Returns the magnitude of every vector which are computed as sqrt(y**2 + x**2). """ - if use_jax: - s = self.array - else: - s = self - return Array2D(values=np.sqrt(s[:, 0] ** 2.0 + s[:, 1] ** 2.0), mask=self.mask) + return Array2D(values=jnp.sqrt(self.array[:, 0] ** 2.0 + self.array[:, 1] ** 2.0), mask=self.mask) @property def y(self) -> Array2D: From ec1e81eb0066bd638c0d48dda33d95d14169e27d Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:57:41 +0100 Subject: [PATCH 100/108] remove autofit imports --- autoarray/geometry/geometry_2d.py | 2 -- autoarray/operators/over_sampling/over_sampler.py | 2 +- autoarray/structures/vectors/uniform.py | 1 - test_autoarray/test_jax_changes.py | 8 +++++--- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/autoarray/geometry/geometry_2d.py b/autoarray/geometry/geometry_2d.py index e78f0f75a..9eea7e9f2 100644 --- a/autoarray/geometry/geometry_2d.py +++ b/autoarray/geometry/geometry_2d.py @@ -13,8 +13,6 @@ from autoarray import type as ty from autoarray.geometry import geometry_util -from autofit.jax_wrapper import use_jax - logging.basicConfig() logger = logging.getLogger(__name__) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 65393709c..ae458e41b 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -10,7 +10,7 @@ from autoarray.operators.over_sampling import over_sample_util -from autofit.jax_wrapper import register_pytree_node_class +from autoarray.numpy_wrapper import register_pytree_node_class @register_pytree_node_class diff --git a/autoarray/structures/vectors/uniform.py b/autoarray/structures/vectors/uniform.py index fc66b0e8d..6213ecfbf 100644 --- a/autoarray/structures/vectors/uniform.py +++ b/autoarray/structures/vectors/uniform.py @@ -2,7 +2,6 @@ import numpy as np import jax.numpy as jnp -# from autofit.jax_wrapper import numpy as np, use_jax from typing import List, Optional, Tuple, Union from autoarray.structures.arrays.uniform_2d import Array2D diff --git a/test_autoarray/test_jax_changes.py b/test_autoarray/test_jax_changes.py index f5104a942..2b6289317 100644 --- a/test_autoarray/test_jax_changes.py +++ b/test_autoarray/test_jax_changes.py @@ -1,8 +1,10 @@ -import autoarray as aa +import jax.numpy as jnp import pytest + +import autoarray as aa + from autoarray import Grid2D, Mask2D -from autofit.jax_wrapper import numpy as np @pytest.fixture(name="array") @@ -33,4 +35,4 @@ def test_boolean_issue(): mask=Mask2D.all_false((10, 10), pixel_scales=1.0), ) values, keys = Grid2D.instance_flatten(grid) - np.array(Grid2D.instance_unflatten(keys, values)) + jnp.array(Grid2D.instance_unflatten(keys, values)) From 37e81f157a2acd3dff2f55ab546a834ece26ddd9 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 18:04:21 +0100 Subject: [PATCH 101/108] fix voronoi unit test in structures --- autoarray/inversion/pixelization/image_mesh/overlay.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autoarray/inversion/pixelization/image_mesh/overlay.py b/autoarray/inversion/pixelization/image_mesh/overlay.py index c5bc7eaef..de130ee6e 100644 --- a/autoarray/inversion/pixelization/image_mesh/overlay.py +++ b/autoarray/inversion/pixelization/image_mesh/overlay.py @@ -220,11 +220,11 @@ def image_plane_mesh_grid_from( origin=origin, ) - overlaid_centres = geometry_util.grid_pixel_centres_2d_slim_from( + overlaid_centres = np.array(geometry_util.grid_pixel_centres_2d_slim_from( grid_scaled_2d_slim=unmasked_overlay_grid, shape_native=mask.shape_native, pixel_scales=mask.pixel_scales, - ).astype("int") + )).astype("int") total_pixels = total_pixels_2d_from( mask_2d=mask.array, From b31c0fc24e248e2a82dbf40d7f5fdfed9d4b0a9e Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:24:39 +0100 Subject: [PATCH 102/108] fix test_preprocess --- autoarray/dataset/preprocess.py | 12 ++++++------ test_autoarray/dataset/test_preprocess.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/autoarray/dataset/preprocess.py b/autoarray/dataset/preprocess.py index 5c7338204..f13af3184 100644 --- a/autoarray/dataset/preprocess.py +++ b/autoarray/dataset/preprocess.py @@ -263,15 +263,15 @@ def edges_from(image, no_edges): edges = [] for edge_no in range(no_edges): - top_edge = image.native[edge_no, edge_no : image.shape_native[1] - edge_no] - bottom_edge = image.native[ + top_edge = image.native.array[edge_no, edge_no : image.shape_native[1] - edge_no] + bottom_edge = image.native.array[ image.shape_native[0] - 1 - edge_no, edge_no : image.shape_native[1] - edge_no, ] - left_edge = image.native[ + left_edge = image.native.array[ edge_no + 1 : image.shape_native[0] - 1 - edge_no, edge_no ] - right_edge = image.native[ + right_edge = image.native.array[ edge_no + 1 : image.shape_native[0] - 1 - edge_no, image.shape_native[1] - 1 - edge_no, ] @@ -517,8 +517,8 @@ def noise_map_with_signal_to_noise_limit_from( noise_map_limit = np.where( (signal_to_noise_map.native > signal_to_noise_limit) & (noise_limit_mask == False), - np.abs(data.native) / signal_to_noise_limit, - noise_map.native, + np.abs(data.native.array) / signal_to_noise_limit, + noise_map.native.array, ) mask = Mask2D.all_false( diff --git a/test_autoarray/dataset/test_preprocess.py b/test_autoarray/dataset/test_preprocess.py index 74e8ef774..f484fd648 100644 --- a/test_autoarray/dataset/test_preprocess.py +++ b/test_autoarray/dataset/test_preprocess.py @@ -462,7 +462,7 @@ def test__background_noise_map_via_edges_of_image_from_4(): ) assert np.allclose( - background_noise_map.native, + background_noise_map.native.array, np.full(fill_value=np.std(np.arange(28)), shape=image.shape_native), ) @@ -486,7 +486,7 @@ def test__background_noise_map_via_edges_of_image_from_5(): ) assert np.allclose( - background_noise_map.native, + background_noise_map.native.array, np.full(fill_value=np.std(np.arange(48)), shape=image.shape_native), ) From 80fc8e8781d3d7b8bcb50ca2698bd5cbda54b8ae Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:25:38 +0100 Subject: [PATCH 103/108] fix test dataset abstract --- autoarray/dataset/imaging/dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index e3ec74d3b..b5b5b73d7 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -166,8 +166,9 @@ def __init__( self.psf = psf - if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") + if psf is not None: + if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: + raise exc.KernelException("Kernel2D Kernel2D must be odd") @cached_property def grids(self): From cd276cd3146e4d03beede6a9ffe366a1ab0f93ba Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:38:09 +0100 Subject: [PATCH 104/108] fix test imaging --- autoarray/dataset/imaging/dataset.py | 18 +++++++++--------- test_autoarray/dataset/imaging/test_dataset.py | 9 +++++++-- .../dataset/imaging/test_simulator.py | 2 +- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index b5b5b73d7..8d84ee1b8 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -204,9 +204,9 @@ def w_tilde(self): indexes, lengths, ) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from( - noise_map_native=np.array(self.noise_map.native), - kernel_native=np.array(self.psf.native), - native_index_for_slim_index=self.mask.derive_indexes.native_for_slim, + noise_map_native=np.array(self.noise_map.native.array).astype("float64"), + kernel_native=np.array(self.psf.native.array).astype("float64"), + native_index_for_slim_index=np.array(self.mask.derive_indexes.native_for_slim).astype("int"), ) return WTildeImaging( @@ -409,20 +409,20 @@ def apply_noise_scaling( """ if signal_to_noise_value is None: - noise_map = self.noise_map.native - noise_map[mask == False] = noise_value + noise_map = np.array(self.noise_map.native.array) + noise_map[mask.array == False] = noise_value else: noise_map = np.where( mask == False, - np.median(self.data.native[mask.derive_mask.edge == False]) + np.median(self.data.native.array[mask.derive_mask.edge == False]) / signal_to_noise_value, - self.noise_map.native, + self.noise_map.native.array, ) if should_zero_data: - data = np.where(np.invert(mask), 0.0, self.data.native) + data = np.where(np.invert(mask.array), 0.0, self.data.native.array) else: - data = self.data.native + data = self.data.native.array data_unmasked = Array2D.no_mask( values=data, diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index 45ef0a80f..ca33f1b40 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -139,7 +139,7 @@ def test__apply_mask(imaging_7x7, mask_2d_7x7, psf_3x3): == 2.0 * np.ones((7, 7)) * np.invert(mask_2d_7x7) ).all() - assert (masked_imaging_7x7.psf.slim == (1.0 / 3.0) * psf_3x3.slim).all() + assert masked_imaging_7x7.psf.slim == pytest.approx((1.0 / 3.0) * psf_3x3.slim, 1.0e-4) assert type(masked_imaging_7x7.psf) == aa.Kernel2D assert masked_imaging_7x7.w_tilde.curvature_preload.shape == (35,) @@ -244,4 +244,9 @@ def test__noise_map_unmasked_has_zeros_or_negative__raises_exception(): def test__psf_not_odd_x_odd_kernel__raises_error(): with pytest.raises(exc.KernelException): - aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) + image = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) + noise_map = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) + psf = aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) + + dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=False) + diff --git a/test_autoarray/dataset/imaging/test_simulator.py b/test_autoarray/dataset/imaging/test_simulator.py index 3cc4182d3..54dc1f6ed 100644 --- a/test_autoarray/dataset/imaging/test_simulator.py +++ b/test_autoarray/dataset/imaging/test_simulator.py @@ -70,7 +70,7 @@ def test__via_image_from__psf_off__noise_off_value_is_noise_value( == np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]) ).all() - assert np.allclose(dataset.noise_map.native, 0.2 * np.ones((3, 3))) + assert np.allclose(dataset.noise_map.native.array, 0.2 * np.ones((3, 3))) def test__via_image_from__psf_off__background_sky_on(image_central_delta_3x3): From f6dfda50b5c2a03db875d0beda7bee7a39ebdbb5 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:42:00 +0100 Subject: [PATCH 105/108] fix layout --- test_autoarray/layout/test_region.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test_autoarray/layout/test_region.py b/test_autoarray/layout/test_region.py index 690cfa9bb..643b8532a 100644 --- a/test_autoarray/layout/test_region.py +++ b/test_autoarray/layout/test_region.py @@ -152,7 +152,7 @@ def test__slice_2d__addition(): image = np.ones((2, 2)) region = aa.Region2D(region=(0, 1, 0, 1)) - array[region.slice] += image[region.slice] + array = array.at[region.slice].add(image[region.slice]) assert (array == np.array([[1.0, 0.0], [0.0, 0.0]])).all() @@ -161,7 +161,7 @@ def test__slice_2d__addition(): image = np.ones((2, 2)) region = aa.Region2D(region=(0, 1, 0, 1)) - array[region.slice] += image[region.slice] + array = array.at[region.slice].add(image[region.slice]) assert (array == np.array([[2.0, 1.0], [1.0, 1.0]])).all() @@ -170,7 +170,7 @@ def test__slice_2d__addition(): image = np.ones((3, 3)) region = aa.Region2D(region=(1, 3, 2, 3)) - array[region.slice] += image[region.slice] + array = array.at[region.slice].add(image[region.slice]) assert ( array == np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [1.0, 1.0, 2.0]]) @@ -183,7 +183,7 @@ def test__slice_2d__set_to_zerose(): region = aa.Region2D(region=(0, 1, 0, 1)) - array[region.slice] = 0 + array = array.at[region.slice].set(0) assert (array == np.array([[0.0, 1.0], [1.0, 1.0]])).all() @@ -192,7 +192,7 @@ def test__slice_2d__set_to_zerose(): region = aa.Region2D(region=(1, 3, 2, 3)) - array[region.slice] = 0 + array = array.at[region.slice].set(0) assert ( array == np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) From 5430647dbf983d9bac7e7de0107b855162f0d117 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:46:46 +0100 Subject: [PATCH 106/108] fix plot unit tests --- autoarray/plot/wrap/two_d/contour.py | 4 ++-- test_autoarray/plot/include/test_include.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autoarray/plot/wrap/two_d/contour.py b/autoarray/plot/wrap/two_d/contour.py index c80159813..164fe86be 100644 --- a/autoarray/plot/wrap/two_d/contour.py +++ b/autoarray/plot/wrap/two_d/contour.py @@ -93,10 +93,10 @@ def set( config_dict.pop("use_log10") config_dict.pop("include_values") - levels = self.levels_from(array) + levels = self.levels_from(array.array) ax = plt.contour( - array.native[::-1], levels=levels, extent=extent, **config_dict + array.native.array[::-1], levels=levels, extent=extent, **config_dict ) if self.include_values: try: diff --git a/test_autoarray/plot/include/test_include.py b/test_autoarray/plot/include/test_include.py index 6f4d29c77..b32616e9d 100644 --- a/test_autoarray/plot/include/test_include.py +++ b/test_autoarray/plot/include/test_include.py @@ -6,7 +6,7 @@ def test__loads_default_values_from_config_if_not_input(): assert include.origin is True assert include.mask == True - assert include.border is True + assert include.border is False assert include.parallel_overscan is True assert include.serial_prescan is True assert include.serial_overscan is False From 083ed0bb2f08eed82eb1520b929347036075690c Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 15:50:11 +0100 Subject: [PATCH 107/108] over sampling unit tests --- .../operators/over_sampling/over_sampler.py | 37 ++++++++++++++----- .../over_sample/test_over_sampler.py | 10 +++++ 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index ae458e41b..9fda67bb7 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -147,6 +147,16 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): over_sample_size=sub_size, mask=mask ) + + @property + def sub_is_uniform(self) -> bool: + """ + Returns True if the sub_size is uniform across all pixels in the mask. + """ + return np.all( + np.isclose(self.sub_size.array, self.sub_size.array[0]) + ) + def tree_flatten(self): return (self.mask, self.sub_size), () @@ -185,7 +195,7 @@ def sub_pixel_areas(self) -> np.ndarray: """ The area of every sub-pixel in the mask. """ - sub_pixel_areas = jnp.zeros(self.sub_total) + sub_pixel_areas = np.zeros(self.sub_total) k = 0 @@ -221,15 +231,24 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": except AttributeError: pass - # binned_array_2d = over_sample_util.binned_array_2d_from( - # array_2d=jnp.array(array), - # mask_2d=jnp.array(self.mask), - # sub_size=jnp.array(self.sub_size).astype("int"), - # ) + if self.sub_is_uniform: + binned_array_2d = array.reshape( + self.mask.shape_slim, self.sub_size[0] ** 2 + ).mean(axis=1) + else: + + # Define group sizes + group_sizes = jnp.array(self.sub_size.array.astype("int") ** 2) + + # Compute the cumulative sum of group sizes to get split points + split_indices = jnp.cumsum(group_sizes) + + # Ensure correct concatenation by making 0 a JAX array + start_indices = jnp.concatenate((jnp.array([0]), split_indices[:-1])) - binned_array_2d = array.reshape( - self.mask.shape_slim, self.sub_size[0] ** 2 - ).mean(axis=1) + # Compute the group means + binned_array_2d = jnp.array( + [array[start:end].mean() for start, end in zip(start_indices, split_indices)]) return Array2D( values=binned_array_2d, diff --git a/test_autoarray/operators/over_sample/test_over_sampler.py b/test_autoarray/operators/over_sample/test_over_sampler.py index 7da24d79a..a32b11e8f 100644 --- a/test_autoarray/operators/over_sample/test_over_sampler.py +++ b/test_autoarray/operators/over_sample/test_over_sampler.py @@ -70,6 +70,16 @@ def test__binned_array_2d_from(): pixel_scales=1.0, ) + over_sampling = aa.OverSampler( + mask=mask, sub_size=aa.Array2D(values=[2, 2], mask=mask) + ) + + arr = np.array([1.0, 5.0, 7.0, 10.0, 10.0, 10.0, 10.0, 10.0]) + + binned_array_2d = over_sampling.binned_array_2d_from(array=arr) + + assert binned_array_2d.slim == pytest.approx(np.array([5.75, 10.0]), 1.0e-4) + over_sampling = aa.OverSampler( mask=mask, sub_size=aa.Array2D(values=[1, 2], mask=mask) ) From 72af86b04b70f94ae3514f8c28ff38d85abfe133 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 16:07:46 +0100 Subject: [PATCH 108/108] fix all fit tests --- autoarray/fit/fit_util.py | 87 +++++------- test_autoarray/fit/test_fit_util.py | 199 ++++++++++++++-------------- 2 files changed, 135 insertions(+), 151 deletions(-) diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index d40f55d1c..10f24f9a7 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -1,5 +1,6 @@ from functools import wraps -import jax.numpy as np +import jax.numpy as jnp +import numpy as np from autoarray.mask.abstract_mask import Mask @@ -83,7 +84,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - return np.sum(chi_squared_map._array) + return jnp.sum(np.array(chi_squared_map)) def noise_normalization_from(*, noise_map: ty.DataLike) -> float: @@ -97,12 +98,12 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float: noise_map The masked noise-map of the dataset. """ - return np.sum(np.log(2 * np.pi * noise_map._array**2.0)) + return jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map)**2.0)) def normalized_residual_map_complex_from( - *, residual_map: np.ndarray, noise_map: np.ndarray -) -> np.ndarray: + *, residual_map: jnp.ndarray, noise_map: jnp.ndarray +) -> jnp.ndarray: """ Returns the normalized residual-map of the fit of complex model-data to a dataset, where: @@ -126,8 +127,8 @@ def normalized_residual_map_complex_from( def chi_squared_map_complex_from( - *, residual_map: np.ndarray, noise_map: np.ndarray -) -> np.ndarray: + *, residual_map: jnp.ndarray, noise_map: jnp.ndarray +) -> jnp.ndarray: """ Returnss the chi-squared-map of the fit of complex model-data to a dataset, where: @@ -145,7 +146,7 @@ def chi_squared_map_complex_from( return chi_squared_map_real + 1j * chi_squared_map_imag -def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float: +def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float: """ Returns the chi-squared terms of each complex model data's fit to a masked dataset, by summing the masked chi-squared-map of the fit. @@ -157,12 +158,12 @@ def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - chi_squared_real = np.sum(chi_squared_map.real) - chi_squared_imag = np.sum(chi_squared_map.imag) + chi_squared_real = jnp.sum(chi_squared_map.real) + chi_squared_imag = jnp.sum(chi_squared_map.imag) return chi_squared_real + chi_squared_imag -def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float: +def noise_normalization_complex_from(*, noise_map: jnp.ndarray) -> float: """ Returns the noise-map normalization terms of a complex noise-map, summing the noise_map value in every pixel as: @@ -173,8 +174,8 @@ def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float: noise_map The masked noise-map of the dataset. """ - noise_normalization_real = np.sum(np.log(2 * np.pi * noise_map.real**2.0)) - noise_normalization_imag = np.sum(np.log(2 * np.pi * noise_map.imag**2.0)) + noise_normalization_real = jnp.sum(jnp.log(2 * jnp.pi * noise_map.real**2.0)) + noise_normalization_imag = jnp.sum(jnp.log(2 * jnp.pi * noise_map.imag**2.0)) return noise_normalization_real + noise_normalization_imag @@ -198,9 +199,7 @@ def residual_map_with_mask_from( model_data The model data used to fit the data. """ - return np.subtract( - data, model_data, out=np.zeros_like(data), where=np.asarray(mask) == 0 - ) + return jnp.where(jnp.asarray(mask) == 0, jnp.subtract(data, model_data), 0) @to_new_array @@ -223,13 +222,7 @@ def normalized_residual_map_with_mask_from( mask The mask applied to the residual-map, where `False` entries are included in the calculation. """ - return np.divide( - residual_map, - noise_map, - out=np.zeros_like(residual_map), - where=np.asarray(mask) == 0, - ) - + return jnp.where(jnp.asarray(mask) == 0, jnp.divide(residual_map, noise_map), 0) @to_new_array def chi_squared_map_with_mask_from( @@ -251,13 +244,10 @@ def chi_squared_map_with_mask_from( mask The mask applied to the residual-map, where `False` entries are included in the calculation. """ - return np.square( - np.divide( - residual_map, - noise_map, - out=np.zeros_like(residual_map), - where=np.asarray(mask) == 0, - ) + return jnp.where( + jnp.asarray(mask) == 0, + jnp.square(residual_map / noise_map), + 0 ) @@ -275,7 +265,7 @@ def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask) -> f mask The mask applied to the chi-squared-map, where `False` entries are included in the calculation. """ - return float(np.sum(chi_squared_map[np.asarray(mask) == 0])) + return float(jnp.sum(chi_squared_map[jnp.asarray(mask) == 0])) def chi_squared_with_mask_fast_from( @@ -302,14 +292,14 @@ def chi_squared_with_mask_fast_from( The mask applied to the chi-squared-map, where `False` entries are included in the calculation. """ return float( - np.sum( - np.square( - np.divide( - np.subtract( + jnp.sum( + jnp.square( + jnp.divide( + jnp.subtract( data, model_data, - )[np.asarray(mask) == 0], - noise_map[np.asarray(mask) == 0], + )[jnp.asarray(mask) == 0], + noise_map[jnp.asarray(mask) == 0], ) ) ) @@ -331,11 +321,11 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask) -> mask The mask applied to the noise-map, where `False` entries are included in the calculation. """ - return float(np.sum(np.log(2 * np.pi * noise_map[np.asarray(mask) == 0] ** 2.0))) + return float(jnp.sum(jnp.log(2 * jnp.pi * noise_map[jnp.asarray(mask) == 0] ** 2.0))) def chi_squared_with_noise_covariance_from( - *, residual_map: ty.DataLike, noise_covariance_matrix_inv: np.ndarray + *, residual_map: ty.DataLike, noise_covariance_matrix_inv: jnp.ndarray ) -> float: """ Returns the chi-squared value of the fit of model-data to a masked dataset, where @@ -351,7 +341,7 @@ def chi_squared_with_noise_covariance_from( The inverse of the noise covariance matrix. """ - return residual_map @ noise_covariance_matrix_inv @ residual_map + return residual_map.array @ noise_covariance_matrix_inv @ residual_map.array def log_likelihood_from(*, chi_squared: float, noise_normalization: float) -> float: @@ -431,8 +421,8 @@ def log_evidence_from( def residual_flux_fraction_map_from( - *, residual_map: np.ndarray, data: np.ndarray -) -> np.ndarray: + *, residual_map: jnp.ndarray, data: jnp.ndarray +) -> jnp.ndarray: """ Returns the residual flux fraction map of the fit of model-data to a masked dataset, where: @@ -445,12 +435,12 @@ def residual_flux_fraction_map_from( data The data of the dataset. """ - return np.divide(residual_map, data, out=np.zeros_like(residual_map)) + return jnp.where(data != 0, residual_map / data, 0) def residual_flux_fraction_map_with_mask_from( - *, residual_map: np.ndarray, data: np.ndarray, mask: Mask -) -> np.ndarray: + *, residual_map: jnp.ndarray, data: jnp.ndarray, mask: Mask +) -> jnp.ndarray: """ Returnss the residual flux fraction map of the fit of model-data to a masked dataset, where: @@ -467,9 +457,4 @@ def residual_flux_fraction_map_with_mask_from( mask The mask applied to the residual-map, where `False` entries are included in the calculation. """ - return np.divide( - residual_map, - data, - out=np.zeros_like(residual_map), - where=np.asarray(mask) == 0, - ) + return jnp.where(mask == 0, residual_map / data, 0) \ No newline at end of file diff --git a/test_autoarray/fit/test_fit_util.py b/test_autoarray/fit/test_fit_util.py index 641dbf52e..6bbb5f871 100644 --- a/test_autoarray/fit/test_fit_util.py +++ b/test_autoarray/fit/test_fit_util.py @@ -1,41 +1,41 @@ import autoarray as aa -import numpy as np +import jax.numpy as jnp import pytest def test__residual_map_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) - assert (residual_map == np.array([0.0, 0.0, 0.0, 0.0])).all() + assert (residual_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() - data = np.array([10.0, 10.0, 10.0, 10.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) - assert (residual_map == np.array([-1.0, 0.0, 1.0, 2.0])).all() + assert (residual_map == jnp.array([-1.0, 0.0, 1.0, 2.0])).all() def test__residual_map_with_mask_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data ) - assert (residual_map == np.array([0.0, 0.0, 1.0, 0.0])).all() + assert (residual_map == jnp.array([0.0, 0.0, 1.0, 0.0])).all() def test__normalized_residual_map_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -43,9 +43,9 @@ def test__normalized_residual_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert (normalized_residual_map == np.array([0.0, 0.0, 0.0, 0.0])).all() + assert normalized_residual_map == pytest.approx(jnp.array([0.0, 0.0, 0.0, 0.0]), 1.0e-4) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -53,17 +53,14 @@ def test__normalized_residual_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert ( - normalized_residual_map - == np.array([-(1.0 / 2.0), 0.0, (1.0 / 2.0), (2.0 / 2.0)]) - ).all() + assert normalized_residual_map == pytest.approx(jnp.array([-(1.0 / 2.0), 0.0, (1.0 / 2.0), (2.0 / 2.0)]), 1.0e-4) def test__normalized_residual_map_with_mask_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -73,13 +70,15 @@ def test__normalized_residual_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - assert (normalized_residual_map == np.array([0.0, 0.0, (1.0 / 2.0), 0.0])).all() + print(normalized_residual_map) + + assert normalized_residual_map == pytest.approx(jnp.array([0.0, 0.0, (1.0 / 2.0), 0.0]), abs=1.0e-4) def test__normalized_residual_map_complex_from(): - data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = np.array([2.0 + 2.0j, 2.0 + 2.0j]) - model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) + data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = jnp.array([2.0 + 2.0j, 2.0 + 2.0j]) + model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -87,13 +86,13 @@ def test__normalized_residual_map_complex_from(): residual_map=residual_map, noise_map=noise_map ) - assert (normalized_residual_map == np.array([0.5 - 1.0j, 0.5 - 1.0j])).all() + assert (normalized_residual_map == jnp.array([0.5 - 1.0j, 0.5 - 1.0j])).all() def test__chi_squared_map_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -101,9 +100,9 @@ def test__chi_squared_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert (chi_squared_map == np.array([0.0, 0.0, 0.0, 0.0])).all() + assert (chi_squared_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -113,15 +112,15 @@ def test__chi_squared_map_from(): assert ( chi_squared_map - == np.array([(1.0 / 2.0) ** 2.0, 0.0, (1.0 / 2.0) ** 2.0, (2.0 / 2.0) ** 2.0]) + == jnp.array([(1.0 / 2.0) ** 2.0, 0.0, (1.0 / 2.0) ** 2.0, (2.0 / 2.0) ** 2.0]) ).all() def test__chi_squared_map_with_mask_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -131,9 +130,9 @@ def test__chi_squared_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - assert (chi_squared_map == np.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() + assert (chi_squared_map == jnp.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -143,13 +142,13 @@ def test__chi_squared_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - assert (chi_squared_map == np.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() + assert (chi_squared_map == jnp.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() def test__chi_squared_map_complex_from(): - data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = np.array([2.0 + 2.0j, 2.0 + 2.0j]) - model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) + data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = jnp.array([2.0 + 2.0j, 2.0 + 2.0j]) + model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -157,13 +156,13 @@ def test__chi_squared_map_complex_from(): residual_map=residual_map, noise_map=noise_map ) - assert (chi_squared_map == np.array([0.25 + 1.0j, 0.25 + 1.0j])).all() + assert (chi_squared_map == jnp.array([0.25 + 1.0j, 0.25 + 1.0j])).all() def test__chi_squared_with_noise_covariance_from(): resdiual_map = aa.Array2D.no_mask([[1.0, 1.0], [2.0, 2.0]], pixel_scales=1.0) - noise_covariance_matrix_inv = np.array( + noise_covariance_matrix_inv = jnp.array( [ [1.0, 1.0, 4.0, 0.0], [0.0, 1.0, 9.0, 0.0], @@ -181,10 +180,10 @@ def test__chi_squared_with_noise_covariance_from(): def test__chi_squared_with_mask_fast_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - noise_map = np.array([1.0, 2.0, 3.0, 4.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -207,10 +206,10 @@ def test__chi_squared_with_mask_fast_from(): assert chi_squared == pytest.approx(chi_squared_fast, 1.0e-4) - data = np.array([[10.0, 10.0], [10.0, 10.0]]) - mask = np.array([[True, False], [False, True]]) - noise_map = np.array([[1.0, 2.0], [3.0, 4.0]]) - model_data = np.array([[11.0, 10.0], [9.0, 8.0]]) + data = jnp.array([[10.0, 10.0], [10.0, 10.0]]) + mask = jnp.array([[True, False], [False, True]]) + noise_map = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + model_data = jnp.array([[11.0, 10.0], [9.0, 8.0]]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -235,9 +234,9 @@ def test__chi_squared_with_mask_fast_from(): def test__log_likelihood_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -255,17 +254,17 @@ def test__log_likelihood_from(): chi_squared = 0.0 noise_normalization = ( - np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1.0e-4 ) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -288,17 +287,17 @@ def test__log_likelihood_from(): ((1.0 / 2.0) ** 2.0) + 0.0 + ((1.0 / 2.0) ** 2.0) + ((2.0 / 2.0) ** 2.0) ) noise_normalization = ( - np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1.0e-4 ) - noise_map = np.array([1.0, 2.0, 3.0, 4.0]) + noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -318,10 +317,10 @@ def test__log_likelihood_from(): chi_squared = 1.0 + (1.0 / (3.0**2.0)) + 0.25 noise_normalization = ( - np.log(2 * np.pi * (1.0**2.0)) - + np.log(2 * np.pi * (2.0**2.0)) - + np.log(2 * np.pi * (3.0**2.0)) - + np.log(2 * np.pi * (4.0**2.0)) + jnp.log(2 * jnp.pi * (1.0**2.0)) + + jnp.log(2 * jnp.pi * (2.0**2.0)) + + jnp.log(2 * jnp.pi * (3.0**2.0)) + + jnp.log(2 * jnp.pi * (4.0**2.0)) ) assert log_likelihood == pytest.approx( @@ -330,10 +329,10 @@ def test__log_likelihood_from(): def test__log_likelihood_from__with_mask(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - noise_map = np.array([1.0, 2.0, 3.0, 4.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -358,18 +357,18 @@ def test__log_likelihood_from__with_mask(): # chi squared = 0, 0.25, (0.25 and 1.0 are masked) chi_squared = 0.0 + (1.0 / 3.0) ** 2.0 - noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( - 2 * np.pi * (3.0**2.0) + noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( + 2 * jnp.pi * (3.0**2.0) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1e-4 ) - data = np.array([[10.0, 10.0], [10.0, 10.0]]) - mask = np.array([[True, False], [False, True]]) - noise_map = np.array([[1.0, 2.0], [3.0, 4.0]]) - model_data = np.array([[11.0, 10.0], [9.0, 8.0]]) + data = jnp.array([[10.0, 10.0], [10.0, 10.0]]) + mask = jnp.array([[True, False], [False, True]]) + noise_map = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + model_data = jnp.array([[11.0, 10.0], [9.0, 8.0]]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -394,8 +393,8 @@ def test__log_likelihood_from__with_mask(): # chi squared = 0, 0.25, (0.25 and 1.0 are masked) chi_squared = 0.0 + (1.0 / 3.0) ** 2.0 - noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( - 2 * np.pi * (3.0**2.0) + noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( + 2 * jnp.pi * (3.0**2.0) ) assert log_likelihood == pytest.approx( @@ -404,9 +403,9 @@ def test__log_likelihood_from__with_mask(): def test__log_likelihood_from__complex_data(): - data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = np.array([2.0 + 1.0j, 2.0 + 1.0j]) - model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) + data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = jnp.array([2.0 + 1.0j, 2.0 + 1.0j]) + model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -427,8 +426,8 @@ def test__log_likelihood_from__complex_data(): # chi squared = 0.25 and 4.0 chi_squared = 4.25 - noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( - 2 * np.pi * (1.0**2.0) + noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( + 2 * jnp.pi * (1.0**2.0) ) assert log_likelihood == pytest.approx( @@ -457,8 +456,8 @@ def test__log_evidence_from(): def test__residual_flux_fraction_map_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -466,9 +465,9 @@ def test__residual_flux_fraction_map_from(): residual_map=residual_map, data=data ) - assert (residual_flux_fraction_map == np.array([0.0, 0.0, 0.0, 0.0])).all() + assert (residual_flux_fraction_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -476,13 +475,13 @@ def test__residual_flux_fraction_map_from(): residual_map=residual_map, data=data ) - assert (residual_flux_fraction_map == np.array([-0.1, 0.0, 0.1, 0.2])).all() + assert (residual_flux_fraction_map == jnp.array([-0.1, 0.0, 0.1, 0.2])).all() def test__residual_flux_fraction_map_with_mask_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -492,9 +491,9 @@ def test__residual_flux_fraction_map_with_mask_from(): residual_map=residual_map, mask=mask, data=data ) - assert (residual_flux_fraction_map == np.array([0.0, 0.0, 0.1, 0.0])).all() + assert (residual_flux_fraction_map == jnp.array([0.0, 0.0, 0.1, 0.0])).all() - model_data = np.array([11.0, 9.0, 8.0, 8.0]) + model_data = jnp.array([11.0, 9.0, 8.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -504,4 +503,4 @@ def test__residual_flux_fraction_map_with_mask_from(): residual_map=residual_map, mask=mask, data=data ) - assert (residual_flux_fraction_map == np.array([0.0, 0.1, 0.2, 0.0])).all() + assert (residual_flux_fraction_map == jnp.array([0.0, 0.1, 0.2, 0.0])).all()