From b317010a74eebf1331cb8d0662624f82fa3638c4 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 20:47:52 +0100 Subject: [PATCH 01/33] fix mask tests by using == --- autoarray/mask/abstract_mask.py | 9 ++------- autoarray/mask/mask_2d.py | 2 +- autoarray/structures/grids/uniform_2d.py | 2 +- test_autoarray/mask/test_mask_1d.py | 18 ++++++++---------- 4 files changed, 12 insertions(+), 19 deletions(-) diff --git a/autoarray/mask/abstract_mask.py b/autoarray/mask/abstract_mask.py index 30a778aed..d6ed67c2c 100644 --- a/autoarray/mask/abstract_mask.py +++ b/autoarray/mask/abstract_mask.py @@ -2,13 +2,8 @@ from abc import ABC import logging - -from autoarray.numpy_wrapper import np, use_jax - -if use_jax: - import jax -from pathlib import Path -from typing import Dict, Union +import numpy as np +from typing import Dict from autoconf.fitsable import output_to_fits diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 60559e9d4..b6dbf7977 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -793,7 +793,7 @@ def resized_from(self, new_shape, pad_value: int = 0.0) -> Mask2D: """ resized_mask = array_2d_util.resized_array_2d_from( - array_2d=np.array(self), + array_2d=np.array(self._array), resized_shape=new_shape, pad_value=pad_value, ).astype("bool") diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 41559e083..64e2738f9 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -184,7 +184,7 @@ def __init__( over_sample_util.grid_2d_slim_over_sampled_via_mask_from( mask_2d=np.array(self.mask), pixel_scales=self.mask.pixel_scales, - sub_size=np.array(self.over_sampler.sub_size).astype("int"), + sub_size=np.array(self.over_sampler.sub_size._array).astype("int"), origin=self.mask.origin, ) ) diff --git a/test_autoarray/mask/test_mask_1d.py b/test_autoarray/mask/test_mask_1d.py index 41f02d3b7..cb5fe2c56 100644 --- a/test_autoarray/mask/test_mask_1d.py +++ b/test_autoarray/mask/test_mask_1d.py @@ -1,6 +1,4 @@ -from astropy.io import fits import numpy as np -import os from os import path import pytest @@ -53,34 +51,34 @@ def test__constructor__input_is_2d_mask__raises_exception(): def test__is_all_true(): mask = aa.Mask1D(mask=[False, False, False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[False, True, False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[True, True, True, True], pixel_scales=1.0) - assert mask.is_all_true is True + assert mask.is_all_true == True def test__is_all_false(): mask = aa.Mask1D(mask=[False, False, False, False], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask1D(mask=[False, False], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask1D(mask=[False, True, False, False], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False mask = aa.Mask1D(mask=[True, True, False, False], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False From e1899b40cc0313a29a2cefd1ab07ee8d9af24e0a Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 20:49:03 +0100 Subject: [PATCH 02/33] fix more tests using == --- test_autoarray/mask/test_mask_2d.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test_autoarray/mask/test_mask_2d.py b/test_autoarray/mask/test_mask_2d.py index 7e12e5b82..399510a0f 100644 --- a/test_autoarray/mask/test_mask_2d.py +++ b/test_autoarray/mask/test_mask_2d.py @@ -366,37 +366,37 @@ def test__mask__input_is_1d_mask__no_shape_native__raises_exception(): def test__is_all_true(): mask = aa.Mask2D(mask=[[False, False], [False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[False, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[True, True], [True, True]], pixel_scales=1.0) - assert mask.is_all_true is True + assert mask.is_all_true == True def test__is_all_false(): mask = aa.Mask2D(mask=[[False, False], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask2D(mask=[[False, False]], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask2D(mask=[[False, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False mask = aa.Mask2D(mask=[[True, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False def test__shape_native_masked_pixels(): From cb5f13b72ec5770405b41920e6cbf5427eae5454 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 31 Mar 2025 21:33:03 +0100 Subject: [PATCH 03/33] fixes to get basic func_grad to work --- autoarray/dataset/imaging/dataset.py | 2 +- autoarray/structures/arrays/kernel_2d.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index bc529ef96..7ded4d409 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -162,7 +162,7 @@ def __init__( if psf is not None and use_normalized_psf: psf = Kernel2D.no_mask( - values=psf.native, pixel_scales=psf.pixel_scales, normalize=True + values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True ) self.psf = psf diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index f21cf5a80..c03ceaf32 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -53,7 +53,7 @@ def __init__( ) if normalize: - self._array[:] = np.divide(self._array, np.sum(self._array)) + self._array = np.divide(self._array, np.sum(self._array)) @classmethod def no_mask( From 15bf0dbdbb9a7a76980cd0babac4cac5b3b87c98 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:25:02 +0100 Subject: [PATCH 04/33] progress stopped at convolver --- autoarray/dataset/imaging/dataset.py | 2 +- autoarray/operators/convolver.py | 2 +- autoarray/structures/grids/uniform_2d.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 7ded4d409..20bb8b0be 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -193,7 +193,7 @@ def convolver(self): The convolver given the masked imaging data's mask and PSF. """ - return Convolver(mask=self.mask, kernel=self.psf) + return Convolver(mask=self.mask, kernel=Kernel2D(values=self.psf._array, mask=self.psf.mask, header=self.psf.header)) @cached_property def w_tilde(self): diff --git a/autoarray/operators/convolver.py b/autoarray/operators/convolver.py index 5a52329b0..73d3767d5 100644 --- a/autoarray/operators/convolver.py +++ b/autoarray/operators/convolver.py @@ -221,7 +221,7 @@ def __init__(self, mask, kernel): coordinates=(x, y), mask=np.array(mask), mask_index_array=self.mask_index_array, - kernel_2d=np.array(self.kernel.native[:, :]), + kernel_2d=self.kernel.native, ) self.image_frame_1d_indexes[mask_1d_index, :] = ( image_frame_1d_indexes diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 64e2738f9..782298eb9 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -545,7 +545,7 @@ def from_mask( """ grid_1d = grid_2d_util.grid_2d_slim_via_mask_from( - mask_2d=np.array(mask), + mask_2d=mask._array, pixel_scales=mask.pixel_scales, origin=mask.origin, ) From d3649ff116d77dda5a54134b796dbc10923f9f26 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:39:14 +0100 Subject: [PATCH 05/33] updated grid_2d_slim_via_mask_from to be JAX implementation --- autoarray/geometry/geometry_util.py | 2 -- autoarray/structures/grids/grid_2d_util.py | 35 +++++++--------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index a795d42ee..d09089ef6 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -180,7 +180,6 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales -@numba_util.jit() def central_pixel_coordinates_2d_from( shape_native: Tuple[int, int], ) -> Tuple[float, float]: @@ -205,7 +204,6 @@ def central_pixel_coordinates_2d_from( return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) -@numba_util.jit() def central_scaled_coordinate_2d_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 44c75c7c5..e83eaab14 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -1,4 +1,7 @@ from __future__ import annotations + +import jax.numpy as jnp + from autoarray.numpy_wrapper import np, use_jax if use_jax: @@ -233,7 +236,6 @@ def grid_2d_centre_from(grid_2d_slim: np.ndarray) -> Tuple[float, float]: return centre_y, centre_x -@numba_util.jit() def grid_2d_slim_via_mask_from( mask_2d: np.ndarray, pixel_scales: ty.PixelScales, @@ -273,33 +275,18 @@ def grid_2d_slim_via_mask_from( grid_slim = grid_2d_slim_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - total_pixels = np.sum(~mask_2d) - centres_scaled = geometry_util.central_scaled_coordinate_2d_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_slim = ( - (np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled) - * sign - * pixel_scales - ) - else: - index = 0 - grid_slim = np.zeros(shape=(total_pixels, 2)) - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - grid_slim[index, 0] = -(y - centres_scaled[0]) * pixel_scales[0] - grid_slim[index, 1] = (x - centres_scaled[1]) * pixel_scales[1] - index += 1 - - return grid_slim + centres_scaled = jnp.array(centres_scaled) + pixel_scales = jnp.array(pixel_scales) + sign = jnp.array([-1.0, 1.0]) + return ( + (jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T - centres_scaled) + * sign + * pixel_scales + ) def grid_2d_via_mask_from( From adf5eadac1554872326ee158309cda0ef8546bc4 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:41:19 +0100 Subject: [PATCH 06/33] remove numba from grid_2d_centre_from --- autoarray/structures/grids/grid_2d_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index e83eaab14..522bb0525 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -216,7 +216,6 @@ def convert_grid_2d_to_native( ) -@numba_util.jit() def grid_2d_centre_from(grid_2d_slim: np.ndarray) -> Tuple[float, float]: """ Returns the centre of a grid from a 1D grid. From 31cdd33d54942bcd276d45ac873333e7576dfb56 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:42:18 +0100 Subject: [PATCH 07/33] remove numba from pixel_coordinates_2d_from -> fixes is circular --- autoarray/geometry/geometry_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index d09089ef6..0477fdbc9 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -242,7 +242,6 @@ def central_scaled_coordinate_2d_from( return (y_pixel, x_pixel) -@numba_util.jit() def pixel_coordinates_2d_from( scaled_coordinates_2d: Tuple[float, float], shape_native: Tuple[int, int], From ff1e811e9cc8cb26faeb0f77e8fc3bbae6018744 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 18:57:12 +0100 Subject: [PATCH 08/33] fixing grid_2d_slim_over_sampled_via_mask_from to use numba --- .../over_sampling/over_sample_util.py | 86 +++++++++---------- autoarray/structures/grids/uniform_2d.py | 4 +- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index a98276896..70dc220c4 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -1,8 +1,6 @@ from __future__ import annotations import numpy as np from typing import TYPE_CHECKING, Union, List, Tuple -from autoarray.numpy_wrapper import np, register_pytree_node_class, use_jax, jit - from typing import List, Tuple from autoarray.structures.arrays.uniform_2d import Array2D @@ -364,11 +362,11 @@ def sub_size_radial_bins_from( for i in range(radial_grid.shape[0]): for j in range(len(radial_list)): if radial_grid[i] < radial_list[j]: - if use_jax: - # while this makes it run, it is very, very slow - sub_size = sub_size.at[i].set(sub_size_list[j]) - else: - sub_size[i] = sub_size_list[j] + # if use_jax: + # # while this makes it run, it is very, very slow + # sub_size = sub_size.at[i].set(sub_size_list[j]) + # else: + sub_size[i] = sub_size_list[j] break return sub_size @@ -447,35 +445,35 @@ def grid_2d_slim_over_sampled_via_mask_from( for y1 in range(sub): for x1 in range(sub): - if use_jax: - # while this makes it run, it is very, very slow - grid_slim = grid_slim.at[sub_index, 0].set( - -( - y_scaled - - y_sub_half - + y1 * y_sub_step - + (y_sub_step / 2.0) - ) - ) - grid_slim = grid_slim.at[sub_index, 1].set( - x_scaled - - x_sub_half - + x1 * x_sub_step - + (x_sub_step / 2.0) - ) - else: - grid_slim[sub_index, 0] = -( - y_scaled - - y_sub_half - + y1 * y_sub_step - + (y_sub_step / 2.0) - ) - grid_slim[sub_index, 1] = ( - x_scaled - - x_sub_half - + x1 * x_sub_step - + (x_sub_step / 2.0) - ) + # if use_jax: + # # while this makes it run, it is very, very slow + # grid_slim = grid_slim.at[sub_index, 0].set( + # -( + # y_scaled + # - y_sub_half + # + y1 * y_sub_step + # + (y_sub_step / 2.0) + # ) + # ) + # grid_slim = grid_slim.at[sub_index, 1].set( + # x_scaled + # - x_sub_half + # + x1 * x_sub_step + # + (x_sub_step / 2.0) + # ) + # else: + grid_slim[sub_index, 0] = -( + y_scaled + - y_sub_half + + y1 * y_sub_step + + (y_sub_step / 2.0) + ) + grid_slim[sub_index, 1] = ( + x_scaled + - x_sub_half + + x1 * x_sub_step + + (x_sub_step / 2.0) + ) sub_index += 1 index += 1 @@ -544,14 +542,14 @@ def binned_array_2d_from( for y1 in range(sub): for x1 in range(sub): - if use_jax: - binned_array_2d_slim = binned_array_2d_slim.at[index].add( - array_2d[sub_index] * sub_fraction[index] - ) - else: - binned_array_2d_slim[index] += ( - array_2d[sub_index] * sub_fraction[index] - ) + # if use_jax: + # binned_array_2d_slim = binned_array_2d_slim.at[index].add( + # array_2d[sub_index] * sub_fraction[index] + # ) + # else: + binned_array_2d_slim[index] += ( + array_2d[sub_index] * sub_fraction[index] + ) sub_index += 1 index += 1 diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 782298eb9..846d15cc8 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -1,8 +1,10 @@ from __future__ import annotations -from autoarray.numpy_wrapper import np, use_jax +import numpy as np from pathlib import Path from typing import List, Optional, Tuple, Union +from autoarray.numpy_wrapper import use_jax + from autoconf import conf from autoconf import cached_property from autoconf.fitsable import ndarray_via_fits_from From b322a3faf36ea820d5f878cef30d542698778957 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:00:32 +0100 Subject: [PATCH 09/33] removed use of use_jax in one function --- autoarray/geometry/geometry_util.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 0477fdbc9..bc99dfbd7 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -1,4 +1,6 @@ +import jax.numpy as jnp from typing import Tuple, Union + from autoarray.numpy_wrapper import np, use_jax from autoarray import numba_util @@ -179,7 +181,7 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales - +@numba_util.jit() def central_pixel_coordinates_2d_from( shape_native: Tuple[int, int], ) -> Tuple[float, float]: @@ -203,7 +205,7 @@ def central_pixel_coordinates_2d_from( """ return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) - +@numba_util.jit() def central_scaled_coordinate_2d_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, @@ -379,18 +381,21 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - if use_jax: - shifted_grid_2d = grid_2d.array - np.array(centre) - else: - shifted_grid_2d = grid_2d - np.array(centre) - radius = np.sqrt(np.sum(shifted_grid_2d**2.0, axis=1)) - theta_coordinate_to_profile = np.arctan2( + # if use_jax: + # shifted_grid_2d = grid_2d.array - np.array(centre) + # else: + # shifted_grid_2d = grid_2d - np.array(centre) + + shifted_grid_2d = grid_2d.array - jnp.array(centre) + + radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) + theta_coordinate_to_profile = jnp.arctan2( shifted_grid_2d[:, 0], shifted_grid_2d[:, 1] - ) - np.radians(angle) - return np.vstack( + ) - jnp.radians(angle) + return jnp.vstack( [ - radius * np.sin(theta_coordinate_to_profile), - radius * np.cos(theta_coordinate_to_profile), + radius * jnp.sin(theta_coordinate_to_profile), + radius * jnp.cos(theta_coordinate_to_profile), ] ).T From 9e3c76c3b974d808972d4358505088798456ea3b Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:04:04 +0100 Subject: [PATCH 10/33] grid_pixels_2d_slim_from now uses native numpy, could support JAX --- autoarray/geometry/geometry_util.py | 93 +++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 25 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index bc99dfbd7..48a271341 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -182,7 +182,7 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales @numba_util.jit() -def central_pixel_coordinates_2d_from( +def central_pixel_coordinates_2d_numba_from( shape_native: Tuple[int, int], ) -> Tuple[float, float]: """ @@ -206,7 +206,7 @@ def central_pixel_coordinates_2d_from( return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) @numba_util.jit() -def central_scaled_coordinate_2d_from( +def central_scaled_coordinate_2d_numba_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, origin: Tuple[float, float] = (0.0, 0.0), @@ -234,7 +234,7 @@ def central_scaled_coordinate_2d_from( The central coordinates of the 2D data structure in scaled units. """ - central_pixel_coordinates = central_pixel_coordinates_2d_from( + central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( shape_native=shape_native ) @@ -244,6 +244,67 @@ def central_scaled_coordinate_2d_from( return (y_pixel, x_pixel) +def central_pixel_coordinates_2d_from( + shape_native: Tuple[int, int], +) -> Tuple[float, float]: + """ + Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) + from the shape of that data structure. + + Examples of the central pixels are as follows: + + - For a 3x3 image, the central pixel is pixel [1, 1]. + - For a 4x4 image, the central pixel is [1.5, 1.5]. + + Parameters + ---------- + shape_native + The dimensions of the data structure, which can be in 1D, 2D or higher dimensions. + + Returns + ------- + The central pixel coordinates of the data structure. + """ + return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) + + +def central_scaled_coordinate_2d_from( + shape_native: Tuple[int, int], + pixel_scales: ty.PixelScales, + origin: Tuple[float, float] = (0.0, 0.0), +) -> Tuple[float, float]: + """ + Returns the central scaled coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) + from the shape of that data structure. + + This is computed by using the data structure's shape and converting it to scaled units using an input + pixel-coordinates to scaled-coordinate conversion factor `pixel_scales`. + + The origin of the scaled grid can also be input and moved from (0.0, 0.0). + + Parameters + ---------- + shape_native + The 2D shape of the data structure whose central scaled coordinates are computed. + pixel_scales + The (y,x) scaled units to pixel units conversion factor of the 2D data structure. + origin + The (y,x) scaled units origin of the coordinate system the central scaled coordinate is computed on. + + Returns + ------- + The central coordinates of the 2D data structure in scaled units. + """ + + central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( + shape_native=shape_native + ) + + y_pixel = central_pixel_coordinates[0] + (origin[0] / pixel_scales[0]) + x_pixel = central_pixel_coordinates[1] - (origin[1] / pixel_scales[1]) + + return (y_pixel, x_pixel) + def pixel_coordinates_2d_from( scaled_coordinates_2d: Tuple[float, float], shape_native: Tuple[int, int], @@ -437,7 +498,6 @@ def transform_grid_2d_from_reference_frame( return np.vstack((y, x)).T -@numba_util.jit() def grid_pixels_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -478,30 +538,13 @@ def grid_pixels_2d_slim_from( grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_scaled_2d_slim=grid_scaled_2d_slim, shape=(2,2), pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1, 1]) - return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 - else: - grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_pixels_2d_slim[slim_index, 0] = ( - (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d_slim[slim_index, 1] = ( - (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) - - return grid_pixels_2d_slim + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1, 1]) + return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 @numba_util.jit() From ead617ee95c9fd7b54b0c4d2ed8b28f1af255388 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:05:13 +0100 Subject: [PATCH 11/33] grid_pixel_centres_2d_slim_from, could support JAX --- autoarray/geometry/geometry_util.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 48a271341..a39cba235 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -547,7 +547,6 @@ def grid_pixels_2d_slim_from( return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 -@numba_util.jit() def grid_pixel_centres_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -592,29 +591,13 @@ def grid_pixel_centres_2d_slim_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_pixels_2d_slim = ( - (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 - ).astype(int) - else: - grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) - - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_pixels_2d_slim[slim_index, 0] = int( - (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d_slim[slim_index, 1] = int( - (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + return ( + (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 + ).astype(int) - return grid_pixels_2d_slim @numba_util.jit() From 2769aaf1a77ddfe6834a2e15b93e49051ae947c8 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:06:40 +0100 Subject: [PATCH 12/33] grid_pixel_indexes_2d_slim_from, could support JAX --- autoarray/geometry/geometry_util.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index a39cba235..0d7cf82a9 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -600,7 +600,6 @@ def grid_pixel_centres_2d_slim_from( -@numba_util.jit() def grid_pixel_indexes_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -653,22 +652,12 @@ def grid_pixel_indexes_2d_slim_from( origin=origin, ) - if use_jax: - grid_pixel_indexes_2d_slim = ( - (grid_pixels_2d_slim * np.array([shape_native[1], 1])) - .sum(axis=1) - .astype(int) - ) - else: - grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0]) - - for slim_index in range(grid_pixels_2d_slim.shape[0]): - grid_pixel_indexes_2d_slim[slim_index] = int( - grid_pixels_2d_slim[slim_index, 0] * shape_native[1] - + grid_pixels_2d_slim[slim_index, 1] - ) + return ( + (grid_pixels_2d_slim * np.array([shape_native[1], 1])) + .sum(axis=1) + .astype(int) + ) - return grid_pixel_indexes_2d_slim @numba_util.jit() From b2ba6bd45017bc6c8325a96209c0439180925836 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:07:02 +0100 Subject: [PATCH 13/33] grid_scaled_2d_slim_from, could support JAX --- autoarray/geometry/geometry_util.py | 31 ++++++++--------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 0d7cf82a9..aab4c94aa 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -599,7 +599,6 @@ def grid_pixel_centres_2d_slim_from( ).astype(int) - def grid_pixel_indexes_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -660,7 +659,6 @@ def grid_pixel_indexes_2d_slim_from( -@numba_util.jit() def grid_scaled_2d_slim_from( grid_pixels_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -699,30 +697,17 @@ def grid_scaled_2d_slim_from( grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_pixels_2d_slim=grid_pixels_2d_slim, shape=(2,2), pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1, 1]) - grid_scaled_2d_slim = ( - (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign - ) - else: - grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2)) - - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_scaled_2d_slim[slim_index, 0] = ( - -(grid_pixels_2d_slim[slim_index, 0] - centres_scaled[0] - 0.5) - * pixel_scales[0] - ) - grid_scaled_2d_slim[slim_index, 1] = ( - grid_pixels_2d_slim[slim_index, 1] - centres_scaled[1] - 0.5 - ) * pixel_scales[1] - - return grid_scaled_2d_slim + + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1, 1]) + return ( + (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign + ) + @numba_util.jit() From 05321044405cacfa803d033b45201df15c672c51 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:07:48 +0100 Subject: [PATCH 14/33] grid_pixel_centres_2d_from, could support JAX --- autoarray/geometry/geometry_util.py | 32 ++++++----------------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index aab4c94aa..5dca66c87 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -709,8 +709,6 @@ def grid_scaled_2d_slim_from( ) - -@numba_util.jit() def grid_pixel_centres_2d_from( grid_scaled_2d: np.ndarray, shape_native: Tuple[int, int], @@ -755,30 +753,12 @@ def grid_pixel_centres_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_pixels_2d = ( - (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 - ).astype(int) - else: - grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2)) - - for y in range(grid_scaled_2d.shape[0]): - for x in range(grid_scaled_2d.shape[1]): - grid_pixels_2d[y, x, 0] = int( - (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d[y, x, 1] = int( - (grid_scaled_2d[y, x, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) - - return grid_pixels_2d + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + return ( + (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 + ).astype(int) def extent_symmetric_from( From d90ff2ebb130e86b8cd9b97c3868fc95562cb06c Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:08:10 +0100 Subject: [PATCH 15/33] explciit separate imports --- autoarray/geometry/geometry_util.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 5dca66c87..5da893fac 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -1,7 +1,7 @@ import jax.numpy as jnp +import numpy as np from typing import Tuple, Union -from autoarray.numpy_wrapper import np, use_jax from autoarray import numba_util from autoarray import type as ty @@ -442,11 +442,6 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - # if use_jax: - # shifted_grid_2d = grid_2d.array - np.array(centre) - # else: - # shifted_grid_2d = grid_2d - np.array(centre) - shifted_grid_2d = grid_2d.array - jnp.array(centre) radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) From 59b21e999aa1879d1ff0fe8818f6d195027d8671 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:14:36 +0100 Subject: [PATCH 16/33] fix unit test in test__transform_2d_grid_from_reference_frame --- autoarray/geometry/geometry_util.py | 3 +-- test_autoarray/geometry/test_geometry_util.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 5da893fac..405262f96 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -442,7 +442,7 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - shifted_grid_2d = grid_2d.array - jnp.array(centre) + shifted_grid_2d = np.array(grid_2d) - jnp.array(centre) radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) theta_coordinate_to_profile = jnp.arctan2( @@ -653,7 +653,6 @@ def grid_pixel_indexes_2d_slim_from( ) - def grid_scaled_2d_slim_from( grid_pixels_2d_slim: np.ndarray, shape_native: Tuple[int, int], diff --git a/test_autoarray/geometry/test_geometry_util.py b/test_autoarray/geometry/test_geometry_util.py index b6b8b38d6..2fa8ce63c 100644 --- a/test_autoarray/geometry/test_geometry_util.py +++ b/test_autoarray/geometry/test_geometry_util.py @@ -1071,7 +1071,8 @@ def test__transform_2d_grid_from_reference_frame(): grid_2d=transformed_grid_2d, centre=(8.0, 5.0), angle=137.0 ) - assert grid_2d == pytest.approx(original_grid_2d, 1.0e-4) + + assert (np.abs(grid_2d - original_grid_2d) < 1e-4).all() def test__grid_pixels_2d_slim_from(): From c453a3c1951a0e3e3c9416e6f041c16a38bd61ac Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:17:43 +0100 Subject: [PATCH 17/33] use absolute tolerance to fix geomtry util unit tests --- test_autoarray/geometry/test_geometry_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test_autoarray/geometry/test_geometry_util.py b/test_autoarray/geometry/test_geometry_util.py index 2fa8ce63c..4bc2706dc 100644 --- a/test_autoarray/geometry/test_geometry_util.py +++ b/test_autoarray/geometry/test_geometry_util.py @@ -980,7 +980,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[0.0, 1.0], [1.0, 1.0], [1.0, 0.0]]) + np.array([[0.0, 1.0], [1.0, 1.0], [1.0, 0.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -994,7 +994,7 @@ def test__transform_2d_grid_to_reference_frame(): [0.0, np.sqrt(2)], [np.sqrt(2) / 2.0, np.sqrt(2) / 2.0], ] - ) + ), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1002,7 +1002,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[-1.0, 0.0], [-1.0, 1.0], [0.0, 1.0]]) + np.array([[-1.0, 0.0], [-1.0, 1.0], [0.0, 1.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1010,7 +1010,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[0.0, -1.0], [-1.0, -1.0], [-1.0, 0.0]]) + np.array([[0.0, -1.0], [-1.0, -1.0], [-1.0, 0.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1072,7 +1072,7 @@ def test__transform_2d_grid_from_reference_frame(): ) - assert (np.abs(grid_2d - original_grid_2d) < 1e-4).all() + assert grid_2d == pytest.approx(original_grid_2d, abs=1.0e-4) def test__grid_pixels_2d_slim_from(): From 0c4bb3094f4398fe1c69d606f1961d01e0f4c5a0 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:19:09 +0100 Subject: [PATCH 18/33] fix test__pixel_coordinates_2d_from --- autoarray/geometry/geometry_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index 405262f96..b646c7d08 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -412,7 +412,7 @@ def scaled_coordinates_2d_from( origin=(0.0, 0.0) ) """ - central_scaled_coordinates = central_scaled_coordinate_2d_from( + central_scaled_coordinates = central_scaled_coordinate_2d_numba_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origins ) From d891947e6040d52f3057836d8f62e6557e95e0f6 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:26:06 +0100 Subject: [PATCH 19/33] cleaned up jax imports of array_2d_util to make more tests pass --- autoarray/structures/arrays/array_2d_util.py | 103 +++++-------------- 1 file changed, 28 insertions(+), 75 deletions(-) diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index c75cc9750..796c5d965 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -1,5 +1,7 @@ from __future__ import annotations - +import jax +import jax.numpy as jnp +import numpy as np from typing import TYPE_CHECKING, List, Tuple, Union if TYPE_CHECKING: @@ -9,11 +11,8 @@ from autoarray.mask import mask_2d_util from autoarray import exc -from autoarray.numpy_wrapper import use_jax, np, jit from functools import partial -if use_jax: - import jax def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: @@ -25,13 +24,9 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - if use_jax: - array = jax.lax.cond( + array = jax.lax.cond( type(array) is list, lambda _: np.asarray(array), lambda _: array, None ) - elif type(array) is list: - array = np.asarray(array) - return array @@ -42,13 +37,9 @@ def exception_message(): ) cond = len(array_2d.shape) != 1 - if use_jax: - jax.lax.cond( - cond, lambda _: jax.debug.callback(exception_message), lambda _: None, None - ) - elif cond: - exception_message() - + jax.lax.cond( + cond, lambda _: jax.debug.callback(exception_message), lambda _: None, None + ) def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): """ @@ -88,15 +79,12 @@ def exception_message_1(): array_2d.shape[0] != mask_2d.pixels_in_mask ) - if use_jax: - jax.lax.cond( - cond_1, - lambda _: jax.debug.callback(exception_message_1), - lambda _: None, - None, - ) - elif cond_1: - exception_message_1() + jax.lax.cond( + cond_1, + lambda _: jax.debug.callback(exception_message_1), + lambda _: None, + None, + ) def exception_message_2(): raise exc.ArrayException( @@ -114,16 +102,12 @@ def exception_message_2(): cond_2 = (len(array_2d.shape) == 2) and (array_2d.shape != mask_2d.shape_native) - if use_jax: - jax.lax.cond( - cond_2, - lambda _: jax.debug.callback(exception_message_2), - lambda _: None, - None, - ) - elif cond_2: - exception_message_2() - + jax.lax.cond( + cond_2, + lambda _: jax.debug.callback(exception_message_2), + lambda _: None, + None, + ) def convert_array_2d( array_2d: Union[np.ndarray, List], @@ -159,8 +143,7 @@ def convert_array_2d( check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d) is_native = len(array_2d.shape) == 2 - if use_jax: - mask_2d = mask_2d.array + mask_2d = np.array(mask_2d) if is_native and not skip_mask: array_2d *= np.invert(mask_2d) @@ -526,8 +509,6 @@ def index_slim_for_index_2d_from(indexes_2d: np.ndarray, shape_native) -> np.nda return index_slim_for_index_native_2d - -@numba_util.jit() def array_2d_slim_from( array_2d_native: np.ndarray, mask_2d: np.ndarray, @@ -571,23 +552,7 @@ def array_2d_slim_from( array_2d_slim = array_2d_slim_from(mask=mask, array_2d=array_2d) """ - - if use_jax: - array_2d_slim = array_2d_native[~mask_2d.astype(bool)] - else: - total_pixels = np.sum(~mask_2d) - - array_2d_slim = np.zeros(shape=total_pixels) - index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - array_2d_slim[index] = array_2d_native[y, x] - index += 1 - - return array_2d_slim - + return array_2d_native[~mask_2d.astype(bool)] def array_2d_native_from( array_2d_slim: np.ndarray, @@ -641,8 +606,7 @@ def array_2d_native_from( ) -@partial(jit, static_argnums=(1,)) -@numba_util.jit() +@partial(jax.jit, static_argnums=(1,)) def array_2d_via_indexes_from( array_2d_slim: np.ndarray, shape: Tuple[int, int], @@ -675,22 +639,11 @@ def array_2d_via_indexes_from( ndarray The native 2D array of values mapped from the slimmed array with dimensions (total_values, total_values). """ - if use_jax: - array_native_2d = ( - np.zeros(shape) - .at[tuple(native_index_for_slim_index_2d.T)] - .set(array_2d_slim) - ) - else: - array_native_2d = np.zeros(shape) - - for slim_index in range(len(native_index_for_slim_index_2d)): - array_native_2d[ - native_index_for_slim_index_2d[slim_index, 0], - native_index_for_slim_index_2d[slim_index, 1], - ] = array_2d_slim[slim_index] - - return array_native_2d + return ( + jnp.zeros(shape) + .at[tuple(native_index_for_slim_index_2d.T)] + .set(array_2d_slim) + ) @numba_util.jit() @@ -725,7 +678,7 @@ def array_2d_slim_complex_from( A 1D array of values mapped from the 2D array with dimensions (total_unmasked_pixels). """ - total_pixels = np.sum(~mask_2d) + total_pixels = np.sum(~mask) array_1d = 0 + 0j * np.zeros(shape=total_pixels) index = 0 From ea7aa9d7a6402c7c2dc3eff460b3ea6c335295fc Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:30:04 +0100 Subject: [PATCH 20/33] cleanup imports of grid_2d_util --- autoarray/structures/grids/grid_2d_util.py | 62 +++++++--------------- 1 file changed, 20 insertions(+), 42 deletions(-) diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 522bb0525..574c8898d 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -1,11 +1,7 @@ from __future__ import annotations - +import numpy as np import jax.numpy as jnp - -from autoarray.numpy_wrapper import np, use_jax - -if use_jax: - import jax +import jax from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -16,7 +12,6 @@ from autoarray.structures.arrays import array_2d_util from autoarray.geometry import geometry_util from autoarray import numba_util -from autoarray.mask import mask_2d_util from autoarray import type as ty @@ -73,15 +68,12 @@ def exception_message(): """ ) - if use_jax: - jax.lax.cond( - grid_2d.shape[0] != mask_2d.pixels_in_mask, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - elif grid_2d.shape[0] != mask_2d.pixels_in_mask: - exception_message() + jax.lax.cond( + grid_2d.shape[0] != mask_2d.pixels_in_mask, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None, + ) elif len(grid_2d.shape) == 3: @@ -96,15 +88,12 @@ def exception_message(): """ ) - if use_jax: - jax.lax.cond( - (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - elif (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: - exception_message() + jax.lax.cond( + (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, + lambda _: jax.debug.callback(exception_message), + lambda _: None, + None, + ) def convert_grid_2d( @@ -140,12 +129,8 @@ def convert_grid_2d( is_native = len(grid_2d.shape) == 3 if is_native: - if use_jax: - grid_2d = grid_2d.at[:, :, 0].multiply(np.invert(mask_2d.array)) - grid_2d = grid_2d.at[:, :, 1].multiply(np.invert(mask_2d.array)) - else: - grid_2d[:, :, 0] *= np.invert(mask_2d) - grid_2d[:, :, 1] *= np.invert(mask_2d) + grid_2d = grid_2d.at[:, :, 0].multiply(np.invert(mask_2d.array)) + grid_2d = grid_2d.at[:, :, 1].multiply(np.invert(mask_2d.array)) if is_native == store_native: return grid_2d @@ -154,17 +139,10 @@ def convert_grid_2d( grid_2d_native=np.array(grid_2d), mask=np.array(mask_2d), ) - if use_jax: - return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d.array), - mask_2d=np.array(mask_2d), - ) - else: - return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d), - mask_2d=np.array(mask_2d), - ) - + return grid_2d_native_from( + grid_2d_slim=np.array(grid_2d.array), + mask_2d=np.array(mask_2d), + ) def convert_grid_2d_to_slim( grid_2d: Union[np.ndarray, List], mask_2d: Mask2D From 4014d0378048796028a15fd99fcf2d12e2325a4d Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:38:26 +0100 Subject: [PATCH 21/33] convert methods in grid_2d_util assume ndarray --- .../operators/over_sampling/over_sampler.py | 4 +-- autoarray/structures/grids/grid_2d_util.py | 29 +++++-------------- autoarray/structures/grids/uniform_2d.py | 4 +-- 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 6492c00b7..aa80392b4 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -147,11 +147,11 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): ) def tree_flatten(self): - return (self.mask,), () + return (self.mask, self.sub_size), () @classmethod def tree_unflatten(cls, aux_data, children): - return cls(mask=children[0]) + return cls(mask=children[0], sub_size=children[1]) @property def sub_total(self): diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 574c8898d..6c72c00ef 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -55,29 +55,20 @@ def check_grid_2d(grid_2d: np.ndarray): def check_grid_2d_and_mask_2d(grid_2d: np.ndarray, mask_2d: Mask2D): if len(grid_2d.shape) == 2: - - def exception_message(): + if grid_2d.shape[0] != mask_2d.pixels_in_mask: raise exc.GridException( f""" The input 2D grid does not have the same number of values as pixels in the mask. - + The shape of the input grid_2d is {grid_2d.shape}. The mask shape_native is {mask_2d.shape_native}. The mask number of pixels is {mask_2d.pixels_in_mask}. """ ) - jax.lax.cond( - grid_2d.shape[0] != mask_2d.pixels_in_mask, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - elif len(grid_2d.shape) == 3: - - def exception_message(): + if (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: raise exc.GridException( f""" The input 2D grid is not the same dimensions as the mask @@ -88,13 +79,6 @@ def exception_message(): """ ) - jax.lax.cond( - (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - def convert_grid_2d( grid_2d: Union[np.ndarray, List], mask_2d: Mask2D, store_native: bool = False @@ -129,8 +113,8 @@ def convert_grid_2d( is_native = len(grid_2d.shape) == 3 if is_native: - grid_2d = grid_2d.at[:, :, 0].multiply(np.invert(mask_2d.array)) - grid_2d = grid_2d.at[:, :, 1].multiply(np.invert(mask_2d.array)) + grid_2d[:, :, 0] *= np.invert(mask_2d) + grid_2d[:, :, 1] *= np.invert(mask_2d) if is_native == store_native: return grid_2d @@ -140,10 +124,11 @@ def convert_grid_2d( mask=np.array(mask_2d), ) return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d.array), + grid_2d_slim=np.array(grid_2d), mask_2d=np.array(mask_2d), ) + def convert_grid_2d_to_slim( grid_2d: Union[np.ndarray, List], mask_2d: Mask2D ) -> np.ndarray: diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 846d15cc8..f0d20042c 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -1,4 +1,5 @@ from __future__ import annotations +import jax.numpy as jnp import numpy as np from pathlib import Path from typing import List, Optional, Tuple, Union @@ -14,7 +15,6 @@ from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.structures.arrays import array_2d_util from autoarray.structures.grids import grid_2d_util from autoarray.geometry import geometry_util from autoarray.operators.over_sampling import over_sample_util @@ -162,7 +162,7 @@ def __init__( over sampled grid is not passed in it is computed assuming uniformity. """ values = grid_2d_util.convert_grid_2d( - grid_2d=values, + grid_2d=np.array(values), mask_2d=mask, store_native=store_native, ) From 075654fc7fe33d3dcc0b3db0acc595239bb35dec Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:42:57 +0100 Subject: [PATCH 22/33] more simlpifying of convert functions --- autoarray/structures/arrays/array_2d_util.py | 83 +++++++------------- autoarray/structures/arrays/uniform_2d.py | 7 +- 2 files changed, 30 insertions(+), 60 deletions(-) diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index 796c5d965..f269d0106 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -15,6 +15,7 @@ + def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: """ If the input array input a convert is of type list, convert it to type NumPy array. @@ -24,23 +25,17 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - array = jax.lax.cond( - type(array) is list, lambda _: np.asarray(array), lambda _: array, None - ) + array = np.asarray(array) + return array def check_array_2d(array_2d: np.ndarray): - def exception_message(): + if len(array_2d.shape) != 1: raise exc.ArrayException( "An array input into the Array2D.__new__ method is not of shape 1." ) - cond = len(array_2d.shape) != 1 - jax.lax.cond( - cond, lambda _: jax.debug.callback(exception_message), lambda _: None, None - ) - def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): """ The `manual` classmethods in the `Array2D` object take as input a list or ndarray which is returned as an @@ -57,57 +52,38 @@ def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): mask_2d The mask of the output Array2D. """ + if len(array_2d.shape) == 1: + if array_2d.shape[0] != mask_2d.pixels_in_mask: + raise exc.ArrayException( + f""" + The input array is a slim 1D array, but it does not have the same number of entries as pixels in + the mask. - def exception_message_1(): - raise exc.ArrayException( - f""" - The input array is a slim 1D array, but it does not have the same number of entries as pixels in - the mask. - - This indicates that the number of unmaksed pixels in the mask is different to the input slim array - shape. - - The shapes of the two arrays (which this exception is raised because they are different) are as follows: - - Input array_2d_slim.shape = {array_2d.shape[0]} - Input mask_2d.pixels_in_mask = {mask_2d.pixels_in_mask} - Input mask_2d.shape_native = {mask_2d.shape_native} - """ - ) - - cond_1 = (len(array_2d.shape) == 1) and ( - array_2d.shape[0] != mask_2d.pixels_in_mask - ) - - jax.lax.cond( - cond_1, - lambda _: jax.debug.callback(exception_message_1), - lambda _: None, - None, - ) + This indicates that the number of unmaksed pixels in the mask is different to the input slim array + shape. - def exception_message_2(): - raise exc.ArrayException( - f""" - The input array is 2D but not the same dimensions as the mask. + The shapes of the two arrays (which this exception is raised because they are different) are as follows: - This indicates the mask's shape is different to the input array shape. + Input array_2d_slim.shape = {array_2d.shape[0]} + Input mask_2d.pixels_in_mask = {mask_2d.pixels_in_mask} + Input mask_2d.shape_native = {mask_2d.shape_native} + """ + ) - The shapes of the two arrays (which this exception is raised because they are different) are as follows: + if len(array_2d.shape) == 2: + if array_2d.shape != mask_2d.shape_native: + raise exc.ArrayException( + f""" + The input array is 2D but not the same dimensions as the mask. - Input array_2d shape = {array_2d.shape} - Input mask_2d shape_native = {mask_2d.shape_native} - """ - ) + This indicates the mask's shape is different to the input array shape. - cond_2 = (len(array_2d.shape) == 2) and (array_2d.shape != mask_2d.shape_native) + The shapes of the two arrays (which this exception is raised because they are different) are as follows: - jax.lax.cond( - cond_2, - lambda _: jax.debug.callback(exception_message_2), - lambda _: None, - None, - ) + Input array_2d shape = {array_2d.shape} + Input mask_2d shape_native = {mask_2d.shape_native} + """ + ) def convert_array_2d( array_2d: Union[np.ndarray, List], @@ -143,7 +119,6 @@ def convert_array_2d( check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d) is_native = len(array_2d.shape) == 2 - mask_2d = np.array(mask_2d) if is_native and not skip_mask: array_2d *= np.invert(mask_2d) diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index 4406db105..9c01a8a1e 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -229,16 +229,11 @@ def __init__( print(array_2d.native) # masked 2D data representation. """ - try: - values = values._array - except AttributeError: - pass - if conf.instance["general"]["structures"]["native_binned_only"]: store_native = True values = array_2d_util.convert_array_2d( - array_2d=values, + array_2d=np.array(values), mask_2d=mask, store_native=store_native, skip_mask=skip_mask, From 17817b81c843ea8949bd1e49fafa054821c9eef8 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:52:55 +0100 Subject: [PATCH 23/33] mask derive fixed --- autoarray/mask/derive/indexes_2d.py | 4 ++-- autoarray/structures/arrays/array_2d_util.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index d8cf61a6e..4807799d9 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -198,7 +198,7 @@ def edge_slim(self) -> np.ndarray: print(derive_indexes_2d.edge_slim) """ - return mask_2d_util.edge_1d_indexes_from(mask_2d=np.array(self.mask)).astype( + return mask_2d_util.edge_1d_indexes_from(mask_2d=np.array(self.mask).astype("bool")).astype( "int" ) @@ -301,7 +301,7 @@ def border_slim(self) -> np.ndarray: print(derive_indexes_2d.border_slim) """ return mask_2d_util.border_slim_indexes_from( - mask_2d=np.array(self.mask) + mask_2d=np.array(self.mask).astype("bool") ).astype("int") @property diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index f269d0106..7048c07a0 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -14,8 +14,6 @@ from functools import partial - - def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: """ If the input array input a convert is of type list, convert it to type NumPy array. From b76cc9aac5e78c2d88b218e4a9b49b0c499981d7 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 19:59:50 +0100 Subject: [PATCH 24/33] another way to make hecks only use ndarray --- autoarray/structures/arrays/uniform_2d.py | 7 ++++++- .../arrays/files/array/output_test/array.fits | Bin 5760 -> 5760 bytes 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index 9c01a8a1e..12ed86b0d 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -232,8 +232,13 @@ def __init__( if conf.instance["general"]["structures"]["native_binned_only"]: store_native = True + try: + values = values._array + except AttributeError: + values = values + values = array_2d_util.convert_array_2d( - array_2d=np.array(values), + array_2d=values, mask_2d=mask, store_native=store_native, skip_mask=skip_mask, diff --git a/test_autoarray/structures/arrays/files/array/output_test/array.fits b/test_autoarray/structures/arrays/files/array/output_test/array.fits index 0cff0a8f7db90d3ded28c644cd66e3791627feeb..dab9da2fb676efc023e53b73faa54fe78fc4000c 100644 GIT binary patch delta 56 kcmZqBZP1;N!(?o$Vgmz%Jzl(dBKLc)$qTsk0j$6e2mk;8 delta 84 hcmZqBZP1;N!(?W%G4C>$;|B&XuqT_|T*&>83jh#~5uX46 From c9e275ddd8ad0887bab644366d47e9343338ed95 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:23:29 +0100 Subject: [PATCH 25/33] fixes which ensure grad works on real LH function --- autoarray/fit/fit_util.py | 6 +++--- .../operators/over_sampling/over_sample_util.py | 2 +- autoarray/operators/over_sampling/over_sampler.py | 12 +++++++----- autoarray/structures/arrays/array_2d_util.py | 11 +++++++++-- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index e5d34675b..d40f55d1c 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -1,6 +1,6 @@ from functools import wraps +import jax.numpy as np -from autoarray.numpy_wrapper import np from autoarray.mask.abstract_mask import Mask from autoarray import type as ty @@ -83,7 +83,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - return np.sum(chi_squared_map) + return np.sum(chi_squared_map._array) def noise_normalization_from(*, noise_map: ty.DataLike) -> float: @@ -97,7 +97,7 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float: noise_map The masked noise-map of the dataset. """ - return np.sum(np.log(2 * np.pi * noise_map**2.0)) + return np.sum(np.log(2 * np.pi * noise_map._array**2.0)) def normalized_residual_map_complex_from( diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index 70dc220c4..8966e785e 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -422,7 +422,7 @@ def grid_2d_slim_over_sampled_via_mask_from( grid_slim = np.zeros(shape=(total_sub_pixels, 2)) - centres_scaled = geometry_util.central_scaled_coordinate_2d_from( + centres_scaled = geometry_util.central_scaled_coordinate_2d_numba_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index aa80392b4..6440a2ee7 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -220,11 +220,13 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": except AttributeError: pass - binned_array_2d = over_sample_util.binned_array_2d_from( - array_2d=np.array(array), - mask_2d=np.array(self.mask), - sub_size=np.array(self.sub_size).astype("int"), - ) + # binned_array_2d = over_sample_util.binned_array_2d_from( + # array_2d=np.array(array), + # mask_2d=np.array(self.mask), + # sub_size=np.array(self.sub_size).astype("int"), + # ) + + binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]).mean(axis=1) return Array2D( values=binned_array_2d, diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index 7048c07a0..0e694846d 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -23,8 +23,15 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - array = np.asarray(array) - + if isinstance(array, np.ndarray) or isinstance(array, list): + array = np.asarray(array) + elif isinstance(array, jnp.ndarray): + array = jax.lax.cond( + type(array) is list, + lambda _: jnp.asarray(array), + lambda _: array, + None + ) return array From 70c021242f51da937340ac221a1618474ee492d2 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:32:26 +0100 Subject: [PATCH 26/33] fix all uniform_2d unit tests --- autoarray/geometry/geometry_2d.py | 3 ++- autoarray/structures/arrays/uniform_2d.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/autoarray/geometry/geometry_2d.py b/autoarray/geometry/geometry_2d.py index 29c604405..e78f0f75a 100644 --- a/autoarray/geometry/geometry_2d.py +++ b/autoarray/geometry/geometry_2d.py @@ -184,8 +184,9 @@ def scaled_coordinates_2d_from( ------- A 2D (y,x) pixel-value coordinate. """ + return geometry_util.scaled_coordinates_2d_from( - pixel_coordinates_2d=pixel_coordinates_2d, + pixel_coordinates_2d=np.array(pixel_coordinates_2d), shape_native=self.shape_native, pixel_scales=self.pixel_scales, origins=self.origin, diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index 12ed86b0d..11c478ad5 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -464,7 +464,7 @@ def zoomed_around_mask(self, buffer: int = 1) -> "Array2D": """ extracted_array_2d = array_2d_util.extracted_array_2d_from( - array_2d=np.array(self.native), + array_2d=np.array(self.native._array), y0=self.mask.zoom_region[0] - buffer, y1=self.mask.zoom_region[1] + buffer, x0=self.mask.zoom_region[2] - buffer, @@ -498,7 +498,7 @@ def extent_of_zoomed_array(self, buffer: int = 1) -> np.ndarray: The number pixels around the extracted array used as a buffer. """ extracted_array_2d = array_2d_util.extracted_array_2d_from( - array_2d=np.array(self.native), + array_2d=np.array(self.native._array), y0=self.mask.zoom_region[0] - buffer, y1=self.mask.zoom_region[1] + buffer, x0=self.mask.zoom_region[2] - buffer, @@ -532,7 +532,7 @@ def resized_from( """ resized_array_2d = array_2d_util.resized_array_2d_from( - array_2d=np.array(self.native), resized_shape=new_shape + array_2d=np.array(self.native._array), resized_shape=new_shape ) resized_mask = self.mask.resized_from( @@ -599,7 +599,7 @@ def trimmed_after_convolution_from( resized_mask = self.mask.resized_from(new_shape=trimmed_array_2d.shape) array = array_2d_util.convert_array_2d( - array_2d=trimmed_array_2d, mask_2d=resized_mask + array_2d=trimmed_array_2d._array, mask_2d=resized_mask ) return Array2D( From c417511eb43fbda0925c968128a17d0b2fa235b2 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:36:29 +0100 Subject: [PATCH 27/33] fix all of kernel 2d --- autoarray/structures/arrays/kernel_2d.py | 6 +++--- test_autoarray/structures/arrays/test_kernel_2d.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index c03ceaf32..f80e3f6e3 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -383,7 +383,7 @@ def rescaled_with_odd_dimensions_from( try: kernel_rescaled = rescale( - self.native, + np.array(self.native._array), rescale_factor, anti_aliasing=False, mode="constant", @@ -391,7 +391,7 @@ def rescaled_with_odd_dimensions_from( ) except TypeError: kernel_rescaled = rescale( - self.native, + np.array(self.native._array), rescale_factor, anti_aliasing=False, mode="constant", @@ -472,7 +472,7 @@ def convolved_array_from(self, array: Array2D) -> Array2D: array_2d = array.native - convolved_array_2d = scipy.signal.convolve2d(array_2d, self.native, mode="same") + convolved_array_2d = scipy.signal.convolve2d(array_2d._array, np.array(self.native._array), mode="same") convolved_array_1d = array_2d_util.array_2d_slim_from( mask_2d=np.array(array_2d.mask), diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 292d628f7..9106b43a1 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -457,4 +457,4 @@ def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy normalize=True, ) - assert kernel_astropy == pytest.approx(kernel_2d.native, 1e-4) + assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) From db9cfb7aefe1bcd83afc24b31d249e5ef98df844 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:37:37 +0100 Subject: [PATCH 28/33] fix repr --- test_autoarray/structures/arrays/test_repr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_autoarray/structures/arrays/test_repr.py b/test_autoarray/structures/arrays/test_repr.py index cf0bfa964..8ea8c6b9d 100644 --- a/test_autoarray/structures/arrays/test_repr.py +++ b/test_autoarray/structures/arrays/test_repr.py @@ -3,4 +3,4 @@ def test_repr(): array = aa.Array2D.no_mask([[1, 2], [3, 4]], pixel_scales=1) - assert repr(array) == "Array2D([1., 2., 3., 4.])" + assert repr(array) == "Array2D([1, 2, 3, 4])" From 3cb3f7622d1030349825cba5d7dd8665f9e527f3 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:41:11 +0100 Subject: [PATCH 29/33] remove relocate_to_radial_minimum test as all functionality is to be removed --- .../structures/grids/test_uniform_2d.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/test_autoarray/structures/grids/test_uniform_2d.py b/test_autoarray/structures/grids/test_uniform_2d.py index 4084908ba..d11457416 100644 --- a/test_autoarray/structures/grids/test_uniform_2d.py +++ b/test_autoarray/structures/grids/test_uniform_2d.py @@ -780,37 +780,6 @@ def test__grid_with_coordinates_within_distance_removed_from(): ).all() -def test__grid_radial_minimum(): - grid_2d = np.array([[2.5, 0.0], [4.0, 0.0], [6.0, 0.0]]) - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - assert (deflections == grid_2d).all() - - grid_2d = np.array([[2.0, 0.0], [1.0, 0.0], [6.0, 0.0]]) - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - - assert (deflections == np.array([[2.5, 0.0], [2.5, 0.0], [6.0, 0.0]])).all() - - grid_2d = np.array( - [ - [np.sqrt(2.0), np.sqrt(2.0)], - [1.0, np.sqrt(8.0)], - [np.sqrt(8.0), np.sqrt(8.0)], - ] - ) - - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - - assert deflections == pytest.approx( - np.array([[1.7677, 1.7677], [1.0, np.sqrt(8.0)], [np.sqrt(8), np.sqrt(8.0)]]), - 1.0e-4, - ) - def test__recursive_shape_storage(): grid_2d = aa.Grid2D.no_mask( From 467d1ea614eecafdb70d5741d1390ff32a9a124e Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:48:36 +0100 Subject: [PATCH 30/33] fix Grid2D test_unifrom --- autoarray/structures/grids/uniform_2d.py | 36 ++++++++++--------- .../structures/grids/test_uniform_2d.py | 6 ++-- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index f0d20042c..4d565dd27 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -4,8 +4,6 @@ from pathlib import Path from typing import List, Optional, Tuple, Union -from autoarray.numpy_wrapper import use_jax - from autoconf import conf from autoconf import cached_property from autoconf.fitsable import ndarray_via_fits_from @@ -820,7 +818,7 @@ def grid_with_coordinates_within_distance_removed_from( distance_mask += distances.native < distance mask = Mask2D( - mask=distance_mask, + mask=np.array(distance_mask), pixel_scales=self.pixel_scales, origin=self.origin, ) @@ -841,8 +839,8 @@ def squared_distances_to_coordinate_from( coordinate The (y,x) coordinate from which the squared distance of every grid (y,x) coordinate is computed. """ - if use_jax: - squared_distances = np.square(self.array[:, 0] - coordinate[0]) + np.square( + if isinstance(self, jnp.ndarray): + squared_distances = jnp.square(self.array[:, 0] - coordinate[0]) + jnp.square( self.array[:, 1] - coordinate[1] ) else: @@ -1015,10 +1013,16 @@ def shape_native_scaled_interior(self) -> Tuple[float, float]: of the grid's (y,x) values, whereas the `shape_native_scaled` uses the uniform geometry of the grid and its ``pixel_scales``, which means it has a buffer at each edge of half a ``pixel_scale``. """ - return ( - np.amax(self[:, 0]) - np.amin(self[:, 0]), - np.amax(self[:, 1]) - np.amin(self[:, 1]), - ) + if isinstance(self, jnp.ndarray): + return ( + np.amax(self.array[:, 0]) - np.amin(self.array[:, 0]), + np.amax(self.array[:, 1]) - np.amin(self.array[:, 1]), + ) + else: + return ( + np.amax(self[:, 0]) - np.amin(self[:, 0]), + np.amax(self[:, 1]) - np.amin(self[:, 1]), + ) @property def scaled_minima(self) -> Tuple: @@ -1026,10 +1030,10 @@ def scaled_minima(self) -> Tuple: The (y,x) minimum values of the grid in scaled units, buffed such that their extent is further than the grid's extent. """ - if use_jax: + if isinstance(self, jnp.ndarray): return ( - np.amin(self.array[:, 0]).astype("float"), - np.amin(self.array[:, 1]).astype("float"), + jnp.amin(self.array[:, 0]).astype("float"), + jnp.amin(self.array[:, 1]).astype("float"), ) else: return ( @@ -1043,10 +1047,10 @@ def scaled_maxima(self) -> Tuple: The (y,x) maximum values of the grid in scaled units, buffed such that their extent is further than the grid's extent. """ - if use_jax: + if isinstance(self, jnp.ndarray): return ( - np.amax(self.array[:, 0]).astype("float"), - np.amax(self.array[:, 1]).astype("float"), + jnp.amax(self.array[:, 0]).astype("float"), + jnp.amax(self.array[:, 1]).astype("float"), ) else: return ( @@ -1103,7 +1107,7 @@ def padded_grid_from(self, kernel_shape_native: Tuple[int, int]) -> "Grid2D": ) over_sample_size = np.pad( - self.over_sample_size.native, + self.over_sample_size.native._array, pad_width, mode="constant", constant_values=1, diff --git a/test_autoarray/structures/grids/test_uniform_2d.py b/test_autoarray/structures/grids/test_uniform_2d.py index d11457416..686721c70 100644 --- a/test_autoarray/structures/grids/test_uniform_2d.py +++ b/test_autoarray/structures/grids/test_uniform_2d.py @@ -481,9 +481,10 @@ def test__to_and_from_fits_methods(): def test__shape_native_scaled(): - mask = aa.Mask2D.circular(shape_native=(3, 3), radius=1.0, pixel_scales=(1.0, 1.0)) + mask = aa.Mask2D.circular(shape_native=(3, 3), radius=1.1, pixel_scales=(1.0, 1.0)) grid_2d = aa.Grid2D.from_mask(mask=mask) + assert grid_2d.shape_native_scaled_interior == (2.0, 2.0) mask = aa.Mask2D.elliptical( @@ -574,7 +575,8 @@ def test__grid_2d_radial_projected_shape_slim_from(): pixel_scales=grid_2d.pixel_scales, ) - assert (grid_radii == grid_radii_util).all() + + assert grid_radii == pytest.approx(grid_radii_util, 1.0e-4) assert grid_radial_shape_slim == grid_radii_util.shape[0] grid_radii = grid_2d.grid_2d_radial_projected_from(centre=(0.3, 0.1), angle=60.0) From 7751080c4d627c8cad7626868f1fa372a9110f22 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:52:35 +0100 Subject: [PATCH 31/33] fix grid test_uniform_1d --- autoarray/structures/grids/uniform_1d.py | 6 +----- test_autoarray/structures/grids/test_uniform_1d.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/autoarray/structures/grids/uniform_1d.py b/autoarray/structures/grids/uniform_1d.py index d5870aeca..53a9ec756 100644 --- a/autoarray/structures/grids/uniform_1d.py +++ b/autoarray/structures/grids/uniform_1d.py @@ -1,10 +1,6 @@ from __future__ import annotations import numpy as np -from typing import TYPE_CHECKING, List, Union, Tuple - -if TYPE_CHECKING: - from autoarray.structures.arrays.uniform_1d import Array1D - from autoarray.structures.grids.uniform_2d import Grid2D +from typing import List, Union, Tuple from autoarray.structures.abstract_structure import Structure from autoarray.structures.grids.irregular_2d import Grid2DIrregular diff --git a/test_autoarray/structures/grids/test_uniform_1d.py b/test_autoarray/structures/grids/test_uniform_1d.py index 147bd4b17..b99da3608 100644 --- a/test_autoarray/structures/grids/test_uniform_1d.py +++ b/test_autoarray/structures/grids/test_uniform_1d.py @@ -129,7 +129,7 @@ def test__grid_2d_radial_projected_from(): grid_2d = grid_1d.grid_2d_radial_projected_from(angle=90.0) assert grid_2d.slim == pytest.approx( - np.array([[-1.0, 0.0], [-2.0, 0.0], [-3.0, 0.0], [-4.0, 0.0]]), 1.0e-4 + np.array([[-1.0, 0.0], [-2.0, 0.0], [-3.0, 0.0], [-4.0, 0.0]]), abs=1.0e-4 ) grid_2d = grid_1d.grid_2d_radial_projected_from(angle=45.0) From f4c3269576995a8c2b042bb750a7a134ef062879 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 20:55:19 +0100 Subject: [PATCH 32/33] hammer hammer hammer --- autoarray/structures/grids/irregular_2d.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/autoarray/structures/grids/irregular_2d.py b/autoarray/structures/grids/irregular_2d.py index 68ad0dce4..57d4c4a6f 100644 --- a/autoarray/structures/grids/irregular_2d.py +++ b/autoarray/structures/grids/irregular_2d.py @@ -1,15 +1,12 @@ import logging -from typing import List, Optional, Tuple, Union +import numpy as np +from typing import List, Tuple, Union -from autoarray.numpy_wrapper import np from autoarray.abstract_ndarray import AbstractNDArray from autoarray.geometry.geometry_2d_irregular import Geometry2DIrregular from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.irregular import ArrayIrregular -from autoarray.structures.grids import grid_2d_util -from autoarray.geometry import geometry_util - logger = logging.getLogger(__name__) From 8d2b338c18663a7908ff9d49843eb8aa7a8f087c Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Wed, 2 Apr 2025 15:57:22 +0100 Subject: [PATCH 33/33] fix over sampler test --- autoarray/operators/over_sampling/over_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 6440a2ee7..0aafbe008 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -226,7 +226,7 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": # sub_size=np.array(self.sub_size).astype("int"), # ) - binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]).mean(axis=1) + binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]**2).mean(axis=1) return Array2D( values=binned_array_2d,