From 33a3ccb69b17af77457ca0d000539e942852a25f Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 14:50:01 +0100 Subject: [PATCH 01/19] fix some decorator unit tests --- autoarray/structures/decorators/abstract.py | 2 +- autoarray/structures/decorators/to_grid.py | 3 +-- autoarray/structures/mock/mock_decorators.py | 6 +++--- test_autoarray/structures/decorators/test_to_grid.py | 8 ++++---- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/autoarray/structures/decorators/abstract.py b/autoarray/structures/decorators/abstract.py index c9e5fca87..e033815b4 100644 --- a/autoarray/structures/decorators/abstract.py +++ b/autoarray/structures/decorators/abstract.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Union import numpy as np diff --git a/autoarray/structures/decorators/to_grid.py b/autoarray/structures/decorators/to_grid.py index 144137c69..4797c37ce 100644 --- a/autoarray/structures/decorators/to_grid.py +++ b/autoarray/structures/decorators/to_grid.py @@ -1,6 +1,5 @@ -from autoarray.numpy_wrapper import np from functools import wraps - +import numpy as np from typing import List, Union from autoarray.structures.decorators.abstract import AbstractMaker diff --git a/autoarray/structures/mock/mock_decorators.py b/autoarray/structures/mock/mock_decorators.py index 876b456d7..013b0a62a 100644 --- a/autoarray/structures/mock/mock_decorators.py +++ b/autoarray/structures/mock/mock_decorators.py @@ -116,7 +116,7 @@ def ndarray_2d_from(self, grid, *args, **kwargs): Such functions are common in **PyAutoGalaxy** for light and mass profile objects. """ - return np.multiply(2.0, grid) + return np.multiply(2.0, grid.array) @decorators.to_vector_yx def ndarray_yx_2d_from(self, grid, *args, **kwargs): @@ -146,7 +146,7 @@ def ndarray_2d_list_from(self, grid, *args, **kwargs): Such functions are common in **PyAutoGalaxy** for light and mass profile objects. """ - return [np.multiply(1.0, grid), np.multiply(2.0, grid)] + return [np.multiply(1.0, grid.array), np.multiply(2.0, grid.array)] @decorators.to_vector_yx def ndarray_yx_2d_list_from(self, grid, *args, **kwargs): @@ -156,7 +156,7 @@ def ndarray_yx_2d_list_from(self, grid, *args, **kwargs): Such functions are common in **PyAutoGalaxy** for light and mass profile objects. """ - return [np.multiply(1.0, grid), np.multiply(2.0, grid)] + return [np.multiply(1.0, grid.array), np.multiply(2.0, grid.array)] class MockGridRadialMinimum: diff --git a/test_autoarray/structures/decorators/test_to_grid.py b/test_autoarray/structures/decorators/test_to_grid.py index 60c70d71b..2e8b1be2f 100644 --- a/test_autoarray/structures/decorators/test_to_grid.py +++ b/test_autoarray/structures/decorators/test_to_grid.py @@ -15,11 +15,11 @@ def test__in_grid_1d__out_ndarray_2d(): assert isinstance(ndarray_2d, aa.Grid2D) assert ndarray_2d.native == pytest.approx( - np.array([[[0.0, 0.0], [0.0, -1.0], [0.0, 1.0], [0.0, 0.0]]]), 1.0e-4 + np.array([[[0.0, 0.0], [0.0, -1.0], [0.0, 1.0], [0.0, 0.0]]]), abs=1.0e-4 ) -def test__in_grid_1d__out_ndarray_2d_list(): +def test__in_dgrid_1d__out_ndarray_2d_list(): mask = aa.Mask1D(mask=[True, False, False, True], pixel_scales=(1.0,)) grid_1d = aa.Grid1D.from_mask(mask=mask) @@ -30,12 +30,12 @@ def test__in_grid_1d__out_ndarray_2d_list(): assert isinstance(ndarray_2d_list[0], aa.Grid2D) assert ndarray_2d_list[0].native == pytest.approx( - np.array([[[0.0, 0.0], [0.0, -0.5], [0.0, 0.5], [0.0, 0.0]]]), 1.0e-4 + np.array([[[0.0, 0.0], [0.0, -0.5], [0.0, 0.5], [0.0, 0.0]]]), abs=1.0e-4 ) assert isinstance(ndarray_2d_list[1], aa.Grid2D) assert ndarray_2d_list[1].native == pytest.approx( - np.array([[[0.0, 0.0], [0.0, -1.0], [0.0, 1.0], [0.0, 0.0]]]), 1.0e-4 + np.array([[[0.0, 0.0], [0.0, -1.0], [0.0, 1.0], [0.0, 0.0]]]), abs=1.0e-4 ) From 8b8dc9e9479dc17d9e91e5631e1033a75766a2bd Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 14:59:45 +0100 Subject: [PATCH 02/19] removing numpy wrapper to do explicit impots --- autoarray/abstract_ndarray.py | 17 +++++++++-------- .../operators/over_sampling/over_sampler.py | 13 +++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index 8b5fbc00e..ded8c5452 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -4,10 +4,11 @@ from abc import ABC from abc import abstractmethod +import jax.numpy as jnp from autoconf.fitsable import output_to_fits -from autoarray.numpy_wrapper import np, register_pytree_node, Array +from autoarray.numpy_wrapper import register_pytree_node, Array from typing import TYPE_CHECKING @@ -82,7 +83,7 @@ def __init__(self, array): def invert(self): new = self.copy() - new._array = np.invert(new._array) + new._array = jnp.invert(new._array) return new @classmethod @@ -104,7 +105,7 @@ def instance_flatten(cls, instance): @staticmethod def flip_hdu_for_ds9(values): if conf.instance["general"]["fits"]["flip_for_ds9"]: - return np.flipud(values) + return jnp.flipud(values) return values @classmethod @@ -117,7 +118,7 @@ def instance_unflatten(cls, aux_data, children): setattr(instance, key, value) return instance - def with_new_array(self, array: np.ndarray) -> "AbstractNDArray": + def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray": """ Copy this object but give it a new array. @@ -164,7 +165,7 @@ def __iter__(self): @to_new_array def sqrt(self): - return np.sqrt(self._array) + return jnp.sqrt(self._array) @property def array(self): @@ -330,13 +331,13 @@ def __getitem__(self, item): result = self._array[item] if isinstance(item, slice): result = self.with_new_array(result) - if isinstance(result, np.ndarray): + if isinstance(result, jnp.ndarray): result = self.with_new_array(result) return result def __setitem__(self, key, value): - if isinstance(key, (np.ndarray, AbstractNDArray, Array)): - self._array = np.where(key, value, self._array) + if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)): + self._array = jnp.where(key, value, self._array) else: self._array[key] = value diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 6f12f4b9f..65393709c 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -1,5 +1,6 @@ -from autoarray.numpy_wrapper import np -from typing import List, Tuple, Union +import numpy as np +import jax.numpy as jnp +from typing import Union from autoconf import conf from autoconf import cached_property @@ -184,7 +185,7 @@ def sub_pixel_areas(self) -> np.ndarray: """ The area of every sub-pixel in the mask. """ - sub_pixel_areas = np.zeros(self.sub_total) + sub_pixel_areas = jnp.zeros(self.sub_total) k = 0 @@ -221,9 +222,9 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": pass # binned_array_2d = over_sample_util.binned_array_2d_from( - # array_2d=np.array(array), - # mask_2d=np.array(self.mask), - # sub_size=np.array(self.sub_size).astype("int"), + # array_2d=jnp.array(array), + # mask_2d=jnp.array(self.mask), + # sub_size=jnp.array(self.sub_size).astype("int"), # ) binned_array_2d = array.reshape( From 7115f9cbc5893d768e52a08b4d1ff2c313b39e9f Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:19:22 +0100 Subject: [PATCH 03/19] move relocate radial --- autoarray/config/grids.yaml | 3 - autoarray/structures/decorators/__init__.py | 1 - .../structures/decorators/relocate_radial.py | 106 ------------------ autoarray/structures/mock/mock_decorators.py | 4 - test_autoarray/config/grids.yaml | 12 -- 5 files changed, 126 deletions(-) delete mode 100644 autoarray/config/grids.yaml delete mode 100644 autoarray/structures/decorators/relocate_radial.py delete mode 100644 test_autoarray/config/grids.yaml diff --git a/autoarray/config/grids.yaml b/autoarray/config/grids.yaml deleted file mode 100644 index f82eaa5df..000000000 --- a/autoarray/config/grids.yaml +++ /dev/null @@ -1,3 +0,0 @@ -radial_minimum: - function_name: - class_name: 1.0e-08 diff --git a/autoarray/structures/decorators/__init__.py b/autoarray/structures/decorators/__init__.py index d85cbe6ee..1efb9137e 100644 --- a/autoarray/structures/decorators/__init__.py +++ b/autoarray/structures/decorators/__init__.py @@ -4,4 +4,3 @@ from .to_grid import to_grid from .to_vector_yx import to_vector_yx from .transform import transform -from .relocate_radial import relocate_to_radial_minimum diff --git a/autoarray/structures/decorators/relocate_radial.py b/autoarray/structures/decorators/relocate_radial.py deleted file mode 100644 index 58411714f..000000000 --- a/autoarray/structures/decorators/relocate_radial.py +++ /dev/null @@ -1,106 +0,0 @@ -from autoarray.numpy_wrapper import np, use_jax -from functools import wraps - -from typing import Union - -from autoconf.exc import ConfigException - -from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.structures.grids.uniform_2d import Grid2D -from autoconf import conf - - -def relocate_to_radial_minimum(func): - """ - Checks whether any coordinates in the grid are radially near (0.0, 0.0), which can lead to numerical faults in - the evaluation of a function (e.g. numerical integration reaching a singularity at (0.0, 0.0)). - - If any coordinates are radially within the radial minimum threshold, their (y,x) coordinates are shifted to that - value to ensure they are evaluated at that coordinate. - - The value the (y,x) coordinates are rounded to is set in the 'radial_minimum.yaml' config. - - Parameters - ---------- - func - A function that takes a grid of coordinates which may have a singularity as (0.0, 0.0) - - Returns - ------- - A function that has an input grid whose radial coordinates are relocated to the radial minimum. - """ - - @wraps(func) - def wrapper( - obj: object, - grid: Union[np.ndarray, Grid2D, Grid2DIrregular], - *args, - **kwargs, - ) -> Union[np.ndarray, Grid2D, Grid2DIrregular]: - """ - Checks whether any coordinates in the grid are radially near (0.0, 0.0), which can lead to numerical faults in - the evaluation of a function (e.g. numerical integration reaching a singularity at (0.0, 0.0)). - - If any coordinates are radially within the radial minimum threshold, their (y,x) coordinates are shifted to that - value to ensure they are evaluated at that coordinate. - - The value the (y,x) coordinates are rounded to is set in the 'radial_minimum.yaml' config. - - Parameters - ---------- - obj - An object whose function uses grid_like inputs to compute quantities at every coordinate on the grid. - grid - The (y, x) coordinates which are to be radially moved from (0.0, 0.0). - - Returns - ------- - The grid_like object whose coordinates are radially moved from (0.0, 0.0). - """ - if use_jax: - return func(obj, grid, *args, **kwargs) - - try: - grid_radial_minimum = conf.instance["grids"]["radial_minimum"][ - "radial_minimum" - ][obj.__class__.__name__] - - except KeyError as e: - raise ConfigException( - rf""" - The {obj.__class__.__name__} profile you are using does not have a corresponding - entry in the `config/grid.yaml` config file. - - When a profile is evaluated at (0.0, 0.0), they commonly break due to numericalinstabilities (e.g. - division by zero). To prevent this, the code relocates the (y,x) coordinates of the grid to a - minimum radial value, specified in the `config/grids.yaml` config file. - - For example, if the value in `grid.yaml` is `radial_minimum: 1e-6`, then any (y,x) coordinates - with a radial distance less than 1e-6 to (0.0, 0.0) are relocated to 1e-6. - - For a profile to be used it must have an entry in the `config/grids.yaml` config file. Go to this - file now and add your profile to the `radial_minimum` section. Adopting a value of 1e-6 is a good - default choice. - - If you are going to make a pull request to add your profile to the source code, you should also - add an entry to the `config/grids.yaml` config file of the source code itself - (e.g. `PyAutoGalaxy/autogalaxy/config/grids.yaml`). - """ - ) - - with np.errstate(all="ignore"): # Division by zero fixed via isnan - grid_radii = obj.radial_grid_from(grid=grid) - - grid_radial_scale = np.where( - grid_radii < grid_radial_minimum, grid_radial_minimum / grid_radii, 1.0 - ) - moved_grid = np.multiply(grid, grid_radial_scale[:, None]) - - if hasattr(grid, "with_new_array"): - moved_grid = grid.with_new_array(moved_grid) - - moved_grid[np.isnan(np.array(moved_grid))] = grid_radial_minimum - - return func(obj, moved_grid, *args, **kwargs) - - return wrapper diff --git a/autoarray/structures/mock/mock_decorators.py b/autoarray/structures/mock/mock_decorators.py index 013b0a62a..c02ebc0b8 100644 --- a/autoarray/structures/mock/mock_decorators.py +++ b/autoarray/structures/mock/mock_decorators.py @@ -165,7 +165,3 @@ def __init__(self): def radial_grid_from(self, grid): return np.sqrt(np.add(np.square(grid[:, 0]), np.square(grid[:, 1]))) - - @decorators.relocate_to_radial_minimum - def deflections_yx_2d_from(self, grid): - return grid diff --git a/test_autoarray/config/grids.yaml b/test_autoarray/config/grids.yaml deleted file mode 100644 index 61c268a27..000000000 --- a/test_autoarray/config/grids.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# Certain light and mass profile calculations become ill defined at (0.0, 0.0) or close to this value. This can lead -# to numerical issues in the calculation of the profile, for example a np.nan may arise, crashing the code. - -# To avoid this, we set a minimum value for the radial coordinate of the profile. If the radial coordinate is below -# this value, it is rounded up to this value. This ensures that the profile cannot receive a radial coordinate of 0.0. - -# For example, if an input grid coordinate has a radial coordinate of 1e-12, for most profiles this will be rounded up -# to radial_minimum=1e-08. This is a small enough value that it should not impact the results of the profile calculation. - -radial_minimum: - radial_minimum: - MockGridRadialMinimum: 2.5 \ No newline at end of file From 6f027158f823c43e3ab665285497b35ec424c5cf Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:26:25 +0100 Subject: [PATCH 04/19] more removal of numpy wrapper nps --- autoarray/mask/derive/indexes_2d.py | 3 ++- autoarray/mask/mask_2d_util.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index 062c8e664..0d0a36b26 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging +import numpy as np -from autoarray.numpy_wrapper import np, register_pytree_node_class +from autoarray.numpy_wrapper import register_pytree_node_class from typing import TYPE_CHECKING if TYPE_CHECKING: diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 462448073..10a40b473 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -1,10 +1,10 @@ import numpy as np +import jax.numpy as jnp from scipy.ndimage import convolve from typing import Tuple import warnings from autoarray import exc -from autoarray.numpy_wrapper import np as jnp def native_index_for_slim_index_2d_from( From aa4c9e6e5b8868d35fcfb6d9dce206b63f38e956 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:28:07 +0100 Subject: [PATCH 05/19] remove all numpy wrappers --- autoarray/operators/contour.py | 15 ++++++--------- autoarray/structures/decorators/to_vector_yx.py | 3 +-- autoarray/structures/decorators/transform.py | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index c7da5c7f1..c352ea19f 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -1,6 +1,6 @@ from __future__ import annotations -from autoarray.numpy_wrapper import np, use_jax import numpy +import jax.numpy as jnp from skimage import measure from scipy.spatial import ConvexHull from scipy.spatial import QhullError @@ -42,16 +42,13 @@ def contour_array(self): return self._contour_array pixel_centres = geometry_util.grid_pixel_centres_2d_slim_from( - grid_scaled_2d_slim=np.array(self.grid), + grid_scaled_2d_slim=jnp.array(self.grid), shape_native=self.shape_native, pixel_scales=self.pixel_scales, ).astype("int") - arr = np.zeros(self.shape_native) - if use_jax: - arr = arr.at[tuple(np.array(pixel_centres).T)].set(1) - else: - arr[tuple(np.array(pixel_centres).T)] = 1 + arr = jnp.zeros(self.shape_native) + arr = arr.at[tuple(jnp.array(pixel_centres).T)].set(1) return arr @@ -74,7 +71,7 @@ def contour_list(self): pixel_scales=self.pixel_scales, ) - factor = 0.5 * np.array(self.pixel_scales) * np.array([-1.0, 1.0]) + factor = 0.5 * jnp.array(self.pixel_scales) * jnp.array([-1.0, 1.0]) grid_scaled_1d += factor contour_list.append(Grid2DIrregular(values=grid_scaled_1d)) @@ -104,7 +101,7 @@ def hull( hull_x = grid_convex[hull_vertices, 0] hull_y = grid_convex[hull_vertices, 1] - grid_hull = np.zeros((len(hull_vertices), 2)) + grid_hull = jnp.zeros((len(hull_vertices), 2)) grid_hull[:, 1] = hull_x grid_hull[:, 0] = hull_y diff --git a/autoarray/structures/decorators/to_vector_yx.py b/autoarray/structures/decorators/to_vector_yx.py index 1cf23346d..90aea99ea 100644 --- a/autoarray/structures/decorators/to_vector_yx.py +++ b/autoarray/structures/decorators/to_vector_yx.py @@ -1,6 +1,5 @@ -from autoarray.numpy_wrapper import np from functools import wraps - +import numpy as np from typing import List, Union from autoarray.structures.decorators.abstract import AbstractMaker diff --git a/autoarray/structures/decorators/transform.py b/autoarray/structures/decorators/transform.py index bd837a399..eca0d883b 100644 --- a/autoarray/structures/decorators/transform.py +++ b/autoarray/structures/decorators/transform.py @@ -1,5 +1,5 @@ -from autoarray.numpy_wrapper import np from functools import wraps +import numpy as np from typing import Union from autoarray.structures.grids.uniform_1d import Grid1D From 3b6ab48b21dff83bc144ec19d48d408fc604793b Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:36:33 +0100 Subject: [PATCH 06/19] remove warning for now --- autoarray/numba_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/numba_util.py b/autoarray/numba_util.py index db34f3e1a..9e0298b73 100644 --- a/autoarray/numba_util.py +++ b/autoarray/numba_util.py @@ -33,7 +33,7 @@ try: if os.environ.get("USE_JAX") == "1": - logger.warning("JAX and numba do not work together, so JAX is being used.") + 1 else: import numba From 44a2808e1761798ae0ab077a07a94408c43e843d Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:48:11 +0100 Subject: [PATCH 07/19] fix structure plotters --- autoarray/operators/contour.py | 14 +++++++------- autoarray/plot/wrap/two_d/array_overlay.py | 4 +++- .../structures/plot/test_structure_plotters.py | 1 + 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index c352ea19f..2de247d3c 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy +import numpy as np import jax.numpy as jnp from skimage import measure from scipy.spatial import ConvexHull @@ -42,7 +42,7 @@ def contour_array(self): return self._contour_array pixel_centres = geometry_util.grid_pixel_centres_2d_slim_from( - grid_scaled_2d_slim=jnp.array(self.grid), + grid_scaled_2d_slim=np.array(self.grid), shape_native=self.shape_native, pixel_scales=self.pixel_scales, ).astype("int") @@ -56,7 +56,7 @@ def contour_array(self): def contour_list(self): # make sure to use base numpy to convert JAX array back to a normal array contour_indices_list = measure.find_contours( - numpy.array(self.contour_array.array), 0 + np.array(self.contour_array), 0 ) if len(contour_indices_list) == 0: @@ -71,7 +71,7 @@ def contour_list(self): pixel_scales=self.pixel_scales, ) - factor = 0.5 * jnp.array(self.pixel_scales) * jnp.array([-1.0, 1.0]) + factor = 0.5 * np.array(self.pixel_scales) * np.array([-1.0, 1.0]) grid_scaled_1d += factor contour_list.append(Grid2DIrregular(values=grid_scaled_1d)) @@ -86,10 +86,10 @@ def hull( return None # cast JAX arrays to base numpy arrays - grid_convex = numpy.zeros((len(self.grid), 2)) + grid_convex = np.zeros((len(self.grid), 2)) - grid_convex[:, 0] = numpy.array(self.grid[:, 1]) - grid_convex[:, 1] = numpy.array(self.grid[:, 0]) + grid_convex[:, 0] = np.array(self.grid[:, 1]) + grid_convex[:, 1] = np.array(self.grid[:, 0]) try: hull = ConvexHull(grid_convex) diff --git a/autoarray/plot/wrap/two_d/array_overlay.py b/autoarray/plot/wrap/two_d/array_overlay.py index 57652e8df..5de20b879 100644 --- a/autoarray/plot/wrap/two_d/array_overlay.py +++ b/autoarray/plot/wrap/two_d/array_overlay.py @@ -19,4 +19,6 @@ def overlay_array(self, array, figure): aspect = figure.aspect_from(shape_native=array.shape_native) extent = array.extent_of_zoomed_array(buffer=0) - plt.imshow(X=array.native, aspect=aspect, extent=extent, **self.config_dict) + print(type(array)) + + plt.imshow(X=array.native._array, aspect=aspect, extent=extent, **self.config_dict) diff --git a/test_autoarray/structures/plot/test_structure_plotters.py b/test_autoarray/structures/plot/test_structure_plotters.py index ad1ca0251..d455c86f4 100644 --- a/test_autoarray/structures/plot/test_structure_plotters.py +++ b/test_autoarray/structures/plot/test_structure_plotters.py @@ -3,6 +3,7 @@ from os import path import pytest import numpy as np +import jax.numpy as jnp import shutil directory = path.dirname(path.realpath(__file__)) From ea139fcd95bd3400b620df39041e40d0470beba0 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:56:03 +0100 Subject: [PATCH 08/19] clean up vectors_yx --- autoarray/structures/vectors/uniform.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/autoarray/structures/vectors/uniform.py b/autoarray/structures/vectors/uniform.py index 89d589139..fc66b0e8d 100644 --- a/autoarray/structures/vectors/uniform.py +++ b/autoarray/structures/vectors/uniform.py @@ -1,7 +1,8 @@ import logging -# import numpy as np -from autofit.jax_wrapper import numpy as np, use_jax +import numpy as np +import jax.numpy as jnp +# from autofit.jax_wrapper import numpy as np, use_jax from typing import List, Optional, Tuple, Union from autoarray.structures.arrays.uniform_2d import Array2D @@ -396,11 +397,7 @@ def magnitudes(self) -> Array2D: """ Returns the magnitude of every vector which are computed as sqrt(y**2 + x**2). """ - if use_jax: - s = self.array - else: - s = self - return Array2D(values=np.sqrt(s[:, 0] ** 2.0 + s[:, 1] ** 2.0), mask=self.mask) + return Array2D(values=jnp.sqrt(self.array[:, 0] ** 2.0 + self.array[:, 1] ** 2.0), mask=self.mask) @property def y(self) -> Array2D: From ec1e81eb0066bd638c0d48dda33d95d14169e27d Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 17:57:41 +0100 Subject: [PATCH 09/19] remove autofit imports --- autoarray/geometry/geometry_2d.py | 2 -- autoarray/operators/over_sampling/over_sampler.py | 2 +- autoarray/structures/vectors/uniform.py | 1 - test_autoarray/test_jax_changes.py | 8 +++++--- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/autoarray/geometry/geometry_2d.py b/autoarray/geometry/geometry_2d.py index e78f0f75a..9eea7e9f2 100644 --- a/autoarray/geometry/geometry_2d.py +++ b/autoarray/geometry/geometry_2d.py @@ -13,8 +13,6 @@ from autoarray import type as ty from autoarray.geometry import geometry_util -from autofit.jax_wrapper import use_jax - logging.basicConfig() logger = logging.getLogger(__name__) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 65393709c..ae458e41b 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -10,7 +10,7 @@ from autoarray.operators.over_sampling import over_sample_util -from autofit.jax_wrapper import register_pytree_node_class +from autoarray.numpy_wrapper import register_pytree_node_class @register_pytree_node_class diff --git a/autoarray/structures/vectors/uniform.py b/autoarray/structures/vectors/uniform.py index fc66b0e8d..6213ecfbf 100644 --- a/autoarray/structures/vectors/uniform.py +++ b/autoarray/structures/vectors/uniform.py @@ -2,7 +2,6 @@ import numpy as np import jax.numpy as jnp -# from autofit.jax_wrapper import numpy as np, use_jax from typing import List, Optional, Tuple, Union from autoarray.structures.arrays.uniform_2d import Array2D diff --git a/test_autoarray/test_jax_changes.py b/test_autoarray/test_jax_changes.py index f5104a942..2b6289317 100644 --- a/test_autoarray/test_jax_changes.py +++ b/test_autoarray/test_jax_changes.py @@ -1,8 +1,10 @@ -import autoarray as aa +import jax.numpy as jnp import pytest + +import autoarray as aa + from autoarray import Grid2D, Mask2D -from autofit.jax_wrapper import numpy as np @pytest.fixture(name="array") @@ -33,4 +35,4 @@ def test_boolean_issue(): mask=Mask2D.all_false((10, 10), pixel_scales=1.0), ) values, keys = Grid2D.instance_flatten(grid) - np.array(Grid2D.instance_unflatten(keys, values)) + jnp.array(Grid2D.instance_unflatten(keys, values)) From 37e81f157a2acd3dff2f55ab546a834ece26ddd9 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 3 Apr 2025 18:04:21 +0100 Subject: [PATCH 10/19] fix voronoi unit test in structures --- autoarray/inversion/pixelization/image_mesh/overlay.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autoarray/inversion/pixelization/image_mesh/overlay.py b/autoarray/inversion/pixelization/image_mesh/overlay.py index c5bc7eaef..de130ee6e 100644 --- a/autoarray/inversion/pixelization/image_mesh/overlay.py +++ b/autoarray/inversion/pixelization/image_mesh/overlay.py @@ -220,11 +220,11 @@ def image_plane_mesh_grid_from( origin=origin, ) - overlaid_centres = geometry_util.grid_pixel_centres_2d_slim_from( + overlaid_centres = np.array(geometry_util.grid_pixel_centres_2d_slim_from( grid_scaled_2d_slim=unmasked_overlay_grid, shape_native=mask.shape_native, pixel_scales=mask.pixel_scales, - ).astype("int") + )).astype("int") total_pixels = total_pixels_2d_from( mask_2d=mask.array, From b31c0fc24e248e2a82dbf40d7f5fdfed9d4b0a9e Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:24:39 +0100 Subject: [PATCH 11/19] fix test_preprocess --- autoarray/dataset/preprocess.py | 12 ++++++------ test_autoarray/dataset/test_preprocess.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/autoarray/dataset/preprocess.py b/autoarray/dataset/preprocess.py index 5c7338204..f13af3184 100644 --- a/autoarray/dataset/preprocess.py +++ b/autoarray/dataset/preprocess.py @@ -263,15 +263,15 @@ def edges_from(image, no_edges): edges = [] for edge_no in range(no_edges): - top_edge = image.native[edge_no, edge_no : image.shape_native[1] - edge_no] - bottom_edge = image.native[ + top_edge = image.native.array[edge_no, edge_no : image.shape_native[1] - edge_no] + bottom_edge = image.native.array[ image.shape_native[0] - 1 - edge_no, edge_no : image.shape_native[1] - edge_no, ] - left_edge = image.native[ + left_edge = image.native.array[ edge_no + 1 : image.shape_native[0] - 1 - edge_no, edge_no ] - right_edge = image.native[ + right_edge = image.native.array[ edge_no + 1 : image.shape_native[0] - 1 - edge_no, image.shape_native[1] - 1 - edge_no, ] @@ -517,8 +517,8 @@ def noise_map_with_signal_to_noise_limit_from( noise_map_limit = np.where( (signal_to_noise_map.native > signal_to_noise_limit) & (noise_limit_mask == False), - np.abs(data.native) / signal_to_noise_limit, - noise_map.native, + np.abs(data.native.array) / signal_to_noise_limit, + noise_map.native.array, ) mask = Mask2D.all_false( diff --git a/test_autoarray/dataset/test_preprocess.py b/test_autoarray/dataset/test_preprocess.py index 74e8ef774..f484fd648 100644 --- a/test_autoarray/dataset/test_preprocess.py +++ b/test_autoarray/dataset/test_preprocess.py @@ -462,7 +462,7 @@ def test__background_noise_map_via_edges_of_image_from_4(): ) assert np.allclose( - background_noise_map.native, + background_noise_map.native.array, np.full(fill_value=np.std(np.arange(28)), shape=image.shape_native), ) @@ -486,7 +486,7 @@ def test__background_noise_map_via_edges_of_image_from_5(): ) assert np.allclose( - background_noise_map.native, + background_noise_map.native.array, np.full(fill_value=np.std(np.arange(48)), shape=image.shape_native), ) From 80fc8e8781d3d7b8bcb50ca2698bd5cbda54b8ae Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:25:38 +0100 Subject: [PATCH 12/19] fix test dataset abstract --- autoarray/dataset/imaging/dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index e3ec74d3b..b5b5b73d7 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -166,8 +166,9 @@ def __init__( self.psf = psf - if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") + if psf is not None: + if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: + raise exc.KernelException("Kernel2D Kernel2D must be odd") @cached_property def grids(self): From cd276cd3146e4d03beede6a9ffe366a1ab0f93ba Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:38:09 +0100 Subject: [PATCH 13/19] fix test imaging --- autoarray/dataset/imaging/dataset.py | 18 +++++++++--------- test_autoarray/dataset/imaging/test_dataset.py | 9 +++++++-- .../dataset/imaging/test_simulator.py | 2 +- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index b5b5b73d7..8d84ee1b8 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -204,9 +204,9 @@ def w_tilde(self): indexes, lengths, ) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from( - noise_map_native=np.array(self.noise_map.native), - kernel_native=np.array(self.psf.native), - native_index_for_slim_index=self.mask.derive_indexes.native_for_slim, + noise_map_native=np.array(self.noise_map.native.array).astype("float64"), + kernel_native=np.array(self.psf.native.array).astype("float64"), + native_index_for_slim_index=np.array(self.mask.derive_indexes.native_for_slim).astype("int"), ) return WTildeImaging( @@ -409,20 +409,20 @@ def apply_noise_scaling( """ if signal_to_noise_value is None: - noise_map = self.noise_map.native - noise_map[mask == False] = noise_value + noise_map = np.array(self.noise_map.native.array) + noise_map[mask.array == False] = noise_value else: noise_map = np.where( mask == False, - np.median(self.data.native[mask.derive_mask.edge == False]) + np.median(self.data.native.array[mask.derive_mask.edge == False]) / signal_to_noise_value, - self.noise_map.native, + self.noise_map.native.array, ) if should_zero_data: - data = np.where(np.invert(mask), 0.0, self.data.native) + data = np.where(np.invert(mask.array), 0.0, self.data.native.array) else: - data = self.data.native + data = self.data.native.array data_unmasked = Array2D.no_mask( values=data, diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index 45ef0a80f..ca33f1b40 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -139,7 +139,7 @@ def test__apply_mask(imaging_7x7, mask_2d_7x7, psf_3x3): == 2.0 * np.ones((7, 7)) * np.invert(mask_2d_7x7) ).all() - assert (masked_imaging_7x7.psf.slim == (1.0 / 3.0) * psf_3x3.slim).all() + assert masked_imaging_7x7.psf.slim == pytest.approx((1.0 / 3.0) * psf_3x3.slim, 1.0e-4) assert type(masked_imaging_7x7.psf) == aa.Kernel2D assert masked_imaging_7x7.w_tilde.curvature_preload.shape == (35,) @@ -244,4 +244,9 @@ def test__noise_map_unmasked_has_zeros_or_negative__raises_exception(): def test__psf_not_odd_x_odd_kernel__raises_error(): with pytest.raises(exc.KernelException): - aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) + image = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) + noise_map = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) + psf = aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) + + dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=False) + diff --git a/test_autoarray/dataset/imaging/test_simulator.py b/test_autoarray/dataset/imaging/test_simulator.py index 3cc4182d3..54dc1f6ed 100644 --- a/test_autoarray/dataset/imaging/test_simulator.py +++ b/test_autoarray/dataset/imaging/test_simulator.py @@ -70,7 +70,7 @@ def test__via_image_from__psf_off__noise_off_value_is_noise_value( == np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]) ).all() - assert np.allclose(dataset.noise_map.native, 0.2 * np.ones((3, 3))) + assert np.allclose(dataset.noise_map.native.array, 0.2 * np.ones((3, 3))) def test__via_image_from__psf_off__background_sky_on(image_central_delta_3x3): From f6dfda50b5c2a03db875d0beda7bee7a39ebdbb5 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:42:00 +0100 Subject: [PATCH 14/19] fix layout --- test_autoarray/layout/test_region.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test_autoarray/layout/test_region.py b/test_autoarray/layout/test_region.py index 690cfa9bb..643b8532a 100644 --- a/test_autoarray/layout/test_region.py +++ b/test_autoarray/layout/test_region.py @@ -152,7 +152,7 @@ def test__slice_2d__addition(): image = np.ones((2, 2)) region = aa.Region2D(region=(0, 1, 0, 1)) - array[region.slice] += image[region.slice] + array = array.at[region.slice].add(image[region.slice]) assert (array == np.array([[1.0, 0.0], [0.0, 0.0]])).all() @@ -161,7 +161,7 @@ def test__slice_2d__addition(): image = np.ones((2, 2)) region = aa.Region2D(region=(0, 1, 0, 1)) - array[region.slice] += image[region.slice] + array = array.at[region.slice].add(image[region.slice]) assert (array == np.array([[2.0, 1.0], [1.0, 1.0]])).all() @@ -170,7 +170,7 @@ def test__slice_2d__addition(): image = np.ones((3, 3)) region = aa.Region2D(region=(1, 3, 2, 3)) - array[region.slice] += image[region.slice] + array = array.at[region.slice].add(image[region.slice]) assert ( array == np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [1.0, 1.0, 2.0]]) @@ -183,7 +183,7 @@ def test__slice_2d__set_to_zerose(): region = aa.Region2D(region=(0, 1, 0, 1)) - array[region.slice] = 0 + array = array.at[region.slice].set(0) assert (array == np.array([[0.0, 1.0], [1.0, 1.0]])).all() @@ -192,7 +192,7 @@ def test__slice_2d__set_to_zerose(): region = aa.Region2D(region=(1, 3, 2, 3)) - array[region.slice] = 0 + array = array.at[region.slice].set(0) assert ( array == np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) From 5430647dbf983d9bac7e7de0107b855162f0d117 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 14:46:46 +0100 Subject: [PATCH 15/19] fix plot unit tests --- autoarray/plot/wrap/two_d/contour.py | 4 ++-- test_autoarray/plot/include/test_include.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autoarray/plot/wrap/two_d/contour.py b/autoarray/plot/wrap/two_d/contour.py index c80159813..164fe86be 100644 --- a/autoarray/plot/wrap/two_d/contour.py +++ b/autoarray/plot/wrap/two_d/contour.py @@ -93,10 +93,10 @@ def set( config_dict.pop("use_log10") config_dict.pop("include_values") - levels = self.levels_from(array) + levels = self.levels_from(array.array) ax = plt.contour( - array.native[::-1], levels=levels, extent=extent, **config_dict + array.native.array[::-1], levels=levels, extent=extent, **config_dict ) if self.include_values: try: diff --git a/test_autoarray/plot/include/test_include.py b/test_autoarray/plot/include/test_include.py index 6f4d29c77..b32616e9d 100644 --- a/test_autoarray/plot/include/test_include.py +++ b/test_autoarray/plot/include/test_include.py @@ -6,7 +6,7 @@ def test__loads_default_values_from_config_if_not_input(): assert include.origin is True assert include.mask == True - assert include.border is True + assert include.border is False assert include.parallel_overscan is True assert include.serial_prescan is True assert include.serial_overscan is False From 083ed0bb2f08eed82eb1520b929347036075690c Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 15:50:11 +0100 Subject: [PATCH 16/19] over sampling unit tests --- .../operators/over_sampling/over_sampler.py | 37 ++++++++++++++----- .../over_sample/test_over_sampler.py | 10 +++++ 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index ae458e41b..9fda67bb7 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -147,6 +147,16 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): over_sample_size=sub_size, mask=mask ) + + @property + def sub_is_uniform(self) -> bool: + """ + Returns True if the sub_size is uniform across all pixels in the mask. + """ + return np.all( + np.isclose(self.sub_size.array, self.sub_size.array[0]) + ) + def tree_flatten(self): return (self.mask, self.sub_size), () @@ -185,7 +195,7 @@ def sub_pixel_areas(self) -> np.ndarray: """ The area of every sub-pixel in the mask. """ - sub_pixel_areas = jnp.zeros(self.sub_total) + sub_pixel_areas = np.zeros(self.sub_total) k = 0 @@ -221,15 +231,24 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": except AttributeError: pass - # binned_array_2d = over_sample_util.binned_array_2d_from( - # array_2d=jnp.array(array), - # mask_2d=jnp.array(self.mask), - # sub_size=jnp.array(self.sub_size).astype("int"), - # ) + if self.sub_is_uniform: + binned_array_2d = array.reshape( + self.mask.shape_slim, self.sub_size[0] ** 2 + ).mean(axis=1) + else: + + # Define group sizes + group_sizes = jnp.array(self.sub_size.array.astype("int") ** 2) + + # Compute the cumulative sum of group sizes to get split points + split_indices = jnp.cumsum(group_sizes) + + # Ensure correct concatenation by making 0 a JAX array + start_indices = jnp.concatenate((jnp.array([0]), split_indices[:-1])) - binned_array_2d = array.reshape( - self.mask.shape_slim, self.sub_size[0] ** 2 - ).mean(axis=1) + # Compute the group means + binned_array_2d = jnp.array( + [array[start:end].mean() for start, end in zip(start_indices, split_indices)]) return Array2D( values=binned_array_2d, diff --git a/test_autoarray/operators/over_sample/test_over_sampler.py b/test_autoarray/operators/over_sample/test_over_sampler.py index 7da24d79a..a32b11e8f 100644 --- a/test_autoarray/operators/over_sample/test_over_sampler.py +++ b/test_autoarray/operators/over_sample/test_over_sampler.py @@ -70,6 +70,16 @@ def test__binned_array_2d_from(): pixel_scales=1.0, ) + over_sampling = aa.OverSampler( + mask=mask, sub_size=aa.Array2D(values=[2, 2], mask=mask) + ) + + arr = np.array([1.0, 5.0, 7.0, 10.0, 10.0, 10.0, 10.0, 10.0]) + + binned_array_2d = over_sampling.binned_array_2d_from(array=arr) + + assert binned_array_2d.slim == pytest.approx(np.array([5.75, 10.0]), 1.0e-4) + over_sampling = aa.OverSampler( mask=mask, sub_size=aa.Array2D(values=[1, 2], mask=mask) ) From 72af86b04b70f94ae3514f8c28ff38d85abfe133 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 16:07:46 +0100 Subject: [PATCH 17/19] fix all fit tests --- autoarray/fit/fit_util.py | 87 +++++------- test_autoarray/fit/test_fit_util.py | 199 ++++++++++++++-------------- 2 files changed, 135 insertions(+), 151 deletions(-) diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index d40f55d1c..10f24f9a7 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -1,5 +1,6 @@ from functools import wraps -import jax.numpy as np +import jax.numpy as jnp +import numpy as np from autoarray.mask.abstract_mask import Mask @@ -83,7 +84,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - return np.sum(chi_squared_map._array) + return jnp.sum(np.array(chi_squared_map)) def noise_normalization_from(*, noise_map: ty.DataLike) -> float: @@ -97,12 +98,12 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float: noise_map The masked noise-map of the dataset. """ - return np.sum(np.log(2 * np.pi * noise_map._array**2.0)) + return jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map)**2.0)) def normalized_residual_map_complex_from( - *, residual_map: np.ndarray, noise_map: np.ndarray -) -> np.ndarray: + *, residual_map: jnp.ndarray, noise_map: jnp.ndarray +) -> jnp.ndarray: """ Returns the normalized residual-map of the fit of complex model-data to a dataset, where: @@ -126,8 +127,8 @@ def normalized_residual_map_complex_from( def chi_squared_map_complex_from( - *, residual_map: np.ndarray, noise_map: np.ndarray -) -> np.ndarray: + *, residual_map: jnp.ndarray, noise_map: jnp.ndarray +) -> jnp.ndarray: """ Returnss the chi-squared-map of the fit of complex model-data to a dataset, where: @@ -145,7 +146,7 @@ def chi_squared_map_complex_from( return chi_squared_map_real + 1j * chi_squared_map_imag -def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float: +def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float: """ Returns the chi-squared terms of each complex model data's fit to a masked dataset, by summing the masked chi-squared-map of the fit. @@ -157,12 +158,12 @@ def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - chi_squared_real = np.sum(chi_squared_map.real) - chi_squared_imag = np.sum(chi_squared_map.imag) + chi_squared_real = jnp.sum(chi_squared_map.real) + chi_squared_imag = jnp.sum(chi_squared_map.imag) return chi_squared_real + chi_squared_imag -def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float: +def noise_normalization_complex_from(*, noise_map: jnp.ndarray) -> float: """ Returns the noise-map normalization terms of a complex noise-map, summing the noise_map value in every pixel as: @@ -173,8 +174,8 @@ def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float: noise_map The masked noise-map of the dataset. """ - noise_normalization_real = np.sum(np.log(2 * np.pi * noise_map.real**2.0)) - noise_normalization_imag = np.sum(np.log(2 * np.pi * noise_map.imag**2.0)) + noise_normalization_real = jnp.sum(jnp.log(2 * jnp.pi * noise_map.real**2.0)) + noise_normalization_imag = jnp.sum(jnp.log(2 * jnp.pi * noise_map.imag**2.0)) return noise_normalization_real + noise_normalization_imag @@ -198,9 +199,7 @@ def residual_map_with_mask_from( model_data The model data used to fit the data. """ - return np.subtract( - data, model_data, out=np.zeros_like(data), where=np.asarray(mask) == 0 - ) + return jnp.where(jnp.asarray(mask) == 0, jnp.subtract(data, model_data), 0) @to_new_array @@ -223,13 +222,7 @@ def normalized_residual_map_with_mask_from( mask The mask applied to the residual-map, where `False` entries are included in the calculation. """ - return np.divide( - residual_map, - noise_map, - out=np.zeros_like(residual_map), - where=np.asarray(mask) == 0, - ) - + return jnp.where(jnp.asarray(mask) == 0, jnp.divide(residual_map, noise_map), 0) @to_new_array def chi_squared_map_with_mask_from( @@ -251,13 +244,10 @@ def chi_squared_map_with_mask_from( mask The mask applied to the residual-map, where `False` entries are included in the calculation. """ - return np.square( - np.divide( - residual_map, - noise_map, - out=np.zeros_like(residual_map), - where=np.asarray(mask) == 0, - ) + return jnp.where( + jnp.asarray(mask) == 0, + jnp.square(residual_map / noise_map), + 0 ) @@ -275,7 +265,7 @@ def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask) -> f mask The mask applied to the chi-squared-map, where `False` entries are included in the calculation. """ - return float(np.sum(chi_squared_map[np.asarray(mask) == 0])) + return float(jnp.sum(chi_squared_map[jnp.asarray(mask) == 0])) def chi_squared_with_mask_fast_from( @@ -302,14 +292,14 @@ def chi_squared_with_mask_fast_from( The mask applied to the chi-squared-map, where `False` entries are included in the calculation. """ return float( - np.sum( - np.square( - np.divide( - np.subtract( + jnp.sum( + jnp.square( + jnp.divide( + jnp.subtract( data, model_data, - )[np.asarray(mask) == 0], - noise_map[np.asarray(mask) == 0], + )[jnp.asarray(mask) == 0], + noise_map[jnp.asarray(mask) == 0], ) ) ) @@ -331,11 +321,11 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask) -> mask The mask applied to the noise-map, where `False` entries are included in the calculation. """ - return float(np.sum(np.log(2 * np.pi * noise_map[np.asarray(mask) == 0] ** 2.0))) + return float(jnp.sum(jnp.log(2 * jnp.pi * noise_map[jnp.asarray(mask) == 0] ** 2.0))) def chi_squared_with_noise_covariance_from( - *, residual_map: ty.DataLike, noise_covariance_matrix_inv: np.ndarray + *, residual_map: ty.DataLike, noise_covariance_matrix_inv: jnp.ndarray ) -> float: """ Returns the chi-squared value of the fit of model-data to a masked dataset, where @@ -351,7 +341,7 @@ def chi_squared_with_noise_covariance_from( The inverse of the noise covariance matrix. """ - return residual_map @ noise_covariance_matrix_inv @ residual_map + return residual_map.array @ noise_covariance_matrix_inv @ residual_map.array def log_likelihood_from(*, chi_squared: float, noise_normalization: float) -> float: @@ -431,8 +421,8 @@ def log_evidence_from( def residual_flux_fraction_map_from( - *, residual_map: np.ndarray, data: np.ndarray -) -> np.ndarray: + *, residual_map: jnp.ndarray, data: jnp.ndarray +) -> jnp.ndarray: """ Returns the residual flux fraction map of the fit of model-data to a masked dataset, where: @@ -445,12 +435,12 @@ def residual_flux_fraction_map_from( data The data of the dataset. """ - return np.divide(residual_map, data, out=np.zeros_like(residual_map)) + return jnp.where(data != 0, residual_map / data, 0) def residual_flux_fraction_map_with_mask_from( - *, residual_map: np.ndarray, data: np.ndarray, mask: Mask -) -> np.ndarray: + *, residual_map: jnp.ndarray, data: jnp.ndarray, mask: Mask +) -> jnp.ndarray: """ Returnss the residual flux fraction map of the fit of model-data to a masked dataset, where: @@ -467,9 +457,4 @@ def residual_flux_fraction_map_with_mask_from( mask The mask applied to the residual-map, where `False` entries are included in the calculation. """ - return np.divide( - residual_map, - data, - out=np.zeros_like(residual_map), - where=np.asarray(mask) == 0, - ) + return jnp.where(mask == 0, residual_map / data, 0) \ No newline at end of file diff --git a/test_autoarray/fit/test_fit_util.py b/test_autoarray/fit/test_fit_util.py index 641dbf52e..6bbb5f871 100644 --- a/test_autoarray/fit/test_fit_util.py +++ b/test_autoarray/fit/test_fit_util.py @@ -1,41 +1,41 @@ import autoarray as aa -import numpy as np +import jax.numpy as jnp import pytest def test__residual_map_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) - assert (residual_map == np.array([0.0, 0.0, 0.0, 0.0])).all() + assert (residual_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() - data = np.array([10.0, 10.0, 10.0, 10.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) - assert (residual_map == np.array([-1.0, 0.0, 1.0, 2.0])).all() + assert (residual_map == jnp.array([-1.0, 0.0, 1.0, 2.0])).all() def test__residual_map_with_mask_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data ) - assert (residual_map == np.array([0.0, 0.0, 1.0, 0.0])).all() + assert (residual_map == jnp.array([0.0, 0.0, 1.0, 0.0])).all() def test__normalized_residual_map_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -43,9 +43,9 @@ def test__normalized_residual_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert (normalized_residual_map == np.array([0.0, 0.0, 0.0, 0.0])).all() + assert normalized_residual_map == pytest.approx(jnp.array([0.0, 0.0, 0.0, 0.0]), 1.0e-4) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -53,17 +53,14 @@ def test__normalized_residual_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert ( - normalized_residual_map - == np.array([-(1.0 / 2.0), 0.0, (1.0 / 2.0), (2.0 / 2.0)]) - ).all() + assert normalized_residual_map == pytest.approx(jnp.array([-(1.0 / 2.0), 0.0, (1.0 / 2.0), (2.0 / 2.0)]), 1.0e-4) def test__normalized_residual_map_with_mask_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -73,13 +70,15 @@ def test__normalized_residual_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - assert (normalized_residual_map == np.array([0.0, 0.0, (1.0 / 2.0), 0.0])).all() + print(normalized_residual_map) + + assert normalized_residual_map == pytest.approx(jnp.array([0.0, 0.0, (1.0 / 2.0), 0.0]), abs=1.0e-4) def test__normalized_residual_map_complex_from(): - data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = np.array([2.0 + 2.0j, 2.0 + 2.0j]) - model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) + data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = jnp.array([2.0 + 2.0j, 2.0 + 2.0j]) + model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -87,13 +86,13 @@ def test__normalized_residual_map_complex_from(): residual_map=residual_map, noise_map=noise_map ) - assert (normalized_residual_map == np.array([0.5 - 1.0j, 0.5 - 1.0j])).all() + assert (normalized_residual_map == jnp.array([0.5 - 1.0j, 0.5 - 1.0j])).all() def test__chi_squared_map_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -101,9 +100,9 @@ def test__chi_squared_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert (chi_squared_map == np.array([0.0, 0.0, 0.0, 0.0])).all() + assert (chi_squared_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -113,15 +112,15 @@ def test__chi_squared_map_from(): assert ( chi_squared_map - == np.array([(1.0 / 2.0) ** 2.0, 0.0, (1.0 / 2.0) ** 2.0, (2.0 / 2.0) ** 2.0]) + == jnp.array([(1.0 / 2.0) ** 2.0, 0.0, (1.0 / 2.0) ** 2.0, (2.0 / 2.0) ** 2.0]) ).all() def test__chi_squared_map_with_mask_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -131,9 +130,9 @@ def test__chi_squared_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - assert (chi_squared_map == np.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() + assert (chi_squared_map == jnp.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -143,13 +142,13 @@ def test__chi_squared_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - assert (chi_squared_map == np.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() + assert (chi_squared_map == jnp.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() def test__chi_squared_map_complex_from(): - data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = np.array([2.0 + 2.0j, 2.0 + 2.0j]) - model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) + data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = jnp.array([2.0 + 2.0j, 2.0 + 2.0j]) + model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -157,13 +156,13 @@ def test__chi_squared_map_complex_from(): residual_map=residual_map, noise_map=noise_map ) - assert (chi_squared_map == np.array([0.25 + 1.0j, 0.25 + 1.0j])).all() + assert (chi_squared_map == jnp.array([0.25 + 1.0j, 0.25 + 1.0j])).all() def test__chi_squared_with_noise_covariance_from(): resdiual_map = aa.Array2D.no_mask([[1.0, 1.0], [2.0, 2.0]], pixel_scales=1.0) - noise_covariance_matrix_inv = np.array( + noise_covariance_matrix_inv = jnp.array( [ [1.0, 1.0, 4.0, 0.0], [0.0, 1.0, 9.0, 0.0], @@ -181,10 +180,10 @@ def test__chi_squared_with_noise_covariance_from(): def test__chi_squared_with_mask_fast_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - noise_map = np.array([1.0, 2.0, 3.0, 4.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -207,10 +206,10 @@ def test__chi_squared_with_mask_fast_from(): assert chi_squared == pytest.approx(chi_squared_fast, 1.0e-4) - data = np.array([[10.0, 10.0], [10.0, 10.0]]) - mask = np.array([[True, False], [False, True]]) - noise_map = np.array([[1.0, 2.0], [3.0, 4.0]]) - model_data = np.array([[11.0, 10.0], [9.0, 8.0]]) + data = jnp.array([[10.0, 10.0], [10.0, 10.0]]) + mask = jnp.array([[True, False], [False, True]]) + noise_map = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + model_data = jnp.array([[11.0, 10.0], [9.0, 8.0]]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -235,9 +234,9 @@ def test__chi_squared_with_mask_fast_from(): def test__log_likelihood_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - noise_map = np.array([2.0, 2.0, 2.0, 2.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -255,17 +254,17 @@ def test__log_likelihood_from(): chi_squared = 0.0 noise_normalization = ( - np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1.0e-4 ) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -288,17 +287,17 @@ def test__log_likelihood_from(): ((1.0 / 2.0) ** 2.0) + 0.0 + ((1.0 / 2.0) ** 2.0) + ((2.0 / 2.0) ** 2.0) ) noise_normalization = ( - np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) - + np.log(2.0 * np.pi * (2.0**2.0)) + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + + jnp.log(2.0 * jnp.pi * (2.0**2.0)) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1.0e-4 ) - noise_map = np.array([1.0, 2.0, 3.0, 4.0]) + noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -318,10 +317,10 @@ def test__log_likelihood_from(): chi_squared = 1.0 + (1.0 / (3.0**2.0)) + 0.25 noise_normalization = ( - np.log(2 * np.pi * (1.0**2.0)) - + np.log(2 * np.pi * (2.0**2.0)) - + np.log(2 * np.pi * (3.0**2.0)) - + np.log(2 * np.pi * (4.0**2.0)) + jnp.log(2 * jnp.pi * (1.0**2.0)) + + jnp.log(2 * jnp.pi * (2.0**2.0)) + + jnp.log(2 * jnp.pi * (3.0**2.0)) + + jnp.log(2 * jnp.pi * (4.0**2.0)) ) assert log_likelihood == pytest.approx( @@ -330,10 +329,10 @@ def test__log_likelihood_from(): def test__log_likelihood_from__with_mask(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - noise_map = np.array([1.0, 2.0, 3.0, 4.0]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -358,18 +357,18 @@ def test__log_likelihood_from__with_mask(): # chi squared = 0, 0.25, (0.25 and 1.0 are masked) chi_squared = 0.0 + (1.0 / 3.0) ** 2.0 - noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( - 2 * np.pi * (3.0**2.0) + noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( + 2 * jnp.pi * (3.0**2.0) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1e-4 ) - data = np.array([[10.0, 10.0], [10.0, 10.0]]) - mask = np.array([[True, False], [False, True]]) - noise_map = np.array([[1.0, 2.0], [3.0, 4.0]]) - model_data = np.array([[11.0, 10.0], [9.0, 8.0]]) + data = jnp.array([[10.0, 10.0], [10.0, 10.0]]) + mask = jnp.array([[True, False], [False, True]]) + noise_map = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + model_data = jnp.array([[11.0, 10.0], [9.0, 8.0]]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -394,8 +393,8 @@ def test__log_likelihood_from__with_mask(): # chi squared = 0, 0.25, (0.25 and 1.0 are masked) chi_squared = 0.0 + (1.0 / 3.0) ** 2.0 - noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( - 2 * np.pi * (3.0**2.0) + noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( + 2 * jnp.pi * (3.0**2.0) ) assert log_likelihood == pytest.approx( @@ -404,9 +403,9 @@ def test__log_likelihood_from__with_mask(): def test__log_likelihood_from__complex_data(): - data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = np.array([2.0 + 1.0j, 2.0 + 1.0j]) - model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) + data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = jnp.array([2.0 + 1.0j, 2.0 + 1.0j]) + model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -427,8 +426,8 @@ def test__log_likelihood_from__complex_data(): # chi squared = 0.25 and 4.0 chi_squared = 4.25 - noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( - 2 * np.pi * (1.0**2.0) + noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( + 2 * jnp.pi * (1.0**2.0) ) assert log_likelihood == pytest.approx( @@ -457,8 +456,8 @@ def test__log_evidence_from(): def test__residual_flux_fraction_map_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - model_data = np.array([10.0, 10.0, 10.0, 10.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -466,9 +465,9 @@ def test__residual_flux_fraction_map_from(): residual_map=residual_map, data=data ) - assert (residual_flux_fraction_map == np.array([0.0, 0.0, 0.0, 0.0])).all() + assert (residual_flux_fraction_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -476,13 +475,13 @@ def test__residual_flux_fraction_map_from(): residual_map=residual_map, data=data ) - assert (residual_flux_fraction_map == np.array([-0.1, 0.0, 0.1, 0.2])).all() + assert (residual_flux_fraction_map == jnp.array([-0.1, 0.0, 0.1, 0.2])).all() def test__residual_flux_fraction_map_with_mask_from(): - data = np.array([10.0, 10.0, 10.0, 10.0]) - mask = np.array([True, False, False, True]) - model_data = np.array([11.0, 10.0, 9.0, 8.0]) + data = jnp.array([10.0, 10.0, 10.0, 10.0]) + mask = jnp.array([True, False, False, True]) + model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -492,9 +491,9 @@ def test__residual_flux_fraction_map_with_mask_from(): residual_map=residual_map, mask=mask, data=data ) - assert (residual_flux_fraction_map == np.array([0.0, 0.0, 0.1, 0.0])).all() + assert (residual_flux_fraction_map == jnp.array([0.0, 0.0, 0.1, 0.0])).all() - model_data = np.array([11.0, 9.0, 8.0, 8.0]) + model_data = jnp.array([11.0, 9.0, 8.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -504,4 +503,4 @@ def test__residual_flux_fraction_map_with_mask_from(): residual_map=residual_map, mask=mask, data=data ) - assert (residual_flux_fraction_map == np.array([0.0, 0.1, 0.2, 0.0])).all() + assert (residual_flux_fraction_map == jnp.array([0.0, 0.1, 0.2, 0.0])).all() From 340c34d5faea26adb8616df689ba09034b76fdca Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 16:13:21 +0100 Subject: [PATCH 18/19] pylops removed --- autoarray/__init__.py | 1 - autoarray/inversion/inversion/factory.py | 11 -- .../inversion/inversion/interferometer/lop.py | 146 ------------------ .../inversion/regularization/abstract.py | 23 --- autoarray/operators/transformer.py | 29 +--- autoarray/preloads.py | 14 -- autoarray/structures/visibilities.py | 4 +- optional_requirements.txt | 1 - .../test_inversion_interferometer_util.py | 17 -- 9 files changed, 3 insertions(+), 243 deletions(-) delete mode 100644 autoarray/inversion/inversion/interferometer/lop.py diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 1a6d79c53..2300fb5d4 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -44,7 +44,6 @@ from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping -from .inversion.inversion.interferometer.lop import InversionInterferometerMappingPyLops from .inversion.linear_obj.linear_obj import LinearObj from .inversion.linear_obj.func_list import AbstractLinearObjFuncList from .mask.derive.indexes_2d import DeriveIndexes2D diff --git a/autoarray/inversion/inversion/factory.py b/autoarray/inversion/inversion/factory.py index 327262786..350f19e65 100644 --- a/autoarray/inversion/inversion/factory.py +++ b/autoarray/inversion/inversion/factory.py @@ -9,9 +9,6 @@ from autoarray.inversion.inversion.interferometer.w_tilde import ( InversionInterferometerWTilde, ) -from autoarray.inversion.inversion.interferometer.lop import ( - InversionInterferometerMappingPyLops, -) from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList @@ -212,11 +209,3 @@ def inversion_interferometer_from( settings=settings, run_time_dict=run_time_dict, ) - - else: - return InversionInterferometerMappingPyLops( - dataset=dataset, - linear_obj_list=linear_obj_list, - settings=settings, - run_time_dict=run_time_dict, - ) diff --git a/autoarray/inversion/inversion/interferometer/lop.py b/autoarray/inversion/inversion/interferometer/lop.py deleted file mode 100644 index fdd6b8adf..000000000 --- a/autoarray/inversion/inversion/interferometer/lop.py +++ /dev/null @@ -1,146 +0,0 @@ -from scipy import sparse - -import numpy as np -from typing import Dict - -from autoconf import cached_property - -from autoarray.inversion.inversion.interferometer.abstract import ( - AbstractInversionInterferometer, -) -from autoarray.inversion.linear_obj.linear_obj import LinearObj -from autoarray.structures.visibilities import Visibilities - -from autoarray.numba_util import profile_func - - -class InversionInterferometerMappingPyLops(AbstractInversionInterferometer): - """ - Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations - to be solved (see `inversion.inversion.abstract.AbstractInversion` for a full description). - - A linear object describes the mappings between values in observed `data` and the linear object's model via its - `mapping_matrix`. This class constructs linear equations for `Interferometer` objects, where the data is an - an array of visibilities and the mappings include a non-uniform fast Fourier transform operation described by - the interferometer dataset's transformer. - - This class uses the mapping formalism, which constructs the simultaneous linear equations using the - `mapping_matrix` of every linear object. This is performed using the library PyLops, which uses linear - operators to avoid these matrices being created explicitly in memory, making the calculation more efficient. - """ - - @cached_property - @profile_func - def reconstruction(self): - """ - Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12) - of https://arxiv.org/pdf/astro-ph/0302587.pdf - - S is the vector of reconstructed inversion values. - """ - - import pylops - - Aop = pylops.MatrixMult( - sparse.bsr_matrix(self.linear_obj_list[0].mapping_matrix) - ) - - Fop = self.transformer - - Op = Fop * Aop - - MOp = pylops.MatrixMult(sparse.bsr_matrix(self.preconditioner_matrix_inverse)) - - try: - return pylops.NormalEquationsInversion( - Op=Op, - Regs=None, - epsNRs=[1.0], - data=self.data.ordered_1d, - Weight=pylops.Diagonal(diag=self.noise_map.weight_list_ordered_1d), - NRegs=[ - pylops.MatrixMult(sparse.bsr_matrix(self.regularization_matrix)) - ], - M=MOp, - tol=self.settings.tolerance, - atol=self.settings.tolerance, - **dict(maxiter=self.settings.maxiter), - ) - except AttributeError: - return pylops.normal_equations_inversion( - Op=Op, - Regs=None, - epsNRs=[1.0], - y=self.data.ordered_1d, - Weight=pylops.Diagonal(diag=self.noise_map.weight_list_ordered_1d), - NRegs=[ - pylops.MatrixMult(sparse.bsr_matrix(self.regularization_matrix)) - ], - M=MOp, - tol=self.settings.tolerance, - atol=self.settings.tolerance, - **dict(maxiter=self.settings.maxiter), - )[0] - - @property - @profile_func - def mapped_reconstructed_data_dict( - self, - ) -> Dict[LinearObj, Visibilities]: - """ - When constructing the simultaneous linear equations (via vectors and matrices) the quantities of each individual - linear object (e.g. their `mapping_matrix`) are combined into single ndarrays. This does not track which - quantities belong to which linear objects, therefore the linear equation's solutions (which are returned as - ndarrays) do not contain information on which linear object(s) they correspond to. - - For example, consider if two `Mapper` objects with 50 and 100 source pixels are used in an `Inversion`. - The `reconstruction` (which contains the solved for source pixels values) is an ndarray of shape [150], but - the ndarray itself does not track which values belong to which `Mapper`. - - This function converts an ndarray of a `reconstruction` to a dictionary of ndarrays containing each linear - object's reconstructed images, where the keys are the instances of each mapper in the inversion. - - The PyLops calculation bypasses the calculation of the `mapping_matrix` and it therefore cannot be used to map - the reconstruction's values to the image-plane. Instead, the unique data-to-pixelization mappings are used, - including the 2D non-uniform fast Fourier transform operation after mapping is complete. - - Parameters - ---------- - reconstruction - The reconstruction (in the source frame) whose values are mapped to a dictionary of values for each - individual mapper (in the image-plane). - """ - - mapped_reconstructed_image_dict = self.mapped_reconstructed_image_dict - - return { - linear_obj: self.transformer.visibilities_from(image=image) - for linear_obj, image in mapped_reconstructed_image_dict.items() - } - - @cached_property - @profile_func - def preconditioner_matrix(self): - curvature_matrix_approx = np.multiply( - np.sum(self.noise_map.weight_list_ordered_1d), - self.linear_obj_list[0].mapping_matrix.T - @ self.linear_obj_list[0].mapping_matrix, - ) - - return np.add(curvature_matrix_approx, self.regularization_matrix) - - @cached_property - @profile_func - def preconditioner_matrix_inverse(self): - return np.linalg.inv(self.preconditioner_matrix) - - @cached_property - @profile_func - def log_det_curvature_reg_matrix_term(self): - return 2.0 * np.sum( - np.log(np.diag(np.linalg.cholesky(self.preconditioner_matrix))) - ) - - @property - def reconstruction_noise_map(self): - return None diff --git a/autoarray/inversion/regularization/abstract.py b/autoarray/inversion/regularization/abstract.py index 47cadc302..838eaf942 100644 --- a/autoarray/inversion/regularization/abstract.py +++ b/autoarray/inversion/regularization/abstract.py @@ -5,13 +5,6 @@ if TYPE_CHECKING: from autoarray.inversion.linear_obj.linear_obj import LinearObj -try: - import pylops - - PyLopsOperator = pylops.LinearOperator -except ModuleNotFoundError: - PyLopsOperator = object - class AbstractRegularization: def __init__(self): @@ -174,19 +167,3 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ raise NotImplementedError - - -class RegularizationLop(PyLopsOperator): - def __init__(self, regularization_matrix): - self.regularization_matrix = regularization_matrix - self.pixels = regularization_matrix.shape[0] - self.dims = self.pixels - self.shape = (self.pixels, self.pixels) - self.dtype = dtype - self.explicit = False - - def _matvec(self, x): - return np.dot(self.regularization_matrix, x) - - def _rmatvec(self, x): - return np.dot(self.regularization_matrix.T, x) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index acbe6bb73..3b19f7f7c 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -8,21 +8,11 @@ class NUFFTPlaceholder: pass -class PyLopsPlaceholder: - pass - - try: from pynufft.linalg.nufft_cpu import NUFFT_cpu except ModuleNotFoundError: NUFFT_cpu = NUFFTPlaceholder -try: - import pylops - - PyLopsOperator = pylops.LinearOperator -except ModuleNotFoundError: - PyLopsOperator = PyLopsPlaceholder from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.uniform_2d import Grid2D @@ -42,20 +32,8 @@ def pynufft_exception(): ) -def pylops_exception(): - raise ModuleNotFoundError( - "\n--------------------\n" - "You are attempting to perform interferometer analysis.\n\n" - "However, the optional library PyLops (https://github.com/PyLops/pylops) is not installed.\n\n" - "Install it via the command `pip install pylops==2.3.1`.\n\n" - "----------------------" - ) - - -class TransformerDFT(PyLopsOperator): +class TransformerDFT: def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): - if isinstance(self, PyLopsPlaceholder): - pylops_exception() super().__init__() @@ -146,14 +124,11 @@ def transform_mapping_matrix(self, mapping_matrix): ) -class TransformerNUFFT(NUFFT_cpu, PyLopsOperator): +class TransformerNUFFT(NUFFT_cpu): def __init__(self, uv_wavelengths, real_space_mask): if isinstance(self, NUFFTPlaceholder): pynufft_exception() - if isinstance(self, PyLopsPlaceholder): - pylops_exception() - super(TransformerNUFFT, self).__init__() self.uv_wavelengths = uv_wavelengths diff --git a/autoarray/preloads.py b/autoarray/preloads.py index 6808f0f6c..90c1deae9 100644 --- a/autoarray/preloads.py +++ b/autoarray/preloads.py @@ -198,13 +198,6 @@ def set_mapper_list(self, fit_0, fit_1): if fit_0.inversion.total(cls=AbstractMapper) == 0: return - from autoarray.inversion.inversion.interferometer.lop import ( - InversionInterferometerMappingPyLops, - ) - - if isinstance(fit_0.inversion, InversionInterferometerMappingPyLops): - return - inversion_0 = fit_0.inversion inversion_1 = fit_1.inversion @@ -238,13 +231,6 @@ def set_operated_mapping_matrix_with_preloads(self, fit_0, fit_1): self.operated_mapping_matrix = None - from autoarray.inversion.inversion.interferometer.lop import ( - InversionInterferometerMappingPyLops, - ) - - if isinstance(fit_0.inversion, InversionInterferometerMappingPyLops): - return - inversion_0 = fit_0.inversion inversion_1 = fit_1.inversion diff --git a/autoarray/structures/visibilities.py b/autoarray/structures/visibilities.py index 2b4113c7c..8cc94dca5 100644 --- a/autoarray/structures/visibilities.py +++ b/autoarray/structures/visibilities.py @@ -213,9 +213,7 @@ def __init__(self, visibilities: Union[np.ndarray, List[complex]], *args, **kwar A collection of (real, imag) visibilities noise-map values which are used to represent the noise-map in an `Interferometer` dataset. - This data structure behaves the same as the `Visibilities` structure (see `AbstractVisibilities.__new__`). The - only difference is that it includes a `WeightOperator` used by `Inversion`'s which use `LinearOperators` and - the library `PyLops` to fit `Interferometer` data. + This data structure behaves the same as the `Visibilities` structure (see `AbstractVisibilities.__new__`). Parameters ---------- diff --git a/optional_requirements.txt b/optional_requirements.txt index cbb64e17e..4ba479740 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -1,4 +1,3 @@ -pylops>=1.10.0,<=2.3.1 pynufft #jax==0.4.3 #jaxlib==0.4.3 diff --git a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py index 8875ea378..cb7d27673 100644 --- a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py +++ b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py @@ -68,23 +68,6 @@ def test__data_vector_via_transformed_mapping_matrix_from(): assert (data_vector_complex_via_blurred == data_vector_via_transformed).all() -def test__inversion_interferometer__via_mapper( - interferometer_7_no_fft, - rectangular_mapper_7x7_3x3, - delaunay_mapper_9_3x3, - voronoi_mapper_9_3x3, - regularization_constant, -): - inversion = aa.Inversion( - dataset=interferometer_7_no_fft, - linear_obj_list=[rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_linear_operators=True), - ) - - assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangular) - assert isinstance(inversion, aa.InversionInterferometerMappingPyLops) - - def test__w_tilde_curvature_interferometer_from(): noise_map = np.array([1.0, 2.0, 3.0]) uv_wavelengths = np.array([[0.0001, 2.0, 3000.0], [3000.0, 2.0, 0.0001]]) From 39fbf72a9776d25ccdadb0ed0b9825e49b20f1dd Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 4 Apr 2025 16:28:34 +0100 Subject: [PATCH 19/19] unit tests fixed --- autoarray/fit/fit_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index 10f24f9a7..f28a64e22 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -158,8 +158,8 @@ def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - chi_squared_real = jnp.sum(chi_squared_map.real) - chi_squared_imag = jnp.sum(chi_squared_map.imag) + chi_squared_real = jnp.sum(np.array(chi_squared_map.real)) + chi_squared_imag = jnp.sum(np.array(chi_squared_map.imag)) return chi_squared_real + chi_squared_imag @@ -174,8 +174,8 @@ def noise_normalization_complex_from(*, noise_map: jnp.ndarray) -> float: noise_map The masked noise-map of the dataset. """ - noise_normalization_real = jnp.sum(jnp.log(2 * jnp.pi * noise_map.real**2.0)) - noise_normalization_imag = jnp.sum(jnp.log(2 * jnp.pi * noise_map.imag**2.0)) + noise_normalization_real = jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map).real**2.0)) + noise_normalization_imag = jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map).imag**2.0)) return noise_normalization_real + noise_normalization_imag