diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index 9398775dd..ebbb9927f 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -1,4 +1,5 @@ import copy +import jax import jax.numpy as jnp import numpy as np from scipy.linalg import block_diag @@ -73,17 +74,6 @@ def __init__( A dictionary which contains timing of certain functions calls which is used for profiling. """ - # try: - # import numba - # except ModuleNotFoundError: - # raise exc.InversionException( - # "Inversion functionality (linear light profiles, pixelized reconstructions) is " - # "disabled if numba is not installed.\n\n" - # "This is because the run-times without numba are too slow.\n\n" - # "Please install numba, which is described at the following web page:\n\n" - # "https://pyautolens.readthedocs.io/en/latest/installation/overview.html" - # ) - self.dataset = dataset self.linear_obj_list = linear_obj_list @@ -317,7 +307,7 @@ def operated_mapping_matrix(self) -> np.ndarray: If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous linear equations are solved simultaneously. """ - return np.hstack(self.operated_mapping_matrix_list) + return jnp.hstack(self.operated_mapping_matrix_list) @cached_property @profile_func @@ -474,20 +464,17 @@ def reconstruction(self) -> np.ndarray: And the data_vector = ZTx, so the corresponding row is also taken out. """ - if self.settings.force_edge_pixels_to_zeros: - if self.settings.force_edge_image_pixels_to_zeros: - ids_zeros = np.unique( - np.append( - self.mapper_edge_pixel_list, self.mapper_zero_pixel_list - ) - ) - else: - ids_zeros = self.mapper_edge_pixel_list + if ( + self.has(cls=AbstractMapper) + and self.settings.force_edge_pixels_to_zeros + ): - values_to_solve = np.ones( - np.shape(self.curvature_reg_matrix)[0], dtype=bool + ids_zeros = jnp.array(self.mapper_edge_pixel_list, dtype=int) + + values_to_solve = jnp.ones( + self.curvature_reg_matrix.shape[0], dtype=bool ) - values_to_solve[ids_zeros] = False + values_to_solve = values_to_solve.at[ids_zeros].set(False) data_vector_input = self.data_vector[values_to_solve] @@ -495,25 +482,32 @@ def reconstruction(self) -> np.ndarray: values_to_solve, : ][:, values_to_solve] - solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0]) - - solutions[values_to_solve] = ( - inversion_util.reconstruction_positive_only_from( - data_vector=data_vector_input, - curvature_reg_matrix=curvature_reg_matrix_input, - settings=self.settings, - ) + # Get the values to assign (must be a JAX array) + reconstruction = inversion_util.reconstruction_positive_only_from( + data_vector=data_vector_input, + curvature_reg_matrix=curvature_reg_matrix_input, + settings=self.settings, ) + + # Allocate JAX array + solutions = jnp.zeros(self.curvature_reg_matrix.shape[0]) + + # Get indices where True + indices = jnp.where(values_to_solve)[0] + + # Set reconstruction values at those indices + solutions = solutions.at[indices].set(reconstruction) + return solutions + else: - solutions = inversion_util.reconstruction_positive_only_from( + + return inversion_util.reconstruction_positive_only_from( data_vector=self.data_vector, curvature_reg_matrix=self.curvature_reg_matrix, settings=self.settings, ) - return solutions - mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) return inversion_util.reconstruction_positive_negative_from( @@ -522,81 +516,6 @@ def reconstruction(self) -> np.ndarray: mapper_param_range_list=mapper_param_range_list, ) - # @cached_property - # @profile_func - # def reconstruction(self) -> np.ndarray: - # """ - # Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12) - # of https://arxiv.org/pdf/astro-ph/0302587.pdf (Positive-Negative solution) - # - # ============================================================================================ - # - # Solve the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf (Non-negative solution) - # Find non-negative solution that minimizes |Z * S - x|^2. - # - # We use fnnls (https://github.com/jvendrow/fnnls) to optimize the quadratic value. Two commonly used - # variables in the code are defined as follows: - # ZTZ := np.dot(Z.T, Z) - # ZTx := np.dot(Z.T, x) - # """ - # if self.settings.use_positive_only_solver: - # """ - # For the new implementation, we now need to take out the cols and rows of - # the curvature_reg_matrix that corresponds to the parameters we force to be 0. - # Similar for the data vector. - # - # What we actually doing is that we have set the correspoding cols of the Z to be 0. - # As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out. - # And the data_vector = ZTx, so the corresponding row is also taken out. - # """ - # - # if self.settings.force_edge_pixels_to_zeros: - # if self.settings.force_edge_image_pixels_to_zeros: - # ids_zeros = np.unique( - # np.append( - # self.mapper_edge_pixel_list, self.mapper_zero_pixel_list - # ) - # ) - # else: - # ids_zeros = self.mapper_edge_pixel_list - # - # values_to_solve = np.ones( - # np.shape(self.curvature_reg_matrix)[0], dtype=bool - # ) - # values_to_solve[ids_zeros] = False - # - # data_vector_input = self.data_vector[values_to_solve] - # - # curvature_reg_matrix_input = self.curvature_reg_matrix[ - # values_to_solve, : - # ][:, values_to_solve] - # - # solutions = inversion_util.reconstruction_positive_only_from( - # data_vector=data_vector_input, - # curvature_reg_matrix=curvature_reg_matrix_input, - # settings=self.settings, - # ) - # - # mask = values_to_solve.astype(bool) - # - # return solutions[mask] - # else: - # solutions = inversion_util.reconstruction_positive_only_from( - # data_vector=self.data_vector, - # curvature_reg_matrix=self.curvature_reg_matrix, - # settings=self.settings, - # ) - # - # return solutions - # - # mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) - # - # return inversion_util.reconstruction_positive_negative_from( - # data_vector=self.data_vector, - # curvature_reg_matrix=self.curvature_reg_matrix, - # mapper_param_range_list=mapper_param_range_list, - # ) - @cached_property @profile_func def reconstruction_reduced(self) -> np.ndarray: diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 12881a944..2ec0df160 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -109,7 +109,6 @@ def data_vector(self) -> np.ndarray: The calculation is described in more detail in `inversion_util.data_vector_via_blurred_mapping_matrix_from`. """ - return inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from( blurred_mapping_matrix=self.operated_mapping_matrix, image=self.data.array, diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index 1f6d08ac1..97fc9eb8f 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -440,7 +440,7 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: data_weights=mapper.unique_mappings.data_weights, pix_lengths=mapper.unique_mappings.pix_lengths, pix_pixels=mapper.params, - curvature_weights=curvature_weights, + curvature_weights=np.array(curvature_weights), image_frame_1d_lengths=self.convolver.image_frame_1d_lengths, image_frame_1d_indexes=self.convolver.image_frame_1d_indexes, image_frame_1d_kernels=self.convolver.image_frame_1d_kernels, diff --git a/autoarray/inversion/inversion/interferometer/mapping.py b/autoarray/inversion/inversion/interferometer/mapping.py index 9cde492e9..2b3219a2e 100644 --- a/autoarray/inversion/inversion/interferometer/mapping.py +++ b/autoarray/inversion/inversion/interferometer/mapping.py @@ -76,7 +76,7 @@ def data_vector(self) -> np.ndarray: """ return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from( - transformed_mapping_matrix=self.operated_mapping_matrix, + transformed_mapping_matrix=np.array(self.operated_mapping_matrix), visibilities=np.array(self.data), noise_map=np.array(self.noise_map), ) @@ -152,8 +152,10 @@ def mapped_reconstructed_data_dict( visibilities = ( inversion_interferometer_util.mapped_reconstructed_visibilities_from( - transformed_mapping_matrix=operated_mapping_matrix_list[index], - reconstruction=reconstruction, + transformed_mapping_matrix=np.array( + operated_mapping_matrix_list[index] + ), + reconstruction=np.array(reconstruction), ) ) diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 246e2cf56..29f1eaa07 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -1,5 +1,7 @@ import jax.numpy as jnp import jaxnnls +import jax +import jax.lax as lax import numpy as np from typing import List, Optional, Tuple @@ -10,7 +12,6 @@ from autoarray import numba_util from autoarray import exc -from autoarray.util.fnnls import fnnls_cholesky def curvature_matrix_via_w_tilde_from( @@ -41,7 +42,6 @@ def curvature_matrix_via_w_tilde_from( return np.dot(mapping_matrix.T, np.dot(w_tilde, mapping_matrix)) -@numba_util.jit() def curvature_matrix_with_added_to_diag_from( curvature_matrix: np.ndarray, value: float, @@ -56,6 +56,38 @@ def curvature_matrix_with_added_to_diag_from( This function adds this numerical value to the diagonal of the curvature matrix. + Parameters + ---------- + curvature_matrix + The curvature matrix which is being constructed in order to solve a linear system of equations. + """ + try: + return curvature_matrix.at[ + no_regularization_index_list, no_regularization_index_list + ].add(value) + except AttributeError: + return curvature_matrix_with_added_to_diag_from_numba( + curvature_matrix=curvature_matrix, + value=value, + no_regularization_index_list=no_regularization_index_list, + ) + + +@numba_util.jit() +def curvature_matrix_with_added_to_diag_from_numba( + curvature_matrix: np.ndarray, + value: float, + no_regularization_index_list: Optional[List] = None, +) -> np.ndarray: + """ + It is common for the `curvature_matrix` computed to not be positive-definite, leading for the inversion + via `np.linalg.solve` to fail and raise a `LinAlgError`. + + In many circumstances, adding a small numerical value of `1.0e-8` to the diagonal of the `curvature_matrix` + makes it positive definite, such that the inversion is performed without raising an error. + + This function adds this numerical value to the diagonal of the curvature matrix. + Parameters ---------- curvature_matrix @@ -68,48 +100,17 @@ def curvature_matrix_with_added_to_diag_from( return curvature_matrix -# def curvature_matrix_with_added_to_diag_from( -# curvature_matrix: np.ndarray, -# value: float, -# no_regularization_index_list: Optional[List] = None, -# ) -> np.ndarray: -# """ -# It is common for the `curvature_matrix` computed to not be positive-definite, leading for the inversion -# via `np.linalg.solve` to fail and raise a `LinAlgError`. -# -# In many circumstances, adding a small numerical value of `1.0e-8` to the diagonal of the `curvature_matrix` -# makes it positive definite, such that the inversion is performed without raising an error. -# -# This function adds this numerical value to the diagonal of the curvature matrix. -# -# Parameters -# ---------- -# curvature_matrix -# The curvature matrix which is being constructed in order to solve a linear system of equations. -# """ -# return curvature_matrix.at[ -# no_regularization_index_list, no_regularization_index_list -# ].add(value) - - -@numba_util.jit() def curvature_matrix_mirrored_from( curvature_matrix: np.ndarray, ) -> np.ndarray: - curvature_matrix_mirrored = np.zeros( - (curvature_matrix.shape[0], curvature_matrix.shape[1]) - ) + # Copy the original matrix and its transpose + m1 = curvature_matrix + m2 = curvature_matrix.T - for i in range(curvature_matrix.shape[0]): - for j in range(curvature_matrix.shape[1]): - if curvature_matrix[i, j] != 0: - curvature_matrix_mirrored[i, j] = curvature_matrix[i, j] - curvature_matrix_mirrored[j, i] = curvature_matrix[i, j] - if curvature_matrix[j, i] != 0: - curvature_matrix_mirrored[i, j] = curvature_matrix[j, i] - curvature_matrix_mirrored[j, i] = curvature_matrix[j, i] + # For each entry, prefer the non-zero value from either the matrix or its transpose + mirrored = jnp.where(m1 != 0, m1, m2) - return curvature_matrix_mirrored + return mirrored def curvature_matrix_via_mapping_matrix_from( @@ -132,7 +133,7 @@ def curvature_matrix_via_mapping_matrix_from( Flattened 1D array of the noise-map used by the inversion during the fit. """ array = mapping_matrix / noise_map[:, None] - curvature_matrix = np.dot(array.T, array) + curvature_matrix = jnp.dot(array.T, array) if add_to_curvature_diag and len(no_regularization_index_list) > 0: curvature_matrix = curvature_matrix_with_added_to_diag_from( @@ -188,7 +189,7 @@ def mapped_reconstructed_data_via_mapping_matrix_from( The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. """ - return np.dot(mapping_matrix, reconstruction) + return jnp.dot(mapping_matrix, reconstruction) def reconstruction_positive_negative_from( @@ -233,10 +234,13 @@ def reconstruction_positive_negative_from( The curvature_matrix plus regularization matrix, overwriting the curvature_matrix in memory. """ try: - reconstruction = np.linalg.solve(curvature_reg_matrix, data_vector) + reconstruction = jnp.linalg.solve(curvature_reg_matrix, data_vector) except np.linalg.LinAlgError as e: raise exc.InversionException() from e + if jnp.isnan(reconstruction).any(): + raise exc.InversionException + if ( conf.instance["general"]["inversion"]["check_reconstruction"] or force_check_reconstruction @@ -299,29 +303,19 @@ def reconstruction_positive_only_from( Non-negative S that minimizes the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf. """ - # try: - # return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector) - # except (RuntimeError, np.linalg.LinAlgError, ValueError) as e: - # raise exc.InversionException() from e - - if len(data_vector): - try: - if settings.positive_only_uses_p_initial: - P_initial = np.linalg.solve(curvature_reg_matrix, data_vector) > 0 - else: - P_initial = np.zeros(0, dtype=int) - - reconstruction = fnnls_cholesky( - curvature_reg_matrix, - (data_vector).T, - P_initial=P_initial, - ) + try: + reconstruction = jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector) + except (RuntimeError, np.linalg.LinAlgError, ValueError) as e: + raise exc.InversionException() from e + + def handle_nan(reconstruction): + return jnp.zeros_like(reconstruction) - except (RuntimeError, np.linalg.LinAlgError, ValueError) as e: - raise exc.InversionException() from e + def handle_valid(reconstruction): + return reconstruction - else: - raise exc.InversionException() + has_nan = jnp.isnan(reconstruction).any() + reconstruction = lax.cond(has_nan, handle_nan, handle_valid, reconstruction) return reconstruction diff --git a/autoarray/inversion/inversion/settings.py b/autoarray/inversion/inversion/settings.py index 2c0eba077..184e16977 100644 --- a/autoarray/inversion/inversion/settings.py +++ b/autoarray/inversion/inversion/settings.py @@ -15,7 +15,6 @@ def __init__( positive_only_uses_p_initial: Optional[bool] = None, use_border_relocator: Optional[bool] = None, force_edge_pixels_to_zeros: bool = True, - force_edge_image_pixels_to_zeros: bool = False, image_pixels_source_zero=None, no_regularization_add_to_curvature_diag_value: float = None, use_w_tilde_numpy: bool = False, @@ -84,7 +83,6 @@ def __init__( self._use_border_relocator = use_border_relocator self.use_linear_operators = use_linear_operators self.force_edge_pixels_to_zeros = force_edge_pixels_to_zeros - self.force_edge_image_pixels_to_zeros = force_edge_image_pixels_to_zeros self.image_pixels_source_zero = image_pixels_source_zero self._no_regularization_add_to_curvature_diag_value = ( no_regularization_add_to_curvature_diag_value diff --git a/autoarray/inversion/pixelization/mappers/voronoi.py b/autoarray/inversion/pixelization/mappers/voronoi.py index c8e54cbf3..1ebae8ad9 100644 --- a/autoarray/inversion/pixelization/mappers/voronoi.py +++ b/autoarray/inversion/pixelization/mappers/voronoi.py @@ -172,5 +172,8 @@ def interpolated_array_from( is input. """ return self.source_plane_mesh_grid.interpolated_array_from( - values=values, shape_native=shape_native, extent=extent, use_nn=True + values=np.array(values), + shape_native=shape_native, + extent=extent, + use_nn=True, ) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index b78bb0829..5d9a99871 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -147,6 +147,17 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): over_sample_size=sub_size, mask=mask ) + # Used for JAX based adaptive over sampling. + + # Define group sizes + group_sizes = np.array(self.sub_size.array**2) + + # Compute the cumulative sum of group sizes to get split points + self.split_indices = np.cumsum(group_sizes) + + # Ensure correct concatenation by making 0 a JAX array + self.start_indices = np.concatenate((np.array([0]), self.split_indices[:-1])) + @property def sub_is_uniform(self) -> bool: """ @@ -234,20 +245,11 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": ).mean(axis=1) else: - # Define group sizes - group_sizes = jnp.array(self.sub_size.array**2) - - # Compute the cumulative sum of group sizes to get split points - split_indices = jnp.cumsum(group_sizes) - - # Ensure correct concatenation by making 0 a JAX array - start_indices = jnp.concatenate((jnp.array([0]), split_indices[:-1])) - # Compute the group means binned_array_2d = jnp.array( [ array[start:end].mean() - for start, end in zip(start_indices, split_indices) + for start, end in zip(self.start_indices, self.split_indices) ] ) diff --git a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py index 28c45a9bd..c7dc03221 100644 --- a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py +++ b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py @@ -248,13 +248,13 @@ def test__identical_inversion_values_for_two_methods(): inversion_w_tilde = aa.Inversion( dataset=dataset, linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_w_tilde=True), + settings=aa.SettingsInversion(use_w_tilde=True, use_positive_only_solver=True), ) inversion_mapping_matrices = aa.Inversion( dataset=dataset, linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_w_tilde=False), + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), ) assert (inversion_w_tilde.data == inversion_mapping_matrices.data).all() @@ -285,8 +285,8 @@ def test__identical_inversion_values_for_two_methods(): assert inversion_w_tilde.reconstruction == pytest.approx( inversion_mapping_matrices.reconstruction, abs=1.0e-1 ) - assert inversion_w_tilde.mapped_reconstructed_image == pytest.approx( - inversion_mapping_matrices.mapped_reconstructed_image, abs=1.0e-1 + assert inversion_w_tilde.mapped_reconstructed_image.array == pytest.approx( + inversion_mapping_matrices.mapped_reconstructed_image.array, abs=1.0e-1 ) assert inversion_w_tilde.mapped_reconstructed_data == pytest.approx( inversion_mapping_matrices.mapped_reconstructed_data, abs=1.0e-1 @@ -347,13 +347,17 @@ def test__identical_inversion_source_and_image_loops(): inversion_image_loop = aa.Inversion( dataset=dataset, linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_w_tilde=True, use_source_loop=False), + settings=aa.SettingsInversion( + use_w_tilde=True, use_source_loop=False, use_positive_only_solver=True + ), ) inversion_source_loop = aa.Inversion( dataset=dataset, linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_w_tilde=True, use_source_loop=True), + settings=aa.SettingsInversion( + use_w_tilde=True, use_source_loop=True, use_positive_only_solver=True + ), ) assert (inversion_image_loop.data == inversion_source_loop.data).all() @@ -380,8 +384,8 @@ def test__identical_inversion_source_and_image_loops(): assert inversion_image_loop.reconstruction == pytest.approx( inversion_source_loop.reconstruction, 1.0e-2 ) - assert inversion_image_loop.mapped_reconstructed_image == pytest.approx( - inversion_source_loop.mapped_reconstructed_image, 1.0e-2 + assert inversion_image_loop.mapped_reconstructed_image.array == pytest.approx( + inversion_source_loop.mapped_reconstructed_image.array, 1.0e-2 ) assert inversion_image_loop.mapped_reconstructed_data == pytest.approx( inversion_source_loop.mapped_reconstructed_data, 1.0e-2 diff --git a/test_autoarray/inversion/inversion/test_abstract.py b/test_autoarray/inversion/inversion/test_abstract.py index 5ce2ad473..bf8f4a919 100644 --- a/test_autoarray/inversion/inversion/test_abstract.py +++ b/test_autoarray/inversion/inversion/test_abstract.py @@ -6,7 +6,6 @@ from autoarray import exc - directory = path.dirname(path.realpath(__file__)) diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index 743833b6a..984aab946 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -1,6 +1,7 @@ import copy import numpy as np import pytest +from dill import settings import autoarray as aa @@ -43,6 +44,7 @@ def test__inversion_imaging__via_linear_obj_func_list(masked_imaging_7x7_no_blur linear_obj = aa.m.MockLinearObjFuncList( parameters=2, grid=grid, mapping_matrix=np.full(fill_value=0.5, shape=(9, 2)) ) + linear_obj.mapping_matrix[0, 0] = 1.0 inversion = aa.Inversion( dataset=masked_imaging_7x7_no_blur, @@ -52,8 +54,8 @@ def test__inversion_imaging__via_linear_obj_func_list(masked_imaging_7x7_no_blur assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObjFuncList) assert isinstance(inversion, aa.InversionImagingMapping) + assert inversion.reconstruction == pytest.approx(np.array([0.0, 2.0]), abs=1.0e-4) assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) - assert inversion.reconstruction == pytest.approx(np.array([1.0, 1.0]), 1.0e-4) def test__inversion_imaging__via_mapper( @@ -264,7 +266,7 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper__force_edge_pixels_t assert isinstance(inversion.linear_obj_list[1], aa.MapperDelaunay) assert isinstance(inversion, aa.InversionImagingMapping) assert inversion.reconstruction == pytest.approx( - np.array([2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), 1.0e-4 + np.array([2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), abs=1.0e-2 ) @@ -289,8 +291,8 @@ def test__inversion_imaging__compare_mapping_and_w_tilde_values( assert inversion_w_tilde.reconstruction == pytest.approx( inversion_mapping.reconstruction, 1.0e-4 ) - assert inversion_w_tilde.mapped_reconstructed_image == pytest.approx( - inversion_mapping.mapped_reconstructed_image, 1.0e-4 + assert inversion_w_tilde.mapped_reconstructed_image.array == pytest.approx( + inversion_mapping.mapped_reconstructed_image.array, 1.0e-4 ) assert inversion_w_tilde.log_det_curvature_reg_matrix_term == pytest.approx( inversion_mapping.log_det_curvature_reg_matrix_term @@ -309,13 +311,15 @@ def test__inversion_imaging__linear_obj_func_and_non_func_give_same_terms( grid = aa.Grid2D.from_mask(mask=mask) linear_obj = aa.m.MockLinearObj( - parameters=2, grid=grid, mapping_matrix=np.full(fill_value=0.5, shape=(9, 2)) + parameters=2, + grid=grid, + mapping_matrix=np.full(fill_value=0.5, shape=(9, 2)), ) inversion = aa.Inversion( dataset=masked_imaging_7x7_no_blur, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=False), + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), ) masked_imaging_7x7_no_blur = copy.copy(masked_imaging_7x7_no_blur) @@ -327,7 +331,7 @@ def test__inversion_imaging__linear_obj_func_and_non_func_give_same_terms( inversion_no_linear_func = aa.Inversion( dataset=masked_imaging_7x7_no_blur, linear_obj_list=[rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=False), + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), ) assert inversion.regularization_term == pytest.approx( @@ -367,13 +371,13 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( inversion_mapping = aa.Inversion( dataset=masked_imaging_7x7, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=False), + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), ) inversion_w_tilde = aa.Inversion( dataset=masked_imaging_7x7, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=True), + settings=aa.SettingsInversion(use_w_tilde=True, use_positive_only_solver=True), ) assert inversion_mapping.data_vector == pytest.approx( @@ -388,8 +392,8 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( assert inversion_mapping.reconstruction == pytest.approx( inversion_w_tilde.reconstruction, 1.0e-4 ) - assert inversion_mapping.mapped_reconstructed_image == pytest.approx( - inversion_w_tilde.mapped_reconstructed_image, 1.0e-4 + assert inversion_mapping.mapped_reconstructed_image.array == pytest.approx( + inversion_w_tilde.mapped_reconstructed_image.array, 1.0e-4 ) linear_obj_1 = aa.m.MockLinearObjFuncList( @@ -476,9 +480,11 @@ def test__inversion_matrices__x2_mappers( delaunay_mapper_9_3x3, regularization_constant, ): + inversion = aa.Inversion( dataset=masked_imaging_7x7_no_blur, linear_obj_list=[rectangular_mapper_7x7_3x3, delaunay_mapper_9_3x3], + settings=aa.SettingsInversion(use_positive_only_solver=True), ) assert ( @@ -524,26 +530,23 @@ def test__inversion_matrices__x2_mappers( assert (inversion.regularization_matrix[0:9, 9:18] == np.zeros((9, 9))).all() assert (inversion.regularization_matrix[9:18, 0:9] == np.zeros((9, 9))).all() - reconstruction_0 = 0.5 * np.ones(9) - reconstruction_1 = 0.5 * np.ones(9) - - assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3] == pytest.approx( - reconstruction_0, 1.0e-4 - ) - assert inversion.reconstruction_dict[delaunay_mapper_9_3x3] == pytest.approx( - reconstruction_1, 1.0e-4 - ) - assert inversion.reconstruction == pytest.approx( - np.concatenate([reconstruction_0, reconstruction_1]), 1.0e-4 + assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3][ + 4 + ] == pytest.approx(0.05594123, 1.0e-4) + assert inversion.reconstruction_dict[delaunay_mapper_9_3x3][4] == pytest.approx( + 0.04686388, 1.0e-4 ) + assert inversion.reconstruction[13] == pytest.approx(0.04686388, 1.0e-4) - assert inversion.mapped_reconstructed_data_dict[ - rectangular_mapper_7x7_3x3 - ] == pytest.approx(0.5 * np.ones(9), 1.0e-4) - assert inversion.mapped_reconstructed_data_dict[ - delaunay_mapper_9_3x3 - ] == pytest.approx(0.5 * np.ones(9), 1.0e-4) - assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) + assert inversion.mapped_reconstructed_data_dict[rectangular_mapper_7x7_3x3][ + 4 + ] == pytest.approx(0.05594123, 1.0e-4) + assert inversion.mapped_reconstructed_data_dict[delaunay_mapper_9_3x3][ + 3 + ] == pytest.approx(0.01521323, 1.0e-4) + assert inversion.mapped_reconstructed_image[4] == pytest.approx( + 0.10494037076075, 1.0e-4 + ) def test__inversion_imaging__positive_only_solver(masked_imaging_7x7_no_blur): diff --git a/test_autoarray/inversion/inversion/test_settings_dict.py b/test_autoarray/inversion/inversion/test_settings_dict.py index 6a6f8ca9f..d547014e1 100644 --- a/test_autoarray/inversion/inversion/test_settings_dict.py +++ b/test_autoarray/inversion/inversion/test_settings_dict.py @@ -17,7 +17,6 @@ def make_settings_dict(): "use_positive_only_solver": False, "positive_only_uses_p_initial": False, "force_edge_pixels_to_zeros": True, - "force_edge_image_pixels_to_zeros": False, "image_pixels_source_zero": None, "no_regularization_add_to_curvature_diag_value": 1e-08, "use_w_tilde_numpy": False,