diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index bc529ef96..20bb8b0be 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -162,7 +162,7 @@ def __init__( if psf is not None and use_normalized_psf: psf = Kernel2D.no_mask( - values=psf.native, pixel_scales=psf.pixel_scales, normalize=True + values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True ) self.psf = psf @@ -193,7 +193,7 @@ def convolver(self): The convolver given the masked imaging data's mask and PSF. """ - return Convolver(mask=self.mask, kernel=self.psf) + return Convolver(mask=self.mask, kernel=Kernel2D(values=self.psf._array, mask=self.psf.mask, header=self.psf.header)) @cached_property def w_tilde(self): diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index e5d34675b..d40f55d1c 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -1,6 +1,6 @@ from functools import wraps +import jax.numpy as np -from autoarray.numpy_wrapper import np from autoarray.mask.abstract_mask import Mask from autoarray import type as ty @@ -83,7 +83,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - return np.sum(chi_squared_map) + return np.sum(chi_squared_map._array) def noise_normalization_from(*, noise_map: ty.DataLike) -> float: @@ -97,7 +97,7 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float: noise_map The masked noise-map of the dataset. """ - return np.sum(np.log(2 * np.pi * noise_map**2.0)) + return np.sum(np.log(2 * np.pi * noise_map._array**2.0)) def normalized_residual_map_complex_from( diff --git a/autoarray/geometry/geometry_2d.py b/autoarray/geometry/geometry_2d.py index 29c604405..e78f0f75a 100644 --- a/autoarray/geometry/geometry_2d.py +++ b/autoarray/geometry/geometry_2d.py @@ -184,8 +184,9 @@ def scaled_coordinates_2d_from( ------- A 2D (y,x) pixel-value coordinate. """ + return geometry_util.scaled_coordinates_2d_from( - pixel_coordinates_2d=pixel_coordinates_2d, + pixel_coordinates_2d=np.array(pixel_coordinates_2d), shape_native=self.shape_native, pixel_scales=self.pixel_scales, origins=self.origin, diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index a795d42ee..b646c7d08 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -1,5 +1,7 @@ +import jax.numpy as jnp +import numpy as np from typing import Tuple, Union -from autoarray.numpy_wrapper import np, use_jax + from autoarray import numba_util from autoarray import type as ty @@ -179,8 +181,69 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales +@numba_util.jit() +def central_pixel_coordinates_2d_numba_from( + shape_native: Tuple[int, int], +) -> Tuple[float, float]: + """ + Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) + from the shape of that data structure. + + Examples of the central pixels are as follows: + + - For a 3x3 image, the central pixel is pixel [1, 1]. + - For a 4x4 image, the central pixel is [1.5, 1.5]. + + Parameters + ---------- + shape_native + The dimensions of the data structure, which can be in 1D, 2D or higher dimensions. + + Returns + ------- + The central pixel coordinates of the data structure. + """ + return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) @numba_util.jit() +def central_scaled_coordinate_2d_numba_from( + shape_native: Tuple[int, int], + pixel_scales: ty.PixelScales, + origin: Tuple[float, float] = (0.0, 0.0), +) -> Tuple[float, float]: + """ + Returns the central scaled coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) + from the shape of that data structure. + + This is computed by using the data structure's shape and converting it to scaled units using an input + pixel-coordinates to scaled-coordinate conversion factor `pixel_scales`. + + The origin of the scaled grid can also be input and moved from (0.0, 0.0). + + Parameters + ---------- + shape_native + The 2D shape of the data structure whose central scaled coordinates are computed. + pixel_scales + The (y,x) scaled units to pixel units conversion factor of the 2D data structure. + origin + The (y,x) scaled units origin of the coordinate system the central scaled coordinate is computed on. + + Returns + ------- + The central coordinates of the 2D data structure in scaled units. + """ + + central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( + shape_native=shape_native + ) + + y_pixel = central_pixel_coordinates[0] + (origin[0] / pixel_scales[0]) + x_pixel = central_pixel_coordinates[1] - (origin[1] / pixel_scales[1]) + + return (y_pixel, x_pixel) + + def central_pixel_coordinates_2d_from( shape_native: Tuple[int, int], ) -> Tuple[float, float]: @@ -205,7 +268,6 @@ def central_pixel_coordinates_2d_from( return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) -@numba_util.jit() def central_scaled_coordinate_2d_from( shape_native: Tuple[int, int], pixel_scales: ty.PixelScales, @@ -234,7 +296,7 @@ def central_scaled_coordinate_2d_from( The central coordinates of the 2D data structure in scaled units. """ - central_pixel_coordinates = central_pixel_coordinates_2d_from( + central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( shape_native=shape_native ) @@ -243,8 +305,6 @@ def central_scaled_coordinate_2d_from( return (y_pixel, x_pixel) - -@numba_util.jit() def pixel_coordinates_2d_from( scaled_coordinates_2d: Tuple[float, float], shape_native: Tuple[int, int], @@ -352,7 +412,7 @@ def scaled_coordinates_2d_from( origin=(0.0, 0.0) ) """ - central_scaled_coordinates = central_scaled_coordinate_2d_from( + central_scaled_coordinates = central_scaled_coordinate_2d_numba_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origins ) @@ -382,18 +442,16 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - if use_jax: - shifted_grid_2d = grid_2d.array - np.array(centre) - else: - shifted_grid_2d = grid_2d - np.array(centre) - radius = np.sqrt(np.sum(shifted_grid_2d**2.0, axis=1)) - theta_coordinate_to_profile = np.arctan2( + shifted_grid_2d = np.array(grid_2d) - jnp.array(centre) + + radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) + theta_coordinate_to_profile = jnp.arctan2( shifted_grid_2d[:, 0], shifted_grid_2d[:, 1] - ) - np.radians(angle) - return np.vstack( + ) - jnp.radians(angle) + return jnp.vstack( [ - radius * np.sin(theta_coordinate_to_profile), - radius * np.cos(theta_coordinate_to_profile), + radius * jnp.sin(theta_coordinate_to_profile), + radius * jnp.cos(theta_coordinate_to_profile), ] ).T @@ -435,7 +493,6 @@ def transform_grid_2d_from_reference_frame( return np.vstack((y, x)).T -@numba_util.jit() def grid_pixels_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -476,33 +533,15 @@ def grid_pixels_2d_slim_from( grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_scaled_2d_slim=grid_scaled_2d_slim, shape=(2,2), pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1, 1]) - return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 - else: - grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_pixels_2d_slim[slim_index, 0] = ( - (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d_slim[slim_index, 1] = ( - (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) - - return grid_pixels_2d_slim + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1, 1]) + return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 -@numba_util.jit() def grid_pixel_centres_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -547,32 +586,14 @@ def grid_pixel_centres_2d_slim_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_pixels_2d_slim = ( - (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 - ).astype(int) - else: - grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) - - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_pixels_2d_slim[slim_index, 0] = int( - (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d_slim[slim_index, 1] = int( - (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) - - return grid_pixels_2d_slim + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + return ( + (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 + ).astype(int) -@numba_util.jit() def grid_pixel_indexes_2d_slim_from( grid_scaled_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -625,25 +646,13 @@ def grid_pixel_indexes_2d_slim_from( origin=origin, ) - if use_jax: - grid_pixel_indexes_2d_slim = ( - (grid_pixels_2d_slim * np.array([shape_native[1], 1])) - .sum(axis=1) - .astype(int) - ) - else: - grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0]) - - for slim_index in range(grid_pixels_2d_slim.shape[0]): - grid_pixel_indexes_2d_slim[slim_index] = int( - grid_pixels_2d_slim[slim_index, 0] * shape_native[1] - + grid_pixels_2d_slim[slim_index, 1] - ) - - return grid_pixel_indexes_2d_slim + return ( + (grid_pixels_2d_slim * np.array([shape_native[1], 1])) + .sum(axis=1) + .astype(int) + ) -@numba_util.jit() def grid_scaled_2d_slim_from( grid_pixels_2d_slim: np.ndarray, shape_native: Tuple[int, int], @@ -682,33 +691,18 @@ def grid_scaled_2d_slim_from( grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_pixels_2d_slim=grid_pixels_2d_slim, shape=(2,2), pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1, 1]) - grid_scaled_2d_slim = ( - (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign - ) - else: - grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2)) - - for slim_index in range(grid_scaled_2d_slim.shape[0]): - grid_scaled_2d_slim[slim_index, 0] = ( - -(grid_pixels_2d_slim[slim_index, 0] - centres_scaled[0] - 0.5) - * pixel_scales[0] - ) - grid_scaled_2d_slim[slim_index, 1] = ( - grid_pixels_2d_slim[slim_index, 1] - centres_scaled[1] - 0.5 - ) * pixel_scales[1] - - return grid_scaled_2d_slim + + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1, 1]) + return ( + (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign + ) -@numba_util.jit() def grid_pixel_centres_2d_from( grid_scaled_2d: np.ndarray, shape_native: Tuple[int, int], @@ -753,30 +747,12 @@ def grid_pixel_centres_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_pixels_2d = ( - (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 - ).astype(int) - else: - grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2)) - - for y in range(grid_scaled_2d.shape[0]): - for x in range(grid_scaled_2d.shape[1]): - grid_pixels_2d[y, x, 0] = int( - (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) - + centres_scaled[0] - + 0.5 - ) - grid_pixels_2d[y, x, 1] = int( - (grid_scaled_2d[y, x, 1] / pixel_scales[1]) - + centres_scaled[1] - + 0.5 - ) - - return grid_pixels_2d + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) + return ( + (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 + ).astype(int) def extent_symmetric_from( diff --git a/autoarray/mask/abstract_mask.py b/autoarray/mask/abstract_mask.py index 8bccc586e..d6ed67c2c 100644 --- a/autoarray/mask/abstract_mask.py +++ b/autoarray/mask/abstract_mask.py @@ -2,12 +2,6 @@ from abc import ABC import logging - -from autoarray.numpy_wrapper import np, use_jax - -if use_jax: - import jax - import numpy as np from typing import Dict diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index d8cf61a6e..4807799d9 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -198,7 +198,7 @@ def edge_slim(self) -> np.ndarray: print(derive_indexes_2d.edge_slim) """ - return mask_2d_util.edge_1d_indexes_from(mask_2d=np.array(self.mask)).astype( + return mask_2d_util.edge_1d_indexes_from(mask_2d=np.array(self.mask).astype("bool")).astype( "int" ) @@ -301,7 +301,7 @@ def border_slim(self) -> np.ndarray: print(derive_indexes_2d.border_slim) """ return mask_2d_util.border_slim_indexes_from( - mask_2d=np.array(self.mask) + mask_2d=np.array(self.mask).astype("bool") ).astype("int") @property diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 57f530f9d..05bd77852 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -802,7 +802,7 @@ def resized_from(self, new_shape, pad_value: int = 0.0) -> Mask2D: """ resized_mask = array_2d_util.resized_array_2d_from( - array_2d=np.array(self), + array_2d=np.array(self._array), resized_shape=new_shape, pad_value=pad_value, ).astype("bool") diff --git a/autoarray/operators/convolver.py b/autoarray/operators/convolver.py index 5a52329b0..73d3767d5 100644 --- a/autoarray/operators/convolver.py +++ b/autoarray/operators/convolver.py @@ -221,7 +221,7 @@ def __init__(self, mask, kernel): coordinates=(x, y), mask=np.array(mask), mask_index_array=self.mask_index_array, - kernel_2d=np.array(self.kernel.native[:, :]), + kernel_2d=self.kernel.native, ) self.image_frame_1d_indexes[mask_1d_index, :] = ( image_frame_1d_indexes diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index a98276896..8966e785e 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -1,8 +1,6 @@ from __future__ import annotations import numpy as np from typing import TYPE_CHECKING, Union, List, Tuple -from autoarray.numpy_wrapper import np, register_pytree_node_class, use_jax, jit - from typing import List, Tuple from autoarray.structures.arrays.uniform_2d import Array2D @@ -364,11 +362,11 @@ def sub_size_radial_bins_from( for i in range(radial_grid.shape[0]): for j in range(len(radial_list)): if radial_grid[i] < radial_list[j]: - if use_jax: - # while this makes it run, it is very, very slow - sub_size = sub_size.at[i].set(sub_size_list[j]) - else: - sub_size[i] = sub_size_list[j] + # if use_jax: + # # while this makes it run, it is very, very slow + # sub_size = sub_size.at[i].set(sub_size_list[j]) + # else: + sub_size[i] = sub_size_list[j] break return sub_size @@ -424,7 +422,7 @@ def grid_2d_slim_over_sampled_via_mask_from( grid_slim = np.zeros(shape=(total_sub_pixels, 2)) - centres_scaled = geometry_util.central_scaled_coordinate_2d_from( + centres_scaled = geometry_util.central_scaled_coordinate_2d_numba_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) @@ -447,35 +445,35 @@ def grid_2d_slim_over_sampled_via_mask_from( for y1 in range(sub): for x1 in range(sub): - if use_jax: - # while this makes it run, it is very, very slow - grid_slim = grid_slim.at[sub_index, 0].set( - -( - y_scaled - - y_sub_half - + y1 * y_sub_step - + (y_sub_step / 2.0) - ) - ) - grid_slim = grid_slim.at[sub_index, 1].set( - x_scaled - - x_sub_half - + x1 * x_sub_step - + (x_sub_step / 2.0) - ) - else: - grid_slim[sub_index, 0] = -( - y_scaled - - y_sub_half - + y1 * y_sub_step - + (y_sub_step / 2.0) - ) - grid_slim[sub_index, 1] = ( - x_scaled - - x_sub_half - + x1 * x_sub_step - + (x_sub_step / 2.0) - ) + # if use_jax: + # # while this makes it run, it is very, very slow + # grid_slim = grid_slim.at[sub_index, 0].set( + # -( + # y_scaled + # - y_sub_half + # + y1 * y_sub_step + # + (y_sub_step / 2.0) + # ) + # ) + # grid_slim = grid_slim.at[sub_index, 1].set( + # x_scaled + # - x_sub_half + # + x1 * x_sub_step + # + (x_sub_step / 2.0) + # ) + # else: + grid_slim[sub_index, 0] = -( + y_scaled + - y_sub_half + + y1 * y_sub_step + + (y_sub_step / 2.0) + ) + grid_slim[sub_index, 1] = ( + x_scaled + - x_sub_half + + x1 * x_sub_step + + (x_sub_step / 2.0) + ) sub_index += 1 index += 1 @@ -544,14 +542,14 @@ def binned_array_2d_from( for y1 in range(sub): for x1 in range(sub): - if use_jax: - binned_array_2d_slim = binned_array_2d_slim.at[index].add( - array_2d[sub_index] * sub_fraction[index] - ) - else: - binned_array_2d_slim[index] += ( - array_2d[sub_index] * sub_fraction[index] - ) + # if use_jax: + # binned_array_2d_slim = binned_array_2d_slim.at[index].add( + # array_2d[sub_index] * sub_fraction[index] + # ) + # else: + binned_array_2d_slim[index] += ( + array_2d[sub_index] * sub_fraction[index] + ) sub_index += 1 index += 1 diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 6492c00b7..0aafbe008 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -147,11 +147,11 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): ) def tree_flatten(self): - return (self.mask,), () + return (self.mask, self.sub_size), () @classmethod def tree_unflatten(cls, aux_data, children): - return cls(mask=children[0]) + return cls(mask=children[0], sub_size=children[1]) @property def sub_total(self): @@ -220,11 +220,13 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": except AttributeError: pass - binned_array_2d = over_sample_util.binned_array_2d_from( - array_2d=np.array(array), - mask_2d=np.array(self.mask), - sub_size=np.array(self.sub_size).astype("int"), - ) + # binned_array_2d = over_sample_util.binned_array_2d_from( + # array_2d=np.array(array), + # mask_2d=np.array(self.mask), + # sub_size=np.array(self.sub_size).astype("int"), + # ) + + binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]**2).mean(axis=1) return Array2D( values=binned_array_2d, diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index c75cc9750..0e694846d 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -1,5 +1,7 @@ from __future__ import annotations - +import jax +import jax.numpy as jnp +import numpy as np from typing import TYPE_CHECKING, List, Tuple, Union if TYPE_CHECKING: @@ -9,12 +11,8 @@ from autoarray.mask import mask_2d_util from autoarray import exc -from autoarray.numpy_wrapper import use_jax, np, jit from functools import partial -if use_jax: - import jax - def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: """ @@ -25,31 +23,24 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - if use_jax: + if isinstance(array, np.ndarray) or isinstance(array, list): + array = np.asarray(array) + elif isinstance(array, jnp.ndarray): array = jax.lax.cond( - type(array) is list, lambda _: np.asarray(array), lambda _: array, None + type(array) is list, + lambda _: jnp.asarray(array), + lambda _: array, + None ) - elif type(array) is list: - array = np.asarray(array) - return array def check_array_2d(array_2d: np.ndarray): - def exception_message(): + if len(array_2d.shape) != 1: raise exc.ArrayException( "An array input into the Array2D.__new__ method is not of shape 1." ) - cond = len(array_2d.shape) != 1 - if use_jax: - jax.lax.cond( - cond, lambda _: jax.debug.callback(exception_message), lambda _: None, None - ) - elif cond: - exception_message() - - def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): """ The `manual` classmethods in the `Array2D` object take as input a list or ndarray which is returned as an @@ -66,64 +57,38 @@ def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D): mask_2d The mask of the output Array2D. """ + if len(array_2d.shape) == 1: + if array_2d.shape[0] != mask_2d.pixels_in_mask: + raise exc.ArrayException( + f""" + The input array is a slim 1D array, but it does not have the same number of entries as pixels in + the mask. - def exception_message_1(): - raise exc.ArrayException( - f""" - The input array is a slim 1D array, but it does not have the same number of entries as pixels in - the mask. - - This indicates that the number of unmaksed pixels in the mask is different to the input slim array - shape. - - The shapes of the two arrays (which this exception is raised because they are different) are as follows: - - Input array_2d_slim.shape = {array_2d.shape[0]} - Input mask_2d.pixels_in_mask = {mask_2d.pixels_in_mask} - Input mask_2d.shape_native = {mask_2d.shape_native} - """ - ) - - cond_1 = (len(array_2d.shape) == 1) and ( - array_2d.shape[0] != mask_2d.pixels_in_mask - ) - - if use_jax: - jax.lax.cond( - cond_1, - lambda _: jax.debug.callback(exception_message_1), - lambda _: None, - None, - ) - elif cond_1: - exception_message_1() - - def exception_message_2(): - raise exc.ArrayException( - f""" - The input array is 2D but not the same dimensions as the mask. + This indicates that the number of unmaksed pixels in the mask is different to the input slim array + shape. - This indicates the mask's shape is different to the input array shape. + The shapes of the two arrays (which this exception is raised because they are different) are as follows: - The shapes of the two arrays (which this exception is raised because they are different) are as follows: + Input array_2d_slim.shape = {array_2d.shape[0]} + Input mask_2d.pixels_in_mask = {mask_2d.pixels_in_mask} + Input mask_2d.shape_native = {mask_2d.shape_native} + """ + ) - Input array_2d shape = {array_2d.shape} - Input mask_2d shape_native = {mask_2d.shape_native} - """ - ) + if len(array_2d.shape) == 2: + if array_2d.shape != mask_2d.shape_native: + raise exc.ArrayException( + f""" + The input array is 2D but not the same dimensions as the mask. - cond_2 = (len(array_2d.shape) == 2) and (array_2d.shape != mask_2d.shape_native) + This indicates the mask's shape is different to the input array shape. - if use_jax: - jax.lax.cond( - cond_2, - lambda _: jax.debug.callback(exception_message_2), - lambda _: None, - None, - ) - elif cond_2: - exception_message_2() + The shapes of the two arrays (which this exception is raised because they are different) are as follows: + Input array_2d shape = {array_2d.shape} + Input mask_2d shape_native = {mask_2d.shape_native} + """ + ) def convert_array_2d( array_2d: Union[np.ndarray, List], @@ -159,8 +124,6 @@ def convert_array_2d( check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d) is_native = len(array_2d.shape) == 2 - if use_jax: - mask_2d = mask_2d.array if is_native and not skip_mask: array_2d *= np.invert(mask_2d) @@ -526,8 +489,6 @@ def index_slim_for_index_2d_from(indexes_2d: np.ndarray, shape_native) -> np.nda return index_slim_for_index_native_2d - -@numba_util.jit() def array_2d_slim_from( array_2d_native: np.ndarray, mask_2d: np.ndarray, @@ -571,23 +532,7 @@ def array_2d_slim_from( array_2d_slim = array_2d_slim_from(mask=mask, array_2d=array_2d) """ - - if use_jax: - array_2d_slim = array_2d_native[~mask_2d.astype(bool)] - else: - total_pixels = np.sum(~mask_2d) - - array_2d_slim = np.zeros(shape=total_pixels) - index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - array_2d_slim[index] = array_2d_native[y, x] - index += 1 - - return array_2d_slim - + return array_2d_native[~mask_2d.astype(bool)] def array_2d_native_from( array_2d_slim: np.ndarray, @@ -641,8 +586,7 @@ def array_2d_native_from( ) -@partial(jit, static_argnums=(1,)) -@numba_util.jit() +@partial(jax.jit, static_argnums=(1,)) def array_2d_via_indexes_from( array_2d_slim: np.ndarray, shape: Tuple[int, int], @@ -675,22 +619,11 @@ def array_2d_via_indexes_from( ndarray The native 2D array of values mapped from the slimmed array with dimensions (total_values, total_values). """ - if use_jax: - array_native_2d = ( - np.zeros(shape) - .at[tuple(native_index_for_slim_index_2d.T)] - .set(array_2d_slim) - ) - else: - array_native_2d = np.zeros(shape) - - for slim_index in range(len(native_index_for_slim_index_2d)): - array_native_2d[ - native_index_for_slim_index_2d[slim_index, 0], - native_index_for_slim_index_2d[slim_index, 1], - ] = array_2d_slim[slim_index] - - return array_native_2d + return ( + jnp.zeros(shape) + .at[tuple(native_index_for_slim_index_2d.T)] + .set(array_2d_slim) + ) @numba_util.jit() @@ -725,7 +658,7 @@ def array_2d_slim_complex_from( A 1D array of values mapped from the 2D array with dimensions (total_unmasked_pixels). """ - total_pixels = np.sum(~mask_2d) + total_pixels = np.sum(~mask) array_1d = 0 + 0j * np.zeros(shape=total_pixels) index = 0 diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index f21cf5a80..f80e3f6e3 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -53,7 +53,7 @@ def __init__( ) if normalize: - self._array[:] = np.divide(self._array, np.sum(self._array)) + self._array = np.divide(self._array, np.sum(self._array)) @classmethod def no_mask( @@ -383,7 +383,7 @@ def rescaled_with_odd_dimensions_from( try: kernel_rescaled = rescale( - self.native, + np.array(self.native._array), rescale_factor, anti_aliasing=False, mode="constant", @@ -391,7 +391,7 @@ def rescaled_with_odd_dimensions_from( ) except TypeError: kernel_rescaled = rescale( - self.native, + np.array(self.native._array), rescale_factor, anti_aliasing=False, mode="constant", @@ -472,7 +472,7 @@ def convolved_array_from(self, array: Array2D) -> Array2D: array_2d = array.native - convolved_array_2d = scipy.signal.convolve2d(array_2d, self.native, mode="same") + convolved_array_2d = scipy.signal.convolve2d(array_2d._array, np.array(self.native._array), mode="same") convolved_array_1d = array_2d_util.array_2d_slim_from( mask_2d=np.array(array_2d.mask), diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index 4406db105..11c478ad5 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -229,13 +229,13 @@ def __init__( print(array_2d.native) # masked 2D data representation. """ + if conf.instance["general"]["structures"]["native_binned_only"]: + store_native = True + try: values = values._array except AttributeError: - pass - - if conf.instance["general"]["structures"]["native_binned_only"]: - store_native = True + values = values values = array_2d_util.convert_array_2d( array_2d=values, @@ -464,7 +464,7 @@ def zoomed_around_mask(self, buffer: int = 1) -> "Array2D": """ extracted_array_2d = array_2d_util.extracted_array_2d_from( - array_2d=np.array(self.native), + array_2d=np.array(self.native._array), y0=self.mask.zoom_region[0] - buffer, y1=self.mask.zoom_region[1] + buffer, x0=self.mask.zoom_region[2] - buffer, @@ -498,7 +498,7 @@ def extent_of_zoomed_array(self, buffer: int = 1) -> np.ndarray: The number pixels around the extracted array used as a buffer. """ extracted_array_2d = array_2d_util.extracted_array_2d_from( - array_2d=np.array(self.native), + array_2d=np.array(self.native._array), y0=self.mask.zoom_region[0] - buffer, y1=self.mask.zoom_region[1] + buffer, x0=self.mask.zoom_region[2] - buffer, @@ -532,7 +532,7 @@ def resized_from( """ resized_array_2d = array_2d_util.resized_array_2d_from( - array_2d=np.array(self.native), resized_shape=new_shape + array_2d=np.array(self.native._array), resized_shape=new_shape ) resized_mask = self.mask.resized_from( @@ -599,7 +599,7 @@ def trimmed_after_convolution_from( resized_mask = self.mask.resized_from(new_shape=trimmed_array_2d.shape) array = array_2d_util.convert_array_2d( - array_2d=trimmed_array_2d, mask_2d=resized_mask + array_2d=trimmed_array_2d._array, mask_2d=resized_mask ) return Array2D( diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 44c75c7c5..6c72c00ef 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -1,8 +1,7 @@ from __future__ import annotations -from autoarray.numpy_wrapper import np, use_jax - -if use_jax: - import jax +import numpy as np +import jax.numpy as jnp +import jax from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -13,7 +12,6 @@ from autoarray.structures.arrays import array_2d_util from autoarray.geometry import geometry_util from autoarray import numba_util -from autoarray.mask import mask_2d_util from autoarray import type as ty @@ -57,32 +55,20 @@ def check_grid_2d(grid_2d: np.ndarray): def check_grid_2d_and_mask_2d(grid_2d: np.ndarray, mask_2d: Mask2D): if len(grid_2d.shape) == 2: - - def exception_message(): + if grid_2d.shape[0] != mask_2d.pixels_in_mask: raise exc.GridException( f""" The input 2D grid does not have the same number of values as pixels in the mask. - + The shape of the input grid_2d is {grid_2d.shape}. The mask shape_native is {mask_2d.shape_native}. The mask number of pixels is {mask_2d.pixels_in_mask}. """ ) - if use_jax: - jax.lax.cond( - grid_2d.shape[0] != mask_2d.pixels_in_mask, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - elif grid_2d.shape[0] != mask_2d.pixels_in_mask: - exception_message() - elif len(grid_2d.shape) == 3: - - def exception_message(): + if (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: raise exc.GridException( f""" The input 2D grid is not the same dimensions as the mask @@ -93,16 +79,6 @@ def exception_message(): """ ) - if use_jax: - jax.lax.cond( - (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native, - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None, - ) - elif (grid_2d.shape[0], grid_2d.shape[1]) != mask_2d.shape_native: - exception_message() - def convert_grid_2d( grid_2d: Union[np.ndarray, List], mask_2d: Mask2D, store_native: bool = False @@ -137,12 +113,8 @@ def convert_grid_2d( is_native = len(grid_2d.shape) == 3 if is_native: - if use_jax: - grid_2d = grid_2d.at[:, :, 0].multiply(np.invert(mask_2d.array)) - grid_2d = grid_2d.at[:, :, 1].multiply(np.invert(mask_2d.array)) - else: - grid_2d[:, :, 0] *= np.invert(mask_2d) - grid_2d[:, :, 1] *= np.invert(mask_2d) + grid_2d[:, :, 0] *= np.invert(mask_2d) + grid_2d[:, :, 1] *= np.invert(mask_2d) if is_native == store_native: return grid_2d @@ -151,16 +123,10 @@ def convert_grid_2d( grid_2d_native=np.array(grid_2d), mask=np.array(mask_2d), ) - if use_jax: - return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d.array), - mask_2d=np.array(mask_2d), - ) - else: - return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d), - mask_2d=np.array(mask_2d), - ) + return grid_2d_native_from( + grid_2d_slim=np.array(grid_2d), + mask_2d=np.array(mask_2d), + ) def convert_grid_2d_to_slim( @@ -213,7 +179,6 @@ def convert_grid_2d_to_native( ) -@numba_util.jit() def grid_2d_centre_from(grid_2d_slim: np.ndarray) -> Tuple[float, float]: """ Returns the centre of a grid from a 1D grid. @@ -233,7 +198,6 @@ def grid_2d_centre_from(grid_2d_slim: np.ndarray) -> Tuple[float, float]: return centre_y, centre_x -@numba_util.jit() def grid_2d_slim_via_mask_from( mask_2d: np.ndarray, pixel_scales: ty.PixelScales, @@ -273,33 +237,18 @@ def grid_2d_slim_via_mask_from( grid_slim = grid_2d_slim_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - total_pixels = np.sum(~mask_2d) - centres_scaled = geometry_util.central_scaled_coordinate_2d_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) - if use_jax: - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) - grid_slim = ( - (np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled) - * sign - * pixel_scales - ) - else: - index = 0 - grid_slim = np.zeros(shape=(total_pixels, 2)) - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - grid_slim[index, 0] = -(y - centres_scaled[0]) * pixel_scales[0] - grid_slim[index, 1] = (x - centres_scaled[1]) * pixel_scales[1] - index += 1 - - return grid_slim + centres_scaled = jnp.array(centres_scaled) + pixel_scales = jnp.array(pixel_scales) + sign = jnp.array([-1.0, 1.0]) + return ( + (jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T - centres_scaled) + * sign + * pixel_scales + ) def grid_2d_via_mask_from( diff --git a/autoarray/structures/grids/irregular_2d.py b/autoarray/structures/grids/irregular_2d.py index 68ad0dce4..57d4c4a6f 100644 --- a/autoarray/structures/grids/irregular_2d.py +++ b/autoarray/structures/grids/irregular_2d.py @@ -1,15 +1,12 @@ import logging -from typing import List, Optional, Tuple, Union +import numpy as np +from typing import List, Tuple, Union -from autoarray.numpy_wrapper import np from autoarray.abstract_ndarray import AbstractNDArray from autoarray.geometry.geometry_2d_irregular import Geometry2DIrregular from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.irregular import ArrayIrregular -from autoarray.structures.grids import grid_2d_util -from autoarray.geometry import geometry_util - logger = logging.getLogger(__name__) diff --git a/autoarray/structures/grids/uniform_1d.py b/autoarray/structures/grids/uniform_1d.py index d5870aeca..53a9ec756 100644 --- a/autoarray/structures/grids/uniform_1d.py +++ b/autoarray/structures/grids/uniform_1d.py @@ -1,10 +1,6 @@ from __future__ import annotations import numpy as np -from typing import TYPE_CHECKING, List, Union, Tuple - -if TYPE_CHECKING: - from autoarray.structures.arrays.uniform_1d import Array1D - from autoarray.structures.grids.uniform_2d import Grid2D +from typing import List, Union, Tuple from autoarray.structures.abstract_structure import Structure from autoarray.structures.grids.irregular_2d import Grid2DIrregular diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 41559e083..4d565dd27 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -1,5 +1,6 @@ from __future__ import annotations -from autoarray.numpy_wrapper import np, use_jax +import jax.numpy as jnp +import numpy as np from pathlib import Path from typing import List, Optional, Tuple, Union @@ -12,7 +13,6 @@ from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.structures.arrays import array_2d_util from autoarray.structures.grids import grid_2d_util from autoarray.geometry import geometry_util from autoarray.operators.over_sampling import over_sample_util @@ -160,7 +160,7 @@ def __init__( over sampled grid is not passed in it is computed assuming uniformity. """ values = grid_2d_util.convert_grid_2d( - grid_2d=values, + grid_2d=np.array(values), mask_2d=mask, store_native=store_native, ) @@ -184,7 +184,7 @@ def __init__( over_sample_util.grid_2d_slim_over_sampled_via_mask_from( mask_2d=np.array(self.mask), pixel_scales=self.mask.pixel_scales, - sub_size=np.array(self.over_sampler.sub_size).astype("int"), + sub_size=np.array(self.over_sampler.sub_size._array).astype("int"), origin=self.mask.origin, ) ) @@ -545,7 +545,7 @@ def from_mask( """ grid_1d = grid_2d_util.grid_2d_slim_via_mask_from( - mask_2d=np.array(mask), + mask_2d=mask._array, pixel_scales=mask.pixel_scales, origin=mask.origin, ) @@ -818,7 +818,7 @@ def grid_with_coordinates_within_distance_removed_from( distance_mask += distances.native < distance mask = Mask2D( - mask=distance_mask, + mask=np.array(distance_mask), pixel_scales=self.pixel_scales, origin=self.origin, ) @@ -839,8 +839,8 @@ def squared_distances_to_coordinate_from( coordinate The (y,x) coordinate from which the squared distance of every grid (y,x) coordinate is computed. """ - if use_jax: - squared_distances = np.square(self.array[:, 0] - coordinate[0]) + np.square( + if isinstance(self, jnp.ndarray): + squared_distances = jnp.square(self.array[:, 0] - coordinate[0]) + jnp.square( self.array[:, 1] - coordinate[1] ) else: @@ -1013,10 +1013,16 @@ def shape_native_scaled_interior(self) -> Tuple[float, float]: of the grid's (y,x) values, whereas the `shape_native_scaled` uses the uniform geometry of the grid and its ``pixel_scales``, which means it has a buffer at each edge of half a ``pixel_scale``. """ - return ( - np.amax(self[:, 0]) - np.amin(self[:, 0]), - np.amax(self[:, 1]) - np.amin(self[:, 1]), - ) + if isinstance(self, jnp.ndarray): + return ( + np.amax(self.array[:, 0]) - np.amin(self.array[:, 0]), + np.amax(self.array[:, 1]) - np.amin(self.array[:, 1]), + ) + else: + return ( + np.amax(self[:, 0]) - np.amin(self[:, 0]), + np.amax(self[:, 1]) - np.amin(self[:, 1]), + ) @property def scaled_minima(self) -> Tuple: @@ -1024,10 +1030,10 @@ def scaled_minima(self) -> Tuple: The (y,x) minimum values of the grid in scaled units, buffed such that their extent is further than the grid's extent. """ - if use_jax: + if isinstance(self, jnp.ndarray): return ( - np.amin(self.array[:, 0]).astype("float"), - np.amin(self.array[:, 1]).astype("float"), + jnp.amin(self.array[:, 0]).astype("float"), + jnp.amin(self.array[:, 1]).astype("float"), ) else: return ( @@ -1041,10 +1047,10 @@ def scaled_maxima(self) -> Tuple: The (y,x) maximum values of the grid in scaled units, buffed such that their extent is further than the grid's extent. """ - if use_jax: + if isinstance(self, jnp.ndarray): return ( - np.amax(self.array[:, 0]).astype("float"), - np.amax(self.array[:, 1]).astype("float"), + jnp.amax(self.array[:, 0]).astype("float"), + jnp.amax(self.array[:, 1]).astype("float"), ) else: return ( @@ -1101,7 +1107,7 @@ def padded_grid_from(self, kernel_shape_native: Tuple[int, int]) -> "Grid2D": ) over_sample_size = np.pad( - self.over_sample_size.native, + self.over_sample_size.native._array, pad_width, mode="constant", constant_values=1, diff --git a/test_autoarray/geometry/test_geometry_util.py b/test_autoarray/geometry/test_geometry_util.py index b6b8b38d6..4bc2706dc 100644 --- a/test_autoarray/geometry/test_geometry_util.py +++ b/test_autoarray/geometry/test_geometry_util.py @@ -980,7 +980,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[0.0, 1.0], [1.0, 1.0], [1.0, 0.0]]) + np.array([[0.0, 1.0], [1.0, 1.0], [1.0, 0.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -994,7 +994,7 @@ def test__transform_2d_grid_to_reference_frame(): [0.0, np.sqrt(2)], [np.sqrt(2) / 2.0, np.sqrt(2) / 2.0], ] - ) + ), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1002,7 +1002,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[-1.0, 0.0], [-1.0, 1.0], [0.0, 1.0]]) + np.array([[-1.0, 0.0], [-1.0, 1.0], [0.0, 1.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1010,7 +1010,7 @@ def test__transform_2d_grid_to_reference_frame(): ) assert transformed_grid_2d == pytest.approx( - np.array([[0.0, -1.0], [-1.0, -1.0], [-1.0, 0.0]]) + np.array([[0.0, -1.0], [-1.0, -1.0], [-1.0, 0.0]]), abs=1.0e-4 ) transformed_grid_2d = aa.util.geometry.transform_grid_2d_to_reference_frame( @@ -1071,7 +1071,8 @@ def test__transform_2d_grid_from_reference_frame(): grid_2d=transformed_grid_2d, centre=(8.0, 5.0), angle=137.0 ) - assert grid_2d == pytest.approx(original_grid_2d, 1.0e-4) + + assert grid_2d == pytest.approx(original_grid_2d, abs=1.0e-4) def test__grid_pixels_2d_slim_from(): diff --git a/test_autoarray/mask/test_mask_1d.py b/test_autoarray/mask/test_mask_1d.py index 41f02d3b7..cb5fe2c56 100644 --- a/test_autoarray/mask/test_mask_1d.py +++ b/test_autoarray/mask/test_mask_1d.py @@ -1,6 +1,4 @@ -from astropy.io import fits import numpy as np -import os from os import path import pytest @@ -53,34 +51,34 @@ def test__constructor__input_is_2d_mask__raises_exception(): def test__is_all_true(): mask = aa.Mask1D(mask=[False, False, False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[False, True, False, False], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask1D(mask=[True, True, True, True], pixel_scales=1.0) - assert mask.is_all_true is True + assert mask.is_all_true == True def test__is_all_false(): mask = aa.Mask1D(mask=[False, False, False, False], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask1D(mask=[False, False], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask1D(mask=[False, True, False, False], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False mask = aa.Mask1D(mask=[True, True, False, False], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False diff --git a/test_autoarray/mask/test_mask_2d.py b/test_autoarray/mask/test_mask_2d.py index 7e12e5b82..399510a0f 100644 --- a/test_autoarray/mask/test_mask_2d.py +++ b/test_autoarray/mask/test_mask_2d.py @@ -366,37 +366,37 @@ def test__mask__input_is_1d_mask__no_shape_native__raises_exception(): def test__is_all_true(): mask = aa.Mask2D(mask=[[False, False], [False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[False, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_true is False + assert mask.is_all_true == False mask = aa.Mask2D(mask=[[True, True], [True, True]], pixel_scales=1.0) - assert mask.is_all_true is True + assert mask.is_all_true == True def test__is_all_false(): mask = aa.Mask2D(mask=[[False, False], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask2D(mask=[[False, False]], pixel_scales=1.0) - assert mask.is_all_false is True + assert mask.is_all_false == True mask = aa.Mask2D(mask=[[False, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False mask = aa.Mask2D(mask=[[True, True], [False, False]], pixel_scales=1.0) - assert mask.is_all_false is False + assert mask.is_all_false == False def test__shape_native_masked_pixels(): diff --git a/test_autoarray/structures/arrays/files/array/output_test/array.fits b/test_autoarray/structures/arrays/files/array/output_test/array.fits index 0cff0a8f7..dab9da2fb 100644 Binary files a/test_autoarray/structures/arrays/files/array/output_test/array.fits and b/test_autoarray/structures/arrays/files/array/output_test/array.fits differ diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 292d628f7..9106b43a1 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -457,4 +457,4 @@ def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy normalize=True, ) - assert kernel_astropy == pytest.approx(kernel_2d.native, 1e-4) + assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) diff --git a/test_autoarray/structures/arrays/test_repr.py b/test_autoarray/structures/arrays/test_repr.py index cf0bfa964..8ea8c6b9d 100644 --- a/test_autoarray/structures/arrays/test_repr.py +++ b/test_autoarray/structures/arrays/test_repr.py @@ -3,4 +3,4 @@ def test_repr(): array = aa.Array2D.no_mask([[1, 2], [3, 4]], pixel_scales=1) - assert repr(array) == "Array2D([1., 2., 3., 4.])" + assert repr(array) == "Array2D([1, 2, 3, 4])" diff --git a/test_autoarray/structures/grids/test_uniform_1d.py b/test_autoarray/structures/grids/test_uniform_1d.py index 147bd4b17..b99da3608 100644 --- a/test_autoarray/structures/grids/test_uniform_1d.py +++ b/test_autoarray/structures/grids/test_uniform_1d.py @@ -129,7 +129,7 @@ def test__grid_2d_radial_projected_from(): grid_2d = grid_1d.grid_2d_radial_projected_from(angle=90.0) assert grid_2d.slim == pytest.approx( - np.array([[-1.0, 0.0], [-2.0, 0.0], [-3.0, 0.0], [-4.0, 0.0]]), 1.0e-4 + np.array([[-1.0, 0.0], [-2.0, 0.0], [-3.0, 0.0], [-4.0, 0.0]]), abs=1.0e-4 ) grid_2d = grid_1d.grid_2d_radial_projected_from(angle=45.0) diff --git a/test_autoarray/structures/grids/test_uniform_2d.py b/test_autoarray/structures/grids/test_uniform_2d.py index 4084908ba..686721c70 100644 --- a/test_autoarray/structures/grids/test_uniform_2d.py +++ b/test_autoarray/structures/grids/test_uniform_2d.py @@ -481,9 +481,10 @@ def test__to_and_from_fits_methods(): def test__shape_native_scaled(): - mask = aa.Mask2D.circular(shape_native=(3, 3), radius=1.0, pixel_scales=(1.0, 1.0)) + mask = aa.Mask2D.circular(shape_native=(3, 3), radius=1.1, pixel_scales=(1.0, 1.0)) grid_2d = aa.Grid2D.from_mask(mask=mask) + assert grid_2d.shape_native_scaled_interior == (2.0, 2.0) mask = aa.Mask2D.elliptical( @@ -574,7 +575,8 @@ def test__grid_2d_radial_projected_shape_slim_from(): pixel_scales=grid_2d.pixel_scales, ) - assert (grid_radii == grid_radii_util).all() + + assert grid_radii == pytest.approx(grid_radii_util, 1.0e-4) assert grid_radial_shape_slim == grid_radii_util.shape[0] grid_radii = grid_2d.grid_2d_radial_projected_from(centre=(0.3, 0.1), angle=60.0) @@ -780,37 +782,6 @@ def test__grid_with_coordinates_within_distance_removed_from(): ).all() -def test__grid_radial_minimum(): - grid_2d = np.array([[2.5, 0.0], [4.0, 0.0], [6.0, 0.0]]) - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - assert (deflections == grid_2d).all() - - grid_2d = np.array([[2.0, 0.0], [1.0, 0.0], [6.0, 0.0]]) - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - - assert (deflections == np.array([[2.5, 0.0], [2.5, 0.0], [6.0, 0.0]])).all() - - grid_2d = np.array( - [ - [np.sqrt(2.0), np.sqrt(2.0)], - [1.0, np.sqrt(8.0)], - [np.sqrt(8.0), np.sqrt(8.0)], - ] - ) - - mock_profile = aa.m.MockGridRadialMinimum() - - deflections = mock_profile.deflections_yx_2d_from(grid=grid_2d) - - assert deflections == pytest.approx( - np.array([[1.7677, 1.7677], [1.0, np.sqrt(8.0)], [np.sqrt(8), np.sqrt(8.0)]]), - 1.0e-4, - ) - def test__recursive_shape_storage(): grid_2d = aa.Grid2D.no_mask(