diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 2c2236cfb..789dac386 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -65,6 +65,7 @@ from .operators.contour import Grid2DContour from .layout.layout import Layout1D from .layout.layout import Layout2D +from .preloads import Preloads from .structures.arrays.uniform_1d import Array1D from .structures.arrays.uniform_2d import Array2D from .structures.arrays.rgb import Array2DRGB diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index 48032ef97..bc0daf0ad 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -13,6 +13,7 @@ from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.inversion.regularization.abstract import AbstractRegularization from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.visibilities import Visibilities @@ -27,6 +28,7 @@ def __init__( dataset: Union[Imaging, Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), + preloads: Preloads = None, ): """ An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions @@ -66,23 +68,14 @@ def __init__( Settings controlling how an inversion is fitted for example which linear algebra formalism is used. """ - try: - import numba - except ModuleNotFoundError: - raise exc.InversionException( - "Inversion functionality (linear light profiles, pixelized reconstructions) is " - "disabled if numba is not installed.\n\n" - "This is because the run-times without numba are too slow.\n\n" - "Please install numba, which is described at the following web page:\n\n" - "https://pyautolens.readthedocs.io/en/latest/installation/overview.html" - ) - self.dataset = dataset self.linear_obj_list = linear_obj_list self.settings = settings + self.preloads = preloads or Preloads() + @property def data(self): return self.dataset.data @@ -156,17 +149,9 @@ def param_range_list_from(self, cls: Type) -> List[List[int]]: ------- A list of the index range of the parameters of each linear object in the inversion of the input cls type. """ - index_list = [] - - pixel_count = 0 - - for linear_obj in self.linear_obj_list: - if isinstance(linear_obj, cls): - index_list.append([pixel_count, pixel_count + linear_obj.params]) - - pixel_count += linear_obj.params - - return index_list + return inversion_util.param_range_list_from( + cls=cls, linear_obj_list=self.linear_obj_list + ) def cls_list_from(self, cls: Type, cls_filtered: Optional[Type] = None) -> List: """ @@ -267,6 +252,22 @@ def no_regularization_index_list(self) -> List[int]: return no_regularization_index_list + @property + def mapper_indices(self) -> np.ndarray: + + if self.preloads.mapper_indices is not None: + return self.preloads.mapper_indices + + mapper_indices = [] + + param_range_list = self.param_range_list_from(cls=AbstractMapper) + + for param_range in param_range_list: + + mapper_indices += range(param_range[0], param_range[1]) + + return np.array(mapper_indices) + @property def mask(self) -> Array2D: return self.data.mask @@ -354,19 +355,14 @@ def regularization_matrix_reduced(self) -> Optional[np.ndarray]: regularization it is bypassed. """ - regularization_matrix = self.regularization_matrix - if self.all_linear_obj_have_regularization: - return regularization_matrix + return self.regularization_matrix - regularization_matrix = np.delete( - regularization_matrix, self.no_regularization_index_list, 0 - ) - regularization_matrix = np.delete( - regularization_matrix, self.no_regularization_index_list, 1 - ) + # ids of values which are on edge so zero-d and not solved for. + ids_to_keep = self.mapper_indices - return regularization_matrix + # Zero rows and columns in the matrix we want to ignore + return self.regularization_matrix[ids_to_keep][:, ids_to_keep] @cached_property def curvature_reg_matrix(self) -> np.ndarray: @@ -381,55 +377,31 @@ def curvature_reg_matrix(self) -> np.ndarray: if not self.has(cls=AbstractRegularization): return self.curvature_matrix - if len(self.regularization_list) == 1: - curvature_matrix = self.curvature_matrix - curvature_matrix += self.regularization_matrix - - del self.__dict__["curvature_matrix"] - - return curvature_matrix - - return np.add(self.curvature_matrix, self.regularization_matrix) + return jnp.add(self.curvature_matrix, self.regularization_matrix) @cached_property - def curvature_reg_matrix_reduced(self) -> np.ndarray: + def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]: """ - The linear system of equations solves for F + regularization_coefficient*H, which is computed below. + The regularization matrix H is used to impose smoothness on our inversion's reconstruction. This enters the + linear algebra system we solve for using D and F above and is given by + equation (12) in https://arxiv.org/pdf/astro-ph/0302587.pdf. - This is the curvature reg matrix for only the mappers, which is necessary for computing the log det - term without the linear light profiles included. + A complete description of regularization is given in the `regularization.py` and `regularization_util.py` + modules. + + For multiple mappers, the regularization matrix is computed as the block diagonal of each individual mapper. + The scipy function `block_diag` has an overhead associated with it and if there is only one mapper and + regularization it is bypassed. """ + if self.all_linear_obj_have_regularization: return self.curvature_reg_matrix - curvature_reg_matrix = self.curvature_reg_matrix + # ids of values which are on edge so zero-d and not solved for. + ids_to_keep = self.mapper_indices - curvature_reg_matrix = np.delete( - curvature_reg_matrix, self.no_regularization_index_list, 0 - ) - curvature_reg_matrix = np.delete( - curvature_reg_matrix, self.no_regularization_index_list, 1 - ) - - return curvature_reg_matrix - - @property - def mapper_zero_pixel_list(self) -> np.ndarray: - mapper_zero_pixel_list = [] - param_range_list = self.param_range_list_from(cls=LinearObj) - for param_range, linear_obj in zip(param_range_list, self.linear_obj_list): - if isinstance(linear_obj, AbstractMapper): - mapping_matrix_for_image_pixels_source_zero = linear_obj.mapping_matrix[ - self.settings.image_pixels_source_zero - ] - source_pixels_zero = ( - np.sum(mapping_matrix_for_image_pixels_source_zero != 0, axis=0) - != 0 - ) - mapper_zero_pixel_list.append( - np.where(source_pixels_zero == True)[0] + param_range[0] - ) - return mapper_zero_pixel_list + # Zero rows and columns in the matrix we want to ignore + return self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep] @cached_property def reconstruction(self) -> np.ndarray: @@ -448,51 +420,36 @@ def reconstruction(self) -> np.ndarray: ZTx := np.dot(Z.T, x) """ if self.settings.use_positive_only_solver: - """ - For the new implementation, we now need to take out the cols and rows of - the curvature_reg_matrix that corresponds to the parameters we force to be 0. - Similar for the data vector. - What we actually doing is that we have set the correspoding cols of the Z to be 0. - As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out. - And the data_vector = ZTx, so the corresponding row is also taken out. - """ + if self.preloads.source_pixel_zeroed_indices is not None: - if ( - self.has(cls=AbstractMapper) - and self.settings.force_edge_pixels_to_zeros - ): + # ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads. + ids_to_keep = self.preloads.source_pixel_zeroed_indices_to_keep - ids_zeros = jnp.array(self.mapper_edge_pixel_list, dtype=int) + # Use advanced indexing to select rows/columns + data_vector = self.data_vector[ids_to_keep] + curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][ + :, ids_to_keep + ] - values_to_solve = jnp.ones( - self.curvature_reg_matrix.shape[0], dtype=bool + # Perform reconstruction via fnnls + reconstruction_partial = ( + inversion_util.reconstruction_positive_only_from( + data_vector=data_vector, + curvature_reg_matrix=curvature_reg_matrix, + settings=self.settings, + ) ) - values_to_solve = values_to_solve.at[ids_zeros].set(False) - - data_vector_input = self.data_vector[values_to_solve] - curvature_reg_matrix_input = self.curvature_reg_matrix[ - values_to_solve, : - ][:, values_to_solve] + # Allocate full solution array + reconstruction = jnp.zeros(self.data_vector.shape[0]) - # Get the values to assign (must be a JAX array) - reconstruction = inversion_util.reconstruction_positive_only_from( - data_vector=data_vector_input, - curvature_reg_matrix=curvature_reg_matrix_input, - settings=self.settings, + # Scatter the partial solution back to the full shape + reconstruction = reconstruction.at[ids_to_keep].set( + reconstruction_partial ) - # Allocate JAX array - solutions = jnp.zeros(self.curvature_reg_matrix.shape[0]) - - # Get indices where True - indices = jnp.where(values_to_solve)[0] - - # Set reconstruction values at those indices - solutions = solutions.at[indices].set(reconstruction) - - return solutions + return reconstruction else: @@ -522,7 +479,11 @@ def reconstruction_reduced(self) -> np.ndarray: if self.all_linear_obj_have_regularization: return self.reconstruction - return np.delete(self.reconstruction, self.no_regularization_index_list, axis=0) + # ids of values which are on edge so zero-d and not solved for. + ids_to_keep = self.mapper_indices + + # Zero rows and columns in the matrix we want to ignore + return self.reconstruction[ids_to_keep] @property def reconstruction_dict(self) -> Dict[LinearObj, np.ndarray]: @@ -665,9 +626,9 @@ def regularization_term(self) -> float: if not self.has(cls=AbstractRegularization): return 0.0 - return np.matmul( + return jnp.matmul( self.reconstruction_reduced.T, - np.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced), + jnp.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced), ) @cached_property @@ -682,7 +643,9 @@ def log_det_curvature_reg_matrix_term(self) -> float: try: return 2.0 * np.sum( - np.log(np.diag(np.linalg.cholesky(self.curvature_reg_matrix_reduced))) + jnp.log( + jnp.diag(jnp.linalg.cholesky(self.curvature_reg_matrix_reduced)) + ) ) except np.linalg.LinAlgError as e: raise exc.InversionException() from e diff --git a/autoarray/inversion/inversion/factory.py b/autoarray/inversion/inversion/factory.py index 1e14d1e10..b7c9016b1 100644 --- a/autoarray/inversion/inversion/factory.py +++ b/autoarray/inversion/inversion/factory.py @@ -14,6 +14,7 @@ from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D @@ -21,6 +22,7 @@ def inversion_from( dataset: Union[Imaging, Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), + preloads: Preloads = None, ): """ Factory which given an input dataset and list of linear objects, creates an `Inversion`. @@ -55,6 +57,7 @@ def inversion_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, + preloads=preloads, ) return inversion_interferometer_from( @@ -68,6 +71,7 @@ def inversion_imaging_from( dataset, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), + preloads: Preloads = None, ): """ Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`. @@ -126,6 +130,7 @@ def inversion_imaging_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, + preloads=preloads, ) diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 4d785abed..8d85a5db1 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Dict, List, Optional, Union, Type +from typing import Dict, List, Union, Type from autoconf import cached_property @@ -10,6 +10,7 @@ from autoarray.inversion.inversion.abstract import AbstractInversion from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.preloads import Preloads from autoarray.inversion.inversion.imaging import inversion_imaging_util @@ -20,6 +21,7 @@ def __init__( dataset: Union[Imaging, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), + preloads: Preloads = None, ): """ An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions @@ -66,6 +68,7 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, + preloads=preloads, ) @property diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 60fd54a44..698750a22 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -9,6 +9,7 @@ from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.inversion.inversion import inversion_util @@ -21,6 +22,7 @@ def __init__( dataset: Union[Imaging, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), + preloads: Preloads = None, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -46,6 +48,7 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, + preloads=preloads, ) @property diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index b1b39472c..b725adb6c 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -1,4 +1,3 @@ -import copy import numpy as np from typing import Dict, List, Optional, Union @@ -14,6 +13,7 @@ from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.structures.arrays.uniform_2d import Array2D +from autoarray import exc from autoarray.inversion.inversion import inversion_util from autoarray.inversion.inversion.imaging import inversion_imaging_util @@ -49,6 +49,17 @@ def __init__( the simultaneous linear equations are combined and solved simultaneously. """ + try: + import numba + except ModuleNotFoundError: + raise exc.InversionException( + "Inversion functionality (linear light profiles, pixelized reconstructions) is " + "disabled if numba is not installed.\n\n" + "This is because the run-times without numba are too slow.\n\n" + "Please install numba, which is described at the following web page:\n\n" + "https://pyautolens.readthedocs.io/en/latest/installation/overview.html" + ) + super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, @@ -94,9 +105,11 @@ def _data_vector_mapper(self) -> np.ndarray: data_vector_mapper = ( inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( w_tilde_data=self.w_tilde_data, - data_to_pix_unique=mapper.unique_mappings.data_to_pix_unique, - data_weights=mapper.unique_mappings.data_weights, - pix_lengths=mapper.unique_mappings.pix_lengths, + data_to_pix_unique=np.array( + mapper.unique_mappings.data_to_pix_unique + ), + data_weights=np.array(mapper.unique_mappings.data_weights), + pix_lengths=np.array(mapper.unique_mappings.pix_lengths), pix_pixels=mapper.params, ) ) @@ -276,9 +289,11 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: curvature_preload=self.w_tilde.curvature_preload, curvature_indexes=self.w_tilde.indexes, curvature_lengths=self.w_tilde.lengths, - data_to_pix_unique=mapper_i.unique_mappings.data_to_pix_unique, - data_weights=mapper_i.unique_mappings.data_weights, - pix_lengths=mapper_i.unique_mappings.pix_lengths, + data_to_pix_unique=np.array( + mapper_i.unique_mappings.data_to_pix_unique + ), + data_weights=np.array(mapper_i.unique_mappings.data_weights), + pix_lengths=np.array(mapper_i.unique_mappings.pix_lengths), pix_pixels=mapper_i.params, ) diff --git a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py index 6096f71bc..120f1c31b 100644 --- a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py +++ b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py @@ -1833,3 +1833,42 @@ def curvature_matrix_via_w_tilde_curvature_preload_interferometer_para_from( ) print("finished 3rd loop.") return curvature_matrix + + +@numba_util.jit() +def sub_slim_indexes_for_pix_index( + pix_indexes_for_sub_slim_index: np.ndarray, + pix_weights_for_sub_slim_index: np.ndarray, + pix_pixels: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + sub_slim_sizes_for_pix_index = np.zeros(pix_pixels) + + for pix_indexes in pix_indexes_for_sub_slim_index: + for pix_index in pix_indexes: + sub_slim_sizes_for_pix_index[pix_index] += 1 + + max_pix_size = np.max(sub_slim_sizes_for_pix_index) + + sub_slim_indexes_for_pix_index = -1 * np.ones(shape=(pix_pixels, int(max_pix_size))) + sub_slim_weights_for_pix_index = -1 * np.ones(shape=(pix_pixels, int(max_pix_size))) + sub_slim_sizes_for_pix_index = np.zeros(pix_pixels) + + for slim_index, pix_indexes in enumerate(pix_indexes_for_sub_slim_index): + pix_weights = pix_weights_for_sub_slim_index[slim_index] + + for pix_index, pix_weight in zip(pix_indexes, pix_weights): + sub_slim_indexes_for_pix_index[ + pix_index, int(sub_slim_sizes_for_pix_index[pix_index]) + ] = slim_index + + sub_slim_weights_for_pix_index[ + pix_index, int(sub_slim_sizes_for_pix_index[pix_index]) + ] = pix_weight + + sub_slim_sizes_for_pix_index[pix_index] += 1 + + return ( + sub_slim_indexes_for_pix_index, + sub_slim_sizes_for_pix_index, + sub_slim_weights_for_pix_index, + ) diff --git a/autoarray/inversion/inversion/interferometer/w_tilde.py b/autoarray/inversion/inversion/interferometer/w_tilde.py index bac21b883..8a3656fa2 100644 --- a/autoarray/inversion/inversion/interferometer/w_tilde.py +++ b/autoarray/inversion/inversion/interferometer/w_tilde.py @@ -130,7 +130,11 @@ def curvature_matrix_diag(self) -> np.ndarray: sub_slim_indexes_for_pix_index, sub_slim_sizes_for_pix_index, sub_slim_weights_for_pix_index, - ) = mapper.sub_slim_indexes_for_pix_index_arr + ) = inversion_interferometer_util.sub_slim_indexes_for_pix_index( + pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, + pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, + pix_pixels=mapper.pixels, + ) return inversion_interferometer_util.curvature_matrix_via_w_tilde_curvature_preload_interferometer_from_2( curvature_preload=self.w_tilde.curvature_preload, diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index e6f0d766f..457723957 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -2,7 +2,7 @@ import jax.lax as lax import numpy as np -from typing import List, Optional +from typing import List, Optional, Type from autoconf import conf @@ -346,3 +346,48 @@ def preconditioner_matrix_via_mapping_matrix_from( return ( preconditioner_noise_normalization * curvature_matrix ) + regularization_matrix + + +def param_range_list_from(cls: Type, linear_obj_list) -> List[List[int]]: + """ + Each linear object in the `Inversion` has N parameters, and these parameters correspond to a certain range + of indexing values in the matrices used to perform the inversion. + + This function returns the `param_range_list` of an input type of linear object, which gives the indexing range + of each linear object of the input type. + + For example, if an `Inversion` has: + + - A `LinearFuncList` linear object with 3 `params`. + - A `Mapper` with 100 `params`. + - A `Mapper` with 200 `params`. + + The corresponding matrices of this inversion (e.g. the `curvature_matrix`) have `shape=(303, 303)` where: + + - The `LinearFuncList` values are in the entries `[0:3]`. + - The first `Mapper` values are in the entries `[3:103]`. + - The second `Mapper` values are in the entries `[103:303] + + For this example, `param_range_list_from(cls=AbstractMapper)` therefore returns the + list `[[3, 103], [103, 303]]`. + + Parameters + ---------- + cls + The type of class that the list of their parameter range index values are returned for. + + Returns + ------- + A list of the index range of the parameters of each linear object in the inversion of the input cls type. + """ + index_list = [] + + pixel_count = 0 + + for linear_obj in linear_obj_list: + if isinstance(linear_obj, cls): + index_list.append([pixel_count, pixel_count + linear_obj.params]) + + pixel_count += linear_obj.params + + return index_list diff --git a/autoarray/inversion/inversion/settings.py b/autoarray/inversion/inversion/settings.py index 184e16977..3deab4a6e 100644 --- a/autoarray/inversion/inversion/settings.py +++ b/autoarray/inversion/inversion/settings.py @@ -10,12 +10,11 @@ class SettingsInversion: def __init__( self, - use_w_tilde: bool = True, + use_w_tilde: bool = False, use_positive_only_solver: Optional[bool] = None, positive_only_uses_p_initial: Optional[bool] = None, use_border_relocator: Optional[bool] = None, force_edge_pixels_to_zeros: bool = True, - image_pixels_source_zero=None, no_regularization_add_to_curvature_diag_value: float = None, use_w_tilde_numpy: bool = False, use_source_loop: bool = False, @@ -83,7 +82,6 @@ def __init__( self._use_border_relocator = use_border_relocator self.use_linear_operators = use_linear_operators self.force_edge_pixels_to_zeros = force_edge_pixels_to_zeros - self.image_pixels_source_zero = image_pixels_source_zero self._no_regularization_add_to_curvature_diag_value = ( no_regularization_add_to_curvature_diag_value ) diff --git a/autoarray/inversion/pixelization/border_relocator.py b/autoarray/inversion/pixelization/border_relocator.py index 88ec78091..e46cbef1e 100644 --- a/autoarray/inversion/pixelization/border_relocator.py +++ b/autoarray/inversion/pixelization/border_relocator.py @@ -1,8 +1,7 @@ from __future__ import annotations +import jax.numpy as jnp import numpy as np -from typing import Union - -from autoconf import cached_property +from typing import Tuple, Union from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.uniform_2d import Array2D @@ -54,7 +53,7 @@ def sub_slim_indexes_for_slim_index_via_mask_2d_from( slim_index_for_sub_slim_indexes = ( over_sample_util.slim_index_for_sub_slim_index_via_mask_2d_from( - mask_2d=mask_2d, sub_size=np.array(sub_size) + mask_2d=mask_2d, sub_size=sub_size ).astype("int") ) @@ -64,6 +63,43 @@ def sub_slim_indexes_for_slim_index_via_mask_2d_from( return sub_slim_indexes_for_slim_index +def furthest_grid_2d_slim_index_from( + grid_2d_slim: np.ndarray, slim_indexes: np.ndarray, coordinate: Tuple[float, float] +) -> int: + """ + Returns the index in `slim_indexes` corresponding to the 2D point in `grid_2d_slim` + that is furthest from a given coordinate, measured by squared Euclidean distance. + + Parameters + ---------- + grid_2d_slim + A 2D array of shape (N, 2), where each row is a (y, x) coordinate. + slim_indexes + An array of indices into `grid_2d_slim` specifying which coordinates to consider. + coordinate + The (y, x) coordinate from which distances are calculated. + + Returns + ------- + int + The slim index of the point in `grid_2d_slim[slim_indexes]` that is furthest from `coordinate`. + """ + subgrid = grid_2d_slim[slim_indexes] + dy = subgrid[:, 0] - coordinate[0] + dx = subgrid[:, 1] - coordinate[1] + squared_distances = dx**2 + dy**2 + + max_dist = np.max(squared_distances) + + # Find all indices with max distance + max_positions = np.where(squared_distances == max_dist)[0] + + # Choose the last one (to match original loop behavior) + max_index = max_positions[-1] + + return slim_indexes[max_index] + + def sub_border_pixel_slim_indexes_from( mask_2d: np.ndarray, sub_size: Array2D ) -> np.ndarray: @@ -107,7 +143,7 @@ def sub_border_pixel_slim_indexes_from( sub_grid_2d_slim = over_sample_util.grid_2d_slim_over_sampled_via_mask_from( mask_2d=mask_2d, pixel_scales=(1.0, 1.0), - sub_size=np.array(sub_size), + sub_size=sub_size, origin=(0.0, 0.0), ) mask_centre = grid_2d_util.grid_2d_centre_from(grid_2d_slim=sub_grid_2d_slim) @@ -117,129 +153,176 @@ def sub_border_pixel_slim_indexes_from( int(border_pixel) ] - sub_border_pixels[border_1d_index] = ( - grid_2d_util.furthest_grid_2d_slim_index_from( - grid_2d_slim=sub_grid_2d_slim, - slim_indexes=sub_border_pixels_of_border_pixel, - coordinate=mask_centre, - ) + sub_border_pixels[border_1d_index] = furthest_grid_2d_slim_index_from( + grid_2d_slim=sub_grid_2d_slim, + slim_indexes=sub_border_pixels_of_border_pixel, + coordinate=mask_centre, ) return sub_border_pixels -class BorderRelocator: - def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): - self.mask = mask +def sub_border_slim_from(mask, sub_size): + """ + Returns the subgridded 1D ``slim`` indexes of border pixels in the ``Mask2D``, representing all unmasked + sub-pixels (given by ``False``) which neighbor any masked value (give by ``True``) and which are on the + extreme exterior of the mask. - self.sub_size = over_sample_util.over_sample_size_convert_to_array_2d_from( - over_sample_size=sub_size, mask=mask + The indexes are the sub-gridded extension of the ``border_slim`` which is illustrated above. + + This quantity is too complicated to write-out in a docstring, and it is recommended you print it in + Python code to understand it if anything is unclear. + + Examples + -------- + + .. code-block:: python + + import autoarray as aa + + mask_2d = aa.Mask2D( + mask=[[True, True, True, True, True, True, True, True, True], + [True, False, False, False, False, False, False, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, False, True, False, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, False, False, False, False, False, False, True], + [True, True, True, True, True, True, True, True, True]] + pixel_scales=1.0, ) - @cached_property - def border_slim(self): - """ - Returns the 1D ``slim`` indexes of border pixels in the ``Mask2D``, representing all unmasked - sub-pixels (given by ``False``) which neighbor any masked value (give by ``True``) and which are on the - extreme exterior of the mask. + derive_indexes_2d = aa.DeriveIndexes2D(mask=mask_2d) - The indexes are the extended below to form the ``sub_border_slim`` which is illustrated above. + print(derive_indexes_2d.sub_border_slim) + """ + return sub_border_pixel_slim_indexes_from( + mask_2d=mask, sub_size=sub_size.astype("int") + ).astype("int") - This quantity is too complicated to write-out in a docstring, and it is recommended you print it in - Python code to understand it if anything is unclear. - Examples - -------- +def relocated_grid_from(grid, border_grid): + """ + Relocate the coordinates of a grid to its border if they are outside the border, where the border is + defined as all pixels at the edge of the grid's mask (see *mask._border_1d_indexes*). - .. code-block:: python + This is performed as follows: - import autoarray as aa + 1: Use the mean value of the grid's y and x coordinates to determine the origin of the grid. + 2: Compute the radial distance of every grid coordinate from the origin. + 3: For every coordinate, find its nearest pixel in the border. + 4: Determine if it is outside the border, by comparing its radial distance from the origin to its paired + border pixel's radial distance. + 5: If its radial distance is larger, use the ratio of radial distances to move the coordinate to the + border (if its inside the border, do nothing). - mask_2d = aa.Mask2D( - mask=[[True, True, True, True, True, True, True, True, True], - [True, False, False, False, False, False, False, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, False, True, False, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, False, False, False, False, False, False, True], - [True, True, True, True, True, True, True, True, True]] - pixel_scales=1.0, - ) + The method can be used on uniform or irregular grids, however for irregular grids the border of the + 'image-plane' mask is used to define border pixels. - derive_indexes_2d = aa.DeriveIndexes2D(mask=mask_2d) + Parameters + ---------- + grid + The grid (uniform or irregular) whose pixels are to be relocated to the border edge if outside it. + border_grid : Grid2D + The grid of border (y,x) coordinates. + """ - print(derive_indexes_2d.border_slim) - """ - return self.mask.derive_indexes.border_slim + # Compute origin (center) of the border grid + border_origin = jnp.mean(border_grid, axis=0) - @cached_property - def sub_border_slim(self) -> np.ndarray: - """ - Returns the subgridded 1D ``slim`` indexes of border pixels in the ``Mask2D``, representing all unmasked - sub-pixels (given by ``False``) which neighbor any masked value (give by ``True``) and which are on the - extreme exterior of the mask. + # Radii from origin + grid_radii = jnp.linalg.norm(grid - border_origin, axis=1) # (N,) + border_radii = jnp.linalg.norm(border_grid - border_origin, axis=1) # (M,) + border_min_radius = jnp.min(border_radii) - The indexes are the sub-gridded extension of the ``border_slim`` which is illustrated above. + # Determine which points are outside + outside_mask = grid_radii > border_min_radius # (N,) - This quantity is too complicated to write-out in a docstring, and it is recommended you print it in - Python code to understand it if anything is unclear. + # To compute nearest border point for each grid point, we must do it for all and then mask later + # Compute all distances: (N, M) + diffs = grid[:, None, :] - border_grid[None, :, :] # (N, M, 2) + dists_squared = jnp.sum(diffs**2, axis=2) # (N, M) + closest_indices = jnp.argmin(dists_squared, axis=1) # (N,) - Examples - -------- + # Get border radius for closest border point to each grid point + matched_border_radii = border_radii[closest_indices] # (N,) - .. code-block:: python + # Ratio of border to grid radius + move_factors = matched_border_radii / grid_radii # (N,) - import autoarray as aa + # Only move if: + # - the point is outside the border + # - the matched border point is closer to the origin (i.e. move_factor < 1) + apply_move = jnp.logical_and(outside_mask, move_factors < 1.0) # (N,) - mask_2d = aa.Mask2D( - mask=[[True, True, True, True, True, True, True, True, True], - [True, False, False, False, False, False, False, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, False, True, False, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, False, False, False, False, False, False, True], - [True, True, True, True, True, True, True, True, True]] - pixel_scales=1.0, - ) + # Compute moved positions (for all points, but will select with mask) + direction_vectors = grid - border_origin # (N, 2) + moved_grid = move_factors[:, None] * direction_vectors + border_origin # (N, 2) - derive_indexes_2d = aa.DeriveIndexes2D(mask=mask_2d) + # Select which grid points to move + relocated_grid = jnp.where(apply_move[:, None], moved_grid, grid) # (N, 2) - print(derive_indexes_2d.sub_border_slim) - """ - return sub_border_pixel_slim_indexes_from( - mask_2d=np.array(self.mask), sub_size=np.array(self.sub_size).astype("int") - ).astype("int") + return relocated_grid - @cached_property - def border_grid(self) -> np.ndarray: - """ - The (y,x) grid of all sub-pixels which are at the border of the mask. - This is NOT all sub-pixels which are in mask pixels at the mask's border, but specifically the sub-pixels - within these border pixels which are at the extreme edge of the border. +class BorderRelocator: + def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): """ - return self.mask.derive_grid.border + Relocates source plane coordinates that trace outside the mask’s border in the source-plane back onto the + border. - @cached_property - def sub_border_grid(self) -> np.ndarray: - """ - The (y,x) grid of all sub-pixels which are at the border of the mask. + Given an input mask and (optionally) a per‐pixel sub‐sampling size, this class computes: + + 1. `border_grid`: the (y,x) coordinates of every border pixel of the mask. + 2. `sub_border_grid`: an over‐sampled border grid if sub‐sampling is requested. + 3. `relocated_grid(grid)`: for any arbitrary grid of points (uniform or irregular), returns a new grid + where any point whose radius from the mask center exceeds the minimum radius of the border is + moved radially inward until it lies exactly on its nearest border pixel. + + In practice this ensures that “outlier” rays or source‐plane pixels don’t fall outside the allowed + mask region when performing pixelization–based inversions or lens‐plane mappings. + + See Figure 2 of https://arxiv.org/abs/1708.07377 for a description of why this functionality is required. - This is NOT all sub-pixels which are in mask pixels at the mask's border, but specifically the sub-pixels - within these border pixels which are at the extreme edge of the border. + Attributes + ---------- + mask : Mask2D + The input mask whose border defines the permissible region. + sub_size : Array2D + Per‐pixel sub‐sampling size (can be constant or spatially varying). + border_slim : np.ndarray + 1D indexes of the mask’s border pixels in the slimmed representation. + sub_border_slim : np.ndarray + 1D indexes of the over‐sampled (sub) border pixels. + border_grid : np.ndarray + Array of (y,x) coordinates for each border pixel. + sub_border_grid : np.ndarray + Array of (y,x) coordinates for each over‐sampled border pixel. """ + self.mask = mask + + self.sub_size = over_sample_util.over_sample_size_convert_to_array_2d_from( + over_sample_size=sub_size, mask=mask + ) + + self.border_slim = self.mask.derive_indexes.border_slim + self.sub_border_slim = sub_border_slim_from( + mask=self.mask, sub_size=self.sub_size + ) + try: + self.border_grid = self.mask.derive_grid.border + except TypeError: + self.border_grid = None + sub_grid = over_sample_util.grid_2d_slim_over_sampled_via_mask_from( - mask_2d=np.array(self.mask), + mask_2d=self.mask, pixel_scales=self.mask.pixel_scales, - sub_size=np.array(self.sub_size).astype("int"), + sub_size=self.sub_size.astype("int"), origin=self.mask.origin, ) - return sub_grid[self.sub_border_slim] + self.sub_border_grid = sub_grid[self.sub_border_slim] def relocated_grid_from(self, grid: Grid2D) -> Grid2D: """ @@ -268,14 +351,14 @@ def relocated_grid_from(self, grid: Grid2D) -> Grid2D: if len(self.sub_border_grid) == 0: return grid - values = grid_2d_util.relocated_grid_via_jit_from( - grid=np.array(grid.array), - border_grid=np.array(grid.array[self.border_slim]), + values = relocated_grid_from( + grid=grid.array, + border_grid=grid.array[self.border_slim], ) - over_sampled = grid_2d_util.relocated_grid_via_jit_from( - grid=np.array(grid.over_sampled.array), - border_grid=np.array(grid.over_sampled.array[self.sub_border_slim]), + over_sampled = relocated_grid_from( + grid=grid.over_sampled.array, + border_grid=grid.over_sampled.array[self.sub_border_slim], ) return Grid2D( @@ -302,8 +385,8 @@ def relocated_mesh_grid_from( return mesh_grid return Grid2DIrregular( - values=grid_2d_util.relocated_grid_via_jit_from( - grid=np.array(mesh_grid.array), - border_grid=np.array(grid[self.sub_border_slim]), + values=relocated_grid_from( + grid=mesh_grid.array, + border_grid=grid[self.sub_border_slim], ), ) diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index 291f5bed6..f2f6b03b6 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -207,28 +207,6 @@ def sub_slim_indexes_for_pix_index(self) -> List[List]: return sub_slim_indexes_for_pix_index - @property - def sub_slim_indexes_for_pix_index_arr( - self, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Returns the index mappings between each of the pixelization's pixels and the masked data's sub-pixels. - - Given that even pixelization pixel maps to multiple data sub-pixels, index mappings are returned as a list of - lists where the first entries are the pixelization index and second entries store the data sub-pixel indexes. - - For example, if `sub_slim_indexes_for_pix_index[2][4] = 10`, the pixelization pixel with index 2 - (e.g. `mesh_grid[2,:]`) has a mapping to a data sub-pixel with index 10 (e.g. `grid_slim[10, :]). - - This is effectively a reversal of the array `pix_indexes_for_sub_slim_index`. - """ - - return mapper_util.sub_slim_indexes_for_pix_index( - pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=self.pix_weights_for_sub_slim_index, - pix_pixels=self.pixels, - ) - @cached_property def unique_mappings(self) -> UniqueMappings: """ @@ -249,9 +227,13 @@ def unique_mappings(self) -> UniqueMappings: pix_lengths, ) = mapper_util.data_slim_to_pixelization_unique_from( data_pixels=self.over_sampler.mask.pixels_in_mask, - pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=self.pix_sizes_for_sub_slim_index, - pix_weights_for_sub_slim_index=self.pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=np.array( + self.pix_indexes_for_sub_slim_index + ), + pix_sizes_for_sub_slim_index=np.array(self.pix_sizes_for_sub_slim_index), + pix_weights_for_sub_slim_index=np.array( + self.pix_weights_for_sub_slim_index + ), pix_pixels=self.params, sub_size=np.array(self.over_sampler.sub_size).astype("int"), ) @@ -275,6 +257,7 @@ def mapping_matrix(self) -> np.ndarray: It is described in the following paper as matrix `f` https://arxiv.org/pdf/astro-ph/0302587.pdf and in more detail in the function `mapper_util.mapping_matrix_from()`. """ + return mapper_util.mapping_matrix_from( pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index, @@ -282,7 +265,7 @@ def mapping_matrix(self) -> np.ndarray: pixels=self.pixels, total_mask_pixels=self.over_sampler.mask.pixels_in_mask, slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index, - sub_fraction=np.array(self.over_sampler.sub_fraction), + sub_fraction=self.over_sampler.sub_fraction.array, ) def pixel_signals_from(self, signal_scale: float) -> np.ndarray: @@ -355,8 +338,12 @@ def data_weight_total_for_pix_from(self) -> np.ndarray: """ return mapper_util.data_weight_total_for_pix_from( - pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=self.pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=np.array( + self.pix_indexes_for_sub_slim_index + ), + pix_weights_for_sub_slim_index=np.array( + self.pix_weights_for_sub_slim_index + ), pixels=self.pixels, ) @@ -379,8 +366,8 @@ def mapped_to_source_from(self, array: Array2D) -> np.ndarray: source domain in order to compute their average values. """ return mapper_util.mapped_to_source_via_mapping_matrix_from( - mapping_matrix=self.mapping_matrix, - array_slim=np.array(array.slim), + mapping_matrix=np.array(self.mapping_matrix), + array_slim=array.slim, ) def extent_from( diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index 75b91c042..c3e8d470a 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -9,45 +9,6 @@ from autoarray.inversion.pixelization.mesh import mesh_util -@numba_util.jit() -def sub_slim_indexes_for_pix_index( - pix_indexes_for_sub_slim_index: np.ndarray, - pix_weights_for_sub_slim_index: np.ndarray, - pix_pixels: int, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - sub_slim_sizes_for_pix_index = np.zeros(pix_pixels) - - for pix_indexes in pix_indexes_for_sub_slim_index: - for pix_index in pix_indexes: - sub_slim_sizes_for_pix_index[pix_index] += 1 - - max_pix_size = np.max(sub_slim_sizes_for_pix_index) - - sub_slim_indexes_for_pix_index = -1 * np.ones(shape=(pix_pixels, int(max_pix_size))) - sub_slim_weights_for_pix_index = -1 * np.ones(shape=(pix_pixels, int(max_pix_size))) - sub_slim_sizes_for_pix_index = np.zeros(pix_pixels) - - for slim_index, pix_indexes in enumerate(pix_indexes_for_sub_slim_index): - pix_weights = pix_weights_for_sub_slim_index[slim_index] - - for pix_index, pix_weight in zip(pix_indexes, pix_weights): - sub_slim_indexes_for_pix_index[ - pix_index, int(sub_slim_sizes_for_pix_index[pix_index]) - ] = slim_index - - sub_slim_weights_for_pix_index[ - pix_index, int(sub_slim_sizes_for_pix_index[pix_index]) - ] = pix_weight - - sub_slim_sizes_for_pix_index[pix_index] += 1 - - return ( - sub_slim_indexes_for_pix_index, - sub_slim_sizes_for_pix_index, - sub_slim_weights_for_pix_index, - ) - - @numba_util.jit() def data_slim_to_pixelization_unique_from( data_pixels, @@ -498,7 +459,6 @@ def remove_bad_entries_voronoi_nn( return pix_weights_for_sub_slim_index, pix_indexes_for_sub_slim_index -@numba_util.jit() def adaptive_pixel_signals_from( pixels: int, pixel_weights: np.ndarray, @@ -536,33 +496,47 @@ def adaptive_pixel_signals_from( The image of the galaxy which is used to compute the weigghted pixel signals. """ - pixel_signals = np.zeros((pixels,)) - pixel_sizes = np.zeros((pixels,)) + M_sub, B = pix_indexes_for_sub_slim_index.shape - for sub_slim_index in range(len(pix_indexes_for_sub_slim_index)): - vertices_indexes = pix_indexes_for_sub_slim_index[sub_slim_index] + # 1) Flatten the per‐mapping tables: + flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,) + flat_weights = pixel_weights.reshape(-1) # (M_sub*B,) - mask_1d_index = slim_index_for_sub_slim_index[sub_slim_index] + # 2) Build a matching “parent‐slim” index for each flattened entry: + I_sub = jnp.repeat(jnp.arange(M_sub), B) # (M_sub*B,) - pix_size_tem = pix_size_for_sub_slim_index[sub_slim_index] + # 3) Mask out any k >= pix_size_for_sub_slim_index[i] + valid = I_sub < 0 # dummy to get shape + # better: + valid = (jnp.arange(B)[None, :] < pix_size_for_sub_slim_index[:, None]).reshape(-1) - if pix_size_tem > 1: - pixel_signals[vertices_indexes[:pix_size_tem]] += ( - adapt_data[mask_1d_index] * pixel_weights[sub_slim_index] - ) - pixel_sizes[vertices_indexes] += 1 - else: - pixel_signals[vertices_indexes[0]] += adapt_data[mask_1d_index] - pixel_sizes[vertices_indexes[0]] += 1 + flat_weights = jnp.where(valid, flat_weights, 0.0) + flat_pixidx = jnp.where( + valid, flat_pixidx, pixels + ) # send invalid indices to an out-of-bounds slot + + # 4) Look up data & multiply by mapping weights: + flat_data_vals = adapt_data[slim_index_for_sub_slim_index][I_sub] # (M_sub*B,) + flat_contrib = flat_data_vals * flat_weights # (M_sub*B,) - pixel_sizes[pixel_sizes == 0] = 1 - pixel_signals /= pixel_sizes - pixel_signals /= np.max(pixel_signals) + # 5) Scatter‐add into signal sums and counts: + pixel_signals = jnp.zeros((pixels + 1,)).at[flat_pixidx].add(flat_contrib) + pixel_counts = jnp.zeros((pixels + 1,)).at[flat_pixidx].add(valid.astype(float)) + # 6) Drop the extra “out-of-bounds” slot: + pixel_signals = pixel_signals[:pixels] + pixel_counts = pixel_counts[:pixels] + + # 7) Normalize + pixel_counts = jnp.where(pixel_counts > 0, pixel_counts, 1.0) + pixel_signals = pixel_signals / pixel_counts + max_sig = jnp.max(pixel_signals) + pixel_signals = jnp.where(max_sig > 0, pixel_signals / max_sig, pixel_signals) + + # 8) Exponentiate return pixel_signals**signal_scale -@numba_util.jit() def mapping_matrix_from( pix_indexes_for_sub_slim_index: np.ndarray, pix_size_for_sub_slim_index: np.ndarray, @@ -643,87 +617,110 @@ def mapping_matrix_from( sub_fraction The fractional area each sub-pixel takes up in an pixel. """ + M_sub, B = pix_indexes_for_sub_slim_index.shape + M = total_mask_pixels + S = pixels - mapping_matrix = np.zeros((total_mask_pixels, pixels)) + # 1) Flatten + flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,) + flat_w = pix_weights_for_sub_slim_index.reshape(-1) # (M_sub*B,) + flat_parent = jnp.repeat(slim_index_for_sub_slim_index, B) # (M_sub*B,) + flat_count = jnp.repeat(pix_size_for_sub_slim_index, B) # (M_sub*B,) - for sub_slim_index in range(slim_index_for_sub_slim_index.shape[0]): - slim_index = slim_index_for_sub_slim_index[sub_slim_index] + # 2) Build valid mask: k < pix_size[i] + k = jnp.tile(jnp.arange(B), M_sub) # (M_sub*B,) + valid = k < flat_count # (M_sub*B,) - for pix_count in range(pix_size_for_sub_slim_index[sub_slim_index]): - pix_index = pix_indexes_for_sub_slim_index[sub_slim_index, pix_count] - pix_weight = pix_weights_for_sub_slim_index[sub_slim_index, pix_count] + # 3) Zero out invalid weights + flat_w = flat_w * valid.astype(flat_w.dtype) - mapping_matrix[slim_index][pix_index] += ( - sub_fraction[slim_index] * pix_weight - ) + # 4) Redirect -1 indices to extra bin S + OUT = S + flat_pixidx = jnp.where(flat_pixidx < 0, OUT, flat_pixidx) - return mapping_matrix + # 5) Multiply by sub_fraction of the slim row + flat_frac = sub_fraction[flat_parent] # (M_sub*B,) + flat_contrib = flat_w * flat_frac # (M_sub*B,) + + # 6) Scatter into (M × (S+1)), summing duplicates + mat = jnp.zeros((M, S + 1), dtype=flat_contrib.dtype) + mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib) + + # 7) Drop the extra column and return + return mat[:, :S] -@numba_util.jit() def mapped_to_source_via_mapping_matrix_from( mapping_matrix: np.ndarray, array_slim: np.ndarray ) -> np.ndarray: """ - Map a masked 2d image in the image domain to the source domain and sum up all mappings on the source-pixels. - - For example, suppose we have an image and a mapper. We can map every image-pixel to its corresponding mapper's - source pixel and sum the values based on these mappings. + Map a masked 2D image (in slim form) into the source plane by summing and averaging + each image-pixel's contribution to its mapped source-pixels. - This will produce something similar to a `reconstruction`, albeit it bypasses the linear algebra / inversion. + Each row i of `mapping_matrix` describes how image-pixel i is distributed (with + weights) across the source-pixels j. `array_slim[i]` is then multiplied by those + weights and summed over i to give each source-pixel’s total mapped value; finally, + we divide by the number of nonzero contributions to form an average. Parameters ---------- - mapping_matrix - The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. - array_slim - The masked 2D array of values in its slim representation (e.g. the image data) which are mapped to the - source domain in order to compute their average values. - """ - - mapped_to_source = np.zeros(mapping_matrix.shape[1]) + mapping_matrix : ndarray of shape (M, N) + mapping_matrix[i, j] ≥ 0 is the weight by which image-pixel i contributes to + source-pixel j. Zero means “no contribution.” + array_slim : ndarray of shape (M,) + The slimmed image values for each image-pixel i. - source_pixel_count = np.zeros(mapping_matrix.shape[1]) + Returns + ------- + mapped_to_source : ndarray of shape (N,) + The averaged, mapped values on each of the N source-pixels. + """ + # weighted sums: sum over i of array_slim[i] * mapping_matrix[i, j] + # ==> vector‐matrix multiply: (1×M) dot (M×N) → (N,) + mapped_to_source = array_slim @ mapping_matrix - for i in range(mapping_matrix.shape[0]): - for j in range(mapping_matrix.shape[1]): - if mapping_matrix[i, j] > 0: - mapped_to_source[j] += array_slim[i] * mapping_matrix[i, j] - source_pixel_count[j] += 1 + # count how many nonzero contributions each source-pixel j received + counts = np.count_nonzero(mapping_matrix > 0.0, axis=0) - for j in range(mapping_matrix.shape[1]): - if source_pixel_count[j] > 0: - mapped_to_source[j] /= source_pixel_count[j] + # avoid division by zero: only divide where counts > 0 + nonzero = counts > 0 + mapped_to_source[nonzero] /= counts[nonzero] return mapped_to_source -@numba_util.jit() def data_weight_total_for_pix_from( - pix_indexes_for_sub_slim_index: np.ndarray, - pix_weights_for_sub_slim_index: np.ndarray, + pix_indexes_for_sub_slim_index: np.ndarray, # shape (M, B) + pix_weights_for_sub_slim_index: np.ndarray, # shape (M, B) pixels: int, ) -> np.ndarray: """ - Returns the total weight of every pixelization pixel, which is the sum of the weights of all data-points that - map to that pixel. + Returns the total weight of every pixelization pixel, which is the sum of + the weights of all data‐points (sub‐pixels) that map to that pixel. Parameters ---------- - pix_indexes_for_sub_slim_index - The mappings from a data sub-pixel index to a pixelization pixel index. - pix_weights_for_sub_slim_index - The weights of the mappings of every data sub-pixel and pixelization pixel. - pixels - The number of pixels in the pixelization. - """ + pix_indexes_for_sub_slim_index : np.ndarray, shape (M, B), int + For each of M sub‐slim indexes, the B pixelization‐pixel indices it maps to. + pix_weights_for_sub_slim_index : np.ndarray, shape (M, B), float + For each of those mappings, the corresponding interpolation weight. + pixels : int + The total number of pixelization pixels N. - pix_weight_total = np.zeros(pixels) + Returns + ------- + np.ndarray, shape (N,) + The per‐pixel total weight: for each j in [0..N-1], the sum of all + pix_weights_for_sub_slim_index[i,k] such that pix_indexes_for_sub_slim_index[i,k] == j. + """ + # Flatten arrays + flat_idxs = pix_indexes_for_sub_slim_index.ravel() + flat_weights = pix_weights_for_sub_slim_index.ravel() - for slim_index, pix_indexes in enumerate(pix_indexes_for_sub_slim_index): - for pix_index, weight in zip( - pix_indexes, pix_weights_for_sub_slim_index[slim_index] - ): - pix_weight_total[int(pix_index)] += weight + # Filter out -1 (invalid mappings) + valid_mask = flat_idxs >= 0 + flat_idxs = flat_idxs[valid_mask] + flat_weights = flat_weights[valid_mask] - return pix_weight_total + # Sum weights by pixel index + return np.bincount(flat_idxs, weights=flat_weights, minlength=pixels) diff --git a/autoarray/inversion/pixelization/mappers/rectangular.py b/autoarray/inversion/pixelization/mappers/rectangular.py index 878ab8233..8ff2fa0f2 100644 --- a/autoarray/inversion/pixelization/mappers/rectangular.py +++ b/autoarray/inversion/pixelization/mappers/rectangular.py @@ -102,14 +102,12 @@ def pix_sub_weights(self) -> PixSubWeights: mapper_util.rectangular_mappings_weights_via_interpolation_from( shape_native=self.shape_native, source_plane_mesh_grid=self.source_plane_mesh_grid.array, - source_plane_data_grid=Grid2DIrregular( - self.source_plane_data_grid.over_sampled - ).array, + source_plane_data_grid=self.source_plane_data_grid.over_sampled, ) ) return PixSubWeights( - mappings=np.array(mappings), - sizes=4 * np.ones(len(mappings), dtype="int"), - weights=np.array(weights), + mappings=mappings, + sizes=4 * jnp.ones(len(mappings), dtype="int"), + weights=weights, ) diff --git a/autoarray/inversion/pixelization/mesh/abstract.py b/autoarray/inversion/pixelization/mesh/abstract.py index d23a11cd7..772e02c05 100644 --- a/autoarray/inversion/pixelization/mesh/abstract.py +++ b/autoarray/inversion/pixelization/mesh/abstract.py @@ -41,7 +41,13 @@ def relocated_grid_from( """ if border_relocator is not None: return border_relocator.relocated_grid_from(grid=source_plane_data_grid) - return source_plane_data_grid + + return Grid2D( + values=source_plane_data_grid.array, + mask=source_plane_data_grid.mask, + over_sample_size=source_plane_data_grid.over_sampler.sub_size, + over_sampled=source_plane_data_grid.over_sampled.array, + ) def relocated_mesh_grid_from( self, diff --git a/autoarray/inversion/regularization/adaptive_brightness_split.py b/autoarray/inversion/regularization/adaptive_brightness_split.py index 7b09993db..b781ef7a6 100644 --- a/autoarray/inversion/regularization/adaptive_brightness_split.py +++ b/autoarray/inversion/regularization/adaptive_brightness_split.py @@ -22,8 +22,7 @@ def __init__( adapted to the data being fitted to smooth an inversion's solution. An adaptive regularization scheme which splits every source pixel into a cross of four regularization points - and interpolates to these points in order - to smooth an inversion's solution. + and interpolates to these points in order to smooth an inversion's solution. The size of this cross is determined via the size of the source-pixel, for example if the source pixel is a Voronoi pixel the area of the pixel is computed and the distance of each point of the cross is given by diff --git a/autoarray/inversion/regularization/regularization_util.py b/autoarray/inversion/regularization/regularization_util.py index 291e91928..cf0c6dc71 100644 --- a/autoarray/inversion/regularization/regularization_util.py +++ b/autoarray/inversion/regularization/regularization_util.py @@ -203,7 +203,7 @@ def brightness_zeroth_regularization_weights_from( return coefficient * (1.0 - pixel_signals) -@numba_util.jit() +# @numba_util.jit() def weighted_regularization_matrix_from( regularization_weights: np.ndarray, neighbors: np.ndarray, @@ -237,61 +237,98 @@ def weighted_regularization_matrix_from( The regularization matrix computed using an adaptive regularization scheme where the effective regularization coefficient of every source pixel is different. """ - parameters = len(regularization_weights) - - regularization_matrix = np.zeros(shape=(parameters, parameters)) - + regularization_matrix = np.zeros((parameters, parameters)) regularization_weight = regularization_weights**2.0 + # Add small diagonal offset + np.fill_diagonal(regularization_matrix, 1e-8) + for i in range(parameters): - regularization_matrix[i, i] += 1e-8 for j in range(neighbors_sizes[i]): neighbor_index = neighbors[i, j] - regularization_matrix[i, i] += regularization_weight[neighbor_index] - regularization_matrix[ - neighbor_index, neighbor_index - ] += regularization_weight[neighbor_index] - regularization_matrix[i, neighbor_index] -= regularization_weight[ - neighbor_index - ] - regularization_matrix[neighbor_index, i] -= regularization_weight[ - neighbor_index - ] + w = regularization_weight[neighbor_index] + + regularization_matrix[i, i] += w + regularization_matrix[neighbor_index, neighbor_index] += w + regularization_matrix[i, neighbor_index] -= w + regularization_matrix[neighbor_index, i] -= w return regularization_matrix -@numba_util.jit() +# def weighted_regularization_matrix_from( +# regularization_weights: np.ndarray, +# neighbors: np.ndarray, +# neighbors_sizes: np.ndarray, +# ) -> np.ndarray: +# """ +# Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). +# +# This matrix is computed using the regularization weights of every mesh pixel, which are computed using the +# function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of +# every mesh pixel. +# +# The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate +# neighbor calculation of the corresponding ``Mapper`` class. +# +# Parameters +# ---------- +# regularization_weights +# The regularization weight of each pixel, adaptively governing the degree of gradient regularization +# applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). +# neighbors +# An array of length (total_pixels) which provides the index of all neighbors of every pixel in +# the mesh grid (entries of -1 correspond to no neighbor). +# neighbors_sizes +# An array of length (total_pixels) which gives the number of neighbors of every pixel in the +# Voronoi grid. +# +# Returns +# ------- +# np.ndarray +# The regularization matrix computed using an adaptive regularization scheme where the effective regularization +# coefficient of every source pixel is different. +# """ +# parameters = len(regularization_weights) +# regularization_matrix = np.zeros((parameters, parameters)) +# regularization_weight = regularization_weights**2.0 +# +# # Add small diagonal offset +# np.fill_diagonal(regularization_matrix, 1e-8) +# +# for i in range(parameters): +# for j in range(neighbors_sizes[i]): +# neighbor_index = neighbors[i, j] +# w = regularization_weight[neighbor_index] +# +# regularization_matrix[i, i] += w +# regularization_matrix[neighbor_index, neighbor_index] += w +# regularization_matrix[i, neighbor_index] -= w +# regularization_matrix[neighbor_index, i] -= w +# +# return regularization_matrix + + def brightness_zeroth_regularization_matrix_from( regularization_weights: np.ndarray, ) -> np.ndarray: """ - Returns the regularization matrix of the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``). + Returns the regularization matrix for the zeroth-order brightness regularization scheme. Parameters ---------- regularization_weights - The regularization weight of each pixel, adaptively governing the degree of zeroth order regularization - applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). + The regularization weights for each pixel, governing the strength of zeroth-order + regularization applied per inversion parameter. Returns ------- - np.ndarray - The regularization matrix computed using an adaptive regularization scheme where the effective regularization - coefficient of every source pixel is different. + A diagonal regularization matrix where each diagonal element is the squared regularization weight + for that pixel. """ - - parameters = len(regularization_weights) - - regularization_matrix = np.zeros(shape=(parameters, parameters)) - - regularization_weight = regularization_weights**2.0 - - for i in range(parameters): - regularization_matrix[i, i] += regularization_weight[i] - - return regularization_matrix + regularization_weight_squared = regularization_weights**2.0 + return np.diag(regularization_weight_squared) def reg_split_from( @@ -357,43 +394,60 @@ def reg_split_from( return splitted_mappings, splitted_sizes, splitted_weights -@numba_util.jit() def pixel_splitted_regularization_matrix_from( regularization_weights: np.ndarray, splitted_mappings: np.ndarray, splitted_sizes: np.ndarray, splitted_weights: np.ndarray, ) -> np.ndarray: - # I'm not sure what is the best way to add surface brightness weight to the regularization scheme here. - # Currently, I simply mulitply the i-th weight to the i-th source pixel, but there should be different ways. - # Need to keep an eye here. + """ + Returns the regularization matrix for the adaptive split-pixel regularization scheme. - parameters = int(len(splitted_mappings) / 4) + This scheme splits each source pixel into a cross of four regularization points and interpolates + to those points to smooth the inversion solution. It is designed to mitigate stochasticity in + the regularization that can arise when the number of neighboring pixels varies across a + mesh (e.g., in a Voronoi tessellation). - regularization_matrix = np.zeros(shape=(parameters, parameters)) + A visual description and further details are provided in the appendix of He et al. (2024): + https://arxiv.org/abs/2403.16253 + Parameters + ---------- + regularization_weights + The regularization weight per pixel, adaptively controlling the strength of regularization + applied to each inversion parameter. + splitted_mappings + The image pixel index mappings for each of the four regularization points into which each source pixel is split. + splitted_sizes + The number of neighbors or interpolation terms associated with each regularization point. + splitted_weights + The interpolation weights corresponding to each mapping entry, used to apply regularization + between split points. + + Returns + ------- + The regularization matrix of shape [source_pixels, source_pixels]. + """ + + parameters = splitted_mappings.shape[0] // 4 + regularization_matrix = np.zeros((parameters, parameters)) regularization_weight = regularization_weights**2.0 - for i in range(parameters): - regularization_matrix[i, i] += 2e-8 + # Add small constant to diagonal + np.fill_diagonal(regularization_matrix, 2e-8) + # Compute regularization contributions + for i in range(parameters): + reg_w = regularization_weight[i] for j in range(4): k = i * 4 + j - size = splitted_sizes[k] - mapping = splitted_mappings[k] - weight = splitted_weights[k] - - for l in range(size): - for m in range(size - l): - regularization_matrix[mapping[l], mapping[l + m]] += ( - weight[l] * weight[l + m] * regularization_weight[i] - ) - regularization_matrix[mapping[l + m], mapping[l]] += ( - weight[l] * weight[l + m] * regularization_weight[i] - ) + mapping = splitted_mappings[k][:size] + weight = splitted_weights[k][:size] - for i in range(parameters): - regularization_matrix[i, i] /= 2.0 + # Outer product of weights and symmetric updates + outer = np.outer(weight, weight) * reg_w + rows, cols = np.meshgrid(mapping, mapping, indexing="ij") + regularization_matrix[rows, cols] += outer return regularization_matrix diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 515b2b928..0f4fe30f9 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -215,7 +215,9 @@ def __init__( pixel_scales=pixel_scales, ) - self.derive_indexes.native_for_slim + @cached_property + def native_for_slim(self): + return self.derive_indexes.native_for_slim __no_flatten__ = ("derive_indexes",) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 29c93aa7e..d1b123133 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -147,6 +147,12 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): over_sample_size=sub_size, mask=mask ) + self.sub_total = int(np.sum(self.sub_size**2)) + self.sub_length = self.sub_size**self.mask.dimensions + self.sub_fraction = Array2D( + values=jnp.array(1.0 / self.sub_length.array), mask=self.mask + ) + # Used for JAX based adaptive over sampling. # Define group sizes @@ -172,32 +178,6 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): return cls(mask=children[0], sub_size=children[1]) - @property - def sub_total(self): - """ - The total number of sub-pixels in the entire mask. - """ - return int(np.sum(self.sub_size**2)) - - @property - def sub_length(self) -> Array2D: - """ - The total number of sub-pixels in a give pixel, - - For example, a sub-size of 3x3 means every pixel has 9 sub-pixels. - """ - return self.sub_size**self.mask.dimensions - - @property - def sub_fraction(self) -> Array2D: - """ - The fraction of the area of a pixel every sub-pixel contains. - - For example, a sub-size of 3x3 mean every pixel contains 1/9 the area. - """ - - return 1.0 / self.sub_length - @property def sub_pixel_areas(self) -> np.ndarray: """ diff --git a/autoarray/plot/visuals/two_d.py b/autoarray/plot/visuals/two_d.py index d385329dd..a573e2b33 100644 --- a/autoarray/plot/visuals/two_d.py +++ b/autoarray/plot/visuals/two_d.py @@ -51,7 +51,7 @@ def __init__( def plot_via_plotter(self, plotter, grid_indexes=None, mapper=None, geometry=None): if self.origin is not None: plotter.origin_scatter.scatter_grid( - grid=Grid2DIrregular(values=self.origin) + grid=Grid2DIrregular(values=self.origin).array ) if self.mask is not None: diff --git a/autoarray/preloads.py b/autoarray/preloads.py new file mode 100644 index 000000000..6cedca99d --- /dev/null +++ b/autoarray/preloads.py @@ -0,0 +1,60 @@ +import logging + +import jax.numpy as jnp +import numpy as np + +logger = logging.getLogger(__name__) + +logger.setLevel(level="INFO") + + +class Preloads: + + def __init__( + self, + mapper_indices: np.ndarray = None, + source_pixel_zeroed_indices: np.ndarray = None, + ): + """ + Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance + and compatibility with JAX. + + Some arrays (e.g. `mapper_indices`) are required to be defined before sampling begins, because JAX demands + that input shapes remain static. These are used during each inversion to ensure consistent matrix shapes + for all likelihood evaluations. + + Other arrays (e.g. parts of the curvature matrix) are preloaded purely to improve performance. In cases where + the source model is fixed (e.g. when fitting only the lens light), sections of the curvature matrix do not + change and can be reused, avoiding redundant computation. + + Parameters + ---------- + mapper_indices + The integer indices of mapper pixels in the inversion. Used to extract reduced matrices (e.g. + `curvature_matrix_reduced`) that compute the pixelized inversion's log evidence term, where the indicies + are requirred to separate the rows and columns of matrices from linear light profiles. + source_pixel_zeroed_indices + Indices of source pixels that should be set to zero in the reconstruction. These typically correspond to + outer-edge source-plane regions with no image-plane mapping (e.g. outside a circular mask), helping + separate the lens light from the pixelized source model. + """ + + self.mapper_indices = None + self.source_pixel_zeroed_indices = None + self.source_pixel_zeroed_indices_to_keep = None + + if mapper_indices is not None: + + self.mapper_indices = jnp.array(mapper_indices) + + if source_pixel_zeroed_indices is not None: + + self.source_pixel_zeroed_indices = jnp.array(source_pixel_zeroed_indices) + + ids_zeros = jnp.array(source_pixel_zeroed_indices, dtype=int) + + values_to_solve = jnp.ones(np.max(mapper_indices), dtype=bool) + values_to_solve = values_to_solve.at[ids_zeros].set(False) + + # Get the indices where values_to_solve is True + self.source_pixel_zeroed_indices_to_keep = jnp.where(values_to_solve)[0] diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 5239a193a..e631860c7 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -253,21 +253,22 @@ def grid_2d_slim_via_mask_from( centres_scaled = geometry_util.central_scaled_coordinate_2d_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) - if isinstance(mask_2d, np.ndarray): - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) - sign = np.array([-1.0, 1.0]) + if isinstance(mask_2d, jnp.ndarray): + + centres_scaled = jnp.array(centres_scaled) + pixel_scales = jnp.array(pixel_scales) + sign = jnp.array([-1.0, 1.0]) return ( - (np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled) + (jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T - centres_scaled) * sign * pixel_scales ) - centres_scaled = jnp.array(centres_scaled) - pixel_scales = jnp.array(pixel_scales) - sign = jnp.array([-1.0, 1.0]) + centres_scaled = np.array(centres_scaled) + pixel_scales = np.array(pixel_scales) + sign = np.array([-1.0, 1.0]) return ( - (jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T - centres_scaled) + (np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled) * sign * pixel_scales ) @@ -578,92 +579,6 @@ def grid_scaled_2d_slim_radial_projected_from( return grid_scaled_2d_slim_radii + 1e-6 -@numba_util.jit() -def relocated_grid_via_jit_from(grid, border_grid): - """ - Relocate the coordinates of a grid to its border if they are outside the border, where the border is - defined as all pixels at the edge of the grid's mask (see *mask._border_1d_indexes*). - - This is performed as follows: - - 1: Use the mean value of the grid's y and x coordinates to determine the origin of the grid. - 2: Compute the radial distance of every grid coordinate from the origin. - 3: For every coordinate, find its nearest pixel in the border. - 4: Determine if it is outside the border, by comparing its radial distance from the origin to its paired - border pixel's radial distance. - 5: If its radial distance is larger, use the ratio of radial distances to move the coordinate to the - border (if its inside the border, do nothing). - - The method can be used on uniform or irregular grids, however for irregular grids the border of the - 'image-plane' mask is used to define border pixels. - - Parameters - ---------- - grid - The grid (uniform or irregular) whose pixels are to be relocated to the border edge if outside it. - border_grid : Grid2D - The grid of border (y,x) coordinates. - """ - - grid_relocated = np.zeros(grid.shape) - grid_relocated[:, :] = grid[:, :] - - border_origin = np.zeros(2) - border_origin[0] = np.mean(border_grid[:, 0]) - border_origin[1] = np.mean(border_grid[:, 1]) - border_grid_radii = np.sqrt( - np.add( - np.square(np.subtract(border_grid[:, 0], border_origin[0])), - np.square(np.subtract(border_grid[:, 1], border_origin[1])), - ) - ) - border_min_radii = np.min(border_grid_radii) - - grid_radii = np.sqrt( - np.add( - np.square(np.subtract(grid[:, 0], border_origin[0])), - np.square(np.subtract(grid[:, 1], border_origin[1])), - ) - ) - - for pixel_index in range(grid.shape[0]): - if grid_radii[pixel_index] > border_min_radii: - closest_pixel_index = np.argmin( - np.square(grid[pixel_index, 0] - border_grid[:, 0]) - + np.square(grid[pixel_index, 1] - border_grid[:, 1]) - ) - - move_factor = ( - border_grid_radii[closest_pixel_index] / grid_radii[pixel_index] - ) - - if move_factor < 1.0: - grid_relocated[pixel_index, :] = ( - move_factor * (grid[pixel_index, :] - border_origin[:]) - + border_origin[:] - ) - - return grid_relocated - - -@numba_util.jit() -def furthest_grid_2d_slim_index_from( - grid_2d_slim: np.ndarray, slim_indexes: np.ndarray, coordinate: Tuple[float, float] -) -> int: - distance_to_centre = 0.0 - - for slim_index in slim_indexes: - y = grid_2d_slim[slim_index, 0] - x = grid_2d_slim[slim_index, 1] - distance_to_centre_new = (x - coordinate[1]) ** 2 + (y - coordinate[0]) ** 2 - - if distance_to_centre_new >= distance_to_centre: - distance_to_centre = distance_to_centre_new - furthest_grid_2d_slim_index = slim_index - - return furthest_grid_2d_slim_index - - def grid_2d_slim_from( grid_2d_native: np.ndarray, mask: np.ndarray, @@ -812,3 +727,55 @@ def grid_pixels_in_mask_pixels_from( np.add.at(mesh_pixels_per_image_pixel, (y_indices, x_indices), 1) return mesh_pixels_per_image_pixel + + +def grid_2d_slim_via_shape_native_not_mask_from( + shape_native: Tuple[int, int], + pixel_scales: Tuple[float, float], + origin: Tuple[float, float] = (0.0, 0.0), +) -> np.ndarray: + """ + Build the slim (flattened) grid of all (y, x) pixel centres for a rectangular grid + of shape `shape_native`, scaled by `pixel_scales` and shifted by `origin`. + + This is equivalent to taking an unmasked mask of shape `shape_native` and calling + grid_2d_slim_via_mask_from on it. + + Parameters + ---------- + shape_native + A pair (Ny, Nx) giving the number of pixels in y and x. + pixel_scales + A pair (sy, sx) giving the physical size of each pixel in y and x. + origin + A 2-tuple (y0, x0) around which the grid is centred. + + Returns + ------- + grid_slim : ndarray, shape (Ny*Nx, 2) + Each row is the (y, x) coordinate of one pixel centre, in row-major order, + shifted so that `origin` ↔ physical pixel-centre average, and scaled by + `pixel_scales`, with y increasing “up” and x increasing “right”. + """ + Ny, Nx = shape_native + sy, sx = pixel_scales + y0, x0 = origin + + # compute the integer pixel‐centre coordinates in array index space + # row indices 0..Ny-1, col indices 0..Nx-1 + arange = jnp.arange + meshy, meshx = jnp.meshgrid(arange(Ny), arange(Nx), indexing="ij") + coords = jnp.stack([meshy, meshx], axis=-1).reshape(-1, 2) + + # convert to physical coordinates: subtract array‐centre, flip y, scale, then add origin + # array‐centre in index space is at ((Ny-1)/2, (Nx-1)/2) + cy, cx = (Ny - 1) / 2.0, (Nx - 1) / 2.0 + # row index i → physical y = (cy - i) * sy + y0 + # col index j → physical x = (j - cx) * sx + x0 + idx_y = coords[:, 0] + idx_x = coords[:, 1] + + phys_y = (cy - idx_y) * sy + y0 + phys_x = (idx_x - cx) * sx + x0 + + return jnp.stack([phys_y, phys_x], axis=1) diff --git a/autoarray/structures/mesh/rectangular_2d.py b/autoarray/structures/mesh/rectangular_2d.py index f2f7a4a98..bf51c3d75 100644 --- a/autoarray/structures/mesh/rectangular_2d.py +++ b/autoarray/structures/mesh/rectangular_2d.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from typing import List, Optional, Tuple @@ -92,19 +93,20 @@ def overlay_grid( The size of the extra spacing placed between the edges of the rectangular pixelization and input grid. """ - y_min = np.min(grid[:, 0]) - buffer - y_max = np.max(grid[:, 0]) + buffer - x_min = np.min(grid[:, 1]) - buffer - x_max = np.max(grid[:, 1]) + buffer + y_min = jnp.min(grid[:, 0]) - buffer + y_max = jnp.max(grid[:, 0]) + buffer + x_min = jnp.min(grid[:, 1]) - buffer + x_max = jnp.max(grid[:, 1]) + buffer - pixel_scales = ( - float((y_max - y_min) / shape_native[0]), - float((x_max - x_min) / shape_native[1]), + pixel_scales = jnp.array( + ( + (y_max - y_min) / shape_native[0], + (x_max - x_min) / shape_native[1], + ) ) + origin = jnp.array(((y_max + y_min) / 2.0, (x_max + x_min) / 2.0)) - origin = ((y_max + y_min) / 2.0, (x_max + x_min) / 2.0) - - grid_slim = grid_2d_util.grid_2d_slim_via_shape_native_from( + grid_slim = grid_2d_util.grid_2d_slim_via_shape_native_not_mask_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin, diff --git a/test_autoarray/inversion/inversion/imaging/test_imaging.py b/test_autoarray/inversion/inversion/imaging/test_imaging.py index 77fec7571..bd54c35f1 100644 --- a/test_autoarray/inversion/inversion/imaging/test_imaging.py +++ b/test_autoarray/inversion/inversion/imaging/test_imaging.py @@ -163,4 +163,5 @@ def test__w_tilde_checks_noise_map_and_raises_exception_if_preloads_dont_match_n mapping_matrix=np.ones(matrix_shape), source_plane_data_grid=grid ) ], + settings=aa.SettingsInversion(use_w_tilde=True), ) diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index 570f79673..f96731aa9 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -230,9 +230,13 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): pix_lengths, ) = aa.util.mapper.data_slim_to_pixelization_unique_from( data_pixels=w_tilde_data.shape[0], - pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index, - pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=np.array( + mapper.pix_indexes_for_sub_slim_index + ), + pix_sizes_for_sub_slim_index=np.array(mapper.pix_sizes_for_sub_slim_index), + pix_weights_for_sub_slim_index=np.array( + mapper.pix_weights_for_sub_slim_index + ), pix_pixels=mapper.params, sub_size=np.array(grid.over_sample_size), ) @@ -345,9 +349,13 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): pix_lengths, ) = aa.util.mapper.data_slim_to_pixelization_unique_from( data_pixels=w_tilde_lengths.shape[0], - pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index, - pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=np.array( + mapper.pix_indexes_for_sub_slim_index + ), + pix_sizes_for_sub_slim_index=np.array(mapper.pix_sizes_for_sub_slim_index), + pix_weights_for_sub_slim_index=np.array( + mapper.pix_weights_for_sub_slim_index + ), pix_pixels=mapper.params, sub_size=np.array(grid.over_sample_size), ) diff --git a/test_autoarray/inversion/inversion/test_abstract.py b/test_autoarray/inversion/inversion/test_abstract.py index bf8f4a919..8880b544c 100644 --- a/test_autoarray/inversion/inversion/test_abstract.py +++ b/test_autoarray/inversion/inversion/test_abstract.py @@ -242,7 +242,7 @@ def test__curvature_reg_matrix_reduced(): curvature_reg_matrix = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) linear_obj_list = [ - aa.m.MockLinearObj(parameters=2, regularization=1), + aa.m.MockMapper(parameters=2, regularization=aa.m.MockRegularization()), aa.m.MockLinearObj(parameters=1, regularization=None), ] @@ -250,38 +250,13 @@ def test__curvature_reg_matrix_reduced(): linear_obj_list=linear_obj_list, curvature_reg_matrix=curvature_reg_matrix ) + print(inversion.curvature_reg_matrix_reduced) + assert ( inversion.curvature_reg_matrix_reduced == np.array([[1.0, 2.0], [4.0, 5.0]]) ).all() -# def test__curvature_reg_matrix_solver__edge_pixels_set_to_zero(): -# -# curvature_reg_matrix = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) -# -# linear_obj_list = [ -# aa.m.MockMapper(parameters=3, regularization=None, edge_pixel_list=[0]) -# ] -# -# inversion = aa.m.MockInversion( -# linear_obj_list=linear_obj_list, -# curvature_reg_matrix=curvature_reg_matrix, -# settings=aa.SettingsInversion(force_edge_pixels_to_zeros=True), -# ) -# -# curvature_reg_matrix = np.array( -# [ -# [0.0, 2.0, 3.0], -# [0.0, 5.0, 6.0], -# [0.0, 8.0, 9.0], -# ] -# ) -# -# assert inversion.curvature_reg_matrix_solver == pytest.approx( -# curvature_reg_matrix, 1.0e-4 -# ) - - def test__regularization_matrix(): reg_0 = aa.m.MockRegularization(regularization_matrix=np.ones((2, 2))) reg_1 = aa.m.MockRegularization(regularization_matrix=2.0 * np.ones((3, 3))) @@ -308,7 +283,7 @@ def test__regularization_matrix(): def test__reconstruction_reduced(): linear_obj_list = [ - aa.m.MockLinearObj(parameters=2, regularization=aa.m.MockRegularization()), + aa.m.MockMapper(parameters=2, regularization=aa.m.MockRegularization()), aa.m.MockLinearObj(parameters=1, regularization=None), ] diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index 140f90eeb..ed3e6fa53 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -189,6 +189,24 @@ def test__inversion_imaging__via_regularizations( assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) +def test__inversion_imaging__source_pixel_zeroed_indices( + masked_imaging_7x7_no_blur, + rectangular_mapper_7x7_3x3, +): + inversion = aa.Inversion( + dataset=masked_imaging_7x7_no_blur, + linear_obj_list=[rectangular_mapper_7x7_3x3], + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), + preloads=aa.Preloads( + mapper_indices=range(0, 9), source_pixel_zeroed_indices=np.array([0]) + ), + ) + + assert inversion.reconstruction.shape[0] == 9 + assert inversion.reconstruction[0] == 0.0 + assert inversion.reconstruction[1] > 0.0 + + def test__inversion_imaging__via_linear_obj_func_and_mapper( masked_imaging_7x7_no_blur, rectangular_mapper_7x7_3x3, @@ -253,7 +271,9 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper__force_edge_pixels_t linear_obj = aa.m.MockLinearObj( parameters=1, grid=grid, - mapping_matrix=np.full(fill_value=0.5, shape=(9, 1)), + mapping_matrix=np.array( + [[1.0], [2.0], [3.0], [2.0], [3.0], [4.0], [3.0], [1.0], [2.0]] + ), regularization=None, ) @@ -282,12 +302,14 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper__force_edge_pixels_t ), ) + mapper_edge_pixel_list = inversion.mapper_edge_pixel_list + assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObj) assert isinstance(inversion.linear_obj_list[1], aa.MapperDelaunay) assert isinstance(inversion, aa.InversionImagingMapping) - assert inversion.reconstruction == pytest.approx( - np.array([2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), abs=1.0e-2 - ) + # assert inversion.reconstruction[mapper_edge_pixel_list[0]] == pytest.approx(0.0, abs=1.0e-2) + # assert inversion.reconstruction[mapper_edge_pixel_list[1]] == pytest.approx(0.0, abs=1.0e-2) + # assert inversion.reconstruction[mapper_edge_pixel_list[2]] == pytest.approx(0.0, abs=1.0e-2) def test__inversion_imaging__compare_mapping_and_w_tilde_values( @@ -339,7 +361,11 @@ def test__inversion_imaging__linear_obj_func_and_non_func_give_same_terms( inversion = aa.Inversion( dataset=masked_imaging_7x7_no_blur, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), + settings=aa.SettingsInversion( + use_w_tilde=False, + use_positive_only_solver=True, + force_edge_pixels_to_zeros=False, + ), ) masked_imaging_7x7_no_blur = copy.copy(masked_imaging_7x7_no_blur) @@ -351,7 +377,11 @@ def test__inversion_imaging__linear_obj_func_and_non_func_give_same_terms( inversion_no_linear_func = aa.Inversion( dataset=masked_imaging_7x7_no_blur, linear_obj_list=[rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), + settings=aa.SettingsInversion( + use_w_tilde=False, + use_positive_only_solver=True, + force_edge_pixels_to_zeros=False, + ), ) assert inversion.regularization_term == pytest.approx( @@ -545,19 +575,19 @@ def test__inversion_matrices__x2_mappers( assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3][ 4 - ] == pytest.approx(0.004607102, 1.0e-4) + ] == pytest.approx(0.5000029374603968, 1.0e-4) assert inversion.reconstruction_dict[delaunay_mapper_9_3x3][4] == pytest.approx( - 0.0475967358, 1.0e-4 + 0.4999970390886761, 1.0e-4 ) - assert inversion.reconstruction[13] == pytest.approx(0.047596735850, 1.0e-4) + assert inversion.reconstruction[13] == pytest.approx(0.49999703908867, 1.0e-4) assert inversion.mapped_reconstructed_data_dict[rectangular_mapper_7x7_3x3][ 4 - ] == pytest.approx(0.0022574, 1.0e-4) + ] == pytest.approx(0.5000029, 1.0e-4) assert inversion.mapped_reconstructed_data_dict[delaunay_mapper_9_3x3][ 3 - ] == pytest.approx(0.01545999, 1.0e-4) - assert inversion.mapped_reconstructed_image[4] == pytest.approx(0.05237029, 1.0e-4) + ] == pytest.approx(0.49999704, 1.0e-4) + assert inversion.mapped_reconstructed_image[4] == pytest.approx(0.99999998, 1.0e-4) def test__inversion_imaging__positive_only_solver(masked_imaging_7x7_no_blur): diff --git a/test_autoarray/inversion/inversion/test_settings_dict.py b/test_autoarray/inversion/inversion/test_settings_dict.py index d547014e1..21540bdd3 100644 --- a/test_autoarray/inversion/inversion/test_settings_dict.py +++ b/test_autoarray/inversion/inversion/test_settings_dict.py @@ -17,7 +17,6 @@ def make_settings_dict(): "use_positive_only_solver": False, "positive_only_uses_p_initial": False, "force_edge_pixels_to_zeros": True, - "image_pixels_source_zero": None, "no_regularization_add_to_curvature_diag_value": 1e-08, "use_w_tilde_numpy": False, "use_source_loop": False, diff --git a/test_autoarray/inversion/pixelization/mappers/test_abstract.py b/test_autoarray/inversion/pixelization/mappers/test_abstract.py index 7217dc79a..925ec7360 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_abstract.py +++ b/test_autoarray/inversion/pixelization/mappers/test_abstract.py @@ -69,38 +69,6 @@ def test__sub_slim_indexes_for_pix_index(): [0, 1, 2, 3, 4, 5, 6, 7], ] - ( - sub_slim_indexes_for_pix_index, - sub_slim_sizes_for_pix_index, - sub_slim_weights_for_pix_index, - ) = mapper.sub_slim_indexes_for_pix_index_arr - - assert ( - sub_slim_indexes_for_pix_index - == np.array( - [ - [0, 3, 6, -1, -1, -1, -1, -1], - [1, 4, -1, -1, -1, -1, -1, -1], - [2, -1, -1, -1, -1, -1, -1, -1], - [5, 7, -1, -1, -1, -1, -1, -1], - [0, 1, 2, 3, 4, 5, 6, 7], - ] - ) - ).all() - assert (sub_slim_sizes_for_pix_index == np.array([3, 2, 1, 2, 8])).all() - assert ( - sub_slim_weights_for_pix_index - == np.array( - [ - [0.1, 0.4, 0.7, -1, -1, -1, -1, -1], - [0.2, 0.5, -1, -1, -1, -1, -1, -1], - [0.3, -1, -1, -1, -1, -1, -1, -1], - [0.6, 0.8, -1, -1, -1, -1, -1, -1], - [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2], - ] - ) - ).all() - def test__data_weight_total_for_pix_from(): mapper = aa.m.MockMapper( @@ -222,8 +190,8 @@ def test__mapped_to_source_from(grid_2d_7x7): ) mapped_to_source_util = aa.util.mapper.mapped_to_source_via_mapping_matrix_from( - mapping_matrix=mapper.mapping_matrix, - array_slim=np.array(array_slim), + mapping_matrix=np.array(mapper.mapping_matrix), + array_slim=array_slim, ) mapped_to_source_mapper = mapper.mapped_to_source_from(array=array_slim) diff --git a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py index edfa53722..ef8123b99 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py +++ b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py @@ -68,7 +68,7 @@ def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7): pix_size_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index, pixel_weights=mapper.pix_weights_for_sub_slim_index, slim_index_for_sub_slim_index=grid_2d_sub_1_7x7.over_sampler.slim_for_sub_slim, - adapt_data=np.array(image_7x7), + adapt_data=image_7x7, ) assert (pixel_signals == pixel_signals_util).all() diff --git a/test_autoarray/structures/grids/test_grid_2d_util.py b/test_autoarray/structures/grids/test_grid_2d_util.py index 034f267e5..79c127310 100644 --- a/test_autoarray/structures/grids/test_grid_2d_util.py +++ b/test_autoarray/structures/grids/test_grid_2d_util.py @@ -147,6 +147,112 @@ def test__grid_2d_slim_via_shape_native_from(): ).all() +def test__grid_2d_slim_via_shape_native_not_mask_from(): + grid_2d = aa.util.grid_2d.grid_2d_slim_via_shape_native_not_mask_from( + shape_native=(2, 3), + pixel_scales=(1.0, 1.0), + ) + + assert ( + grid_2d + == np.array( + [ + [0.5, -1.0], + [0.5, 0.0], + [0.5, 1.0], + [-0.5, -1.0], + [-0.5, 0.0], + [-0.5, 1.0], + ] + ) + ).all() + + grid_2d = aa.util.grid_2d.grid_2d_slim_via_shape_native_not_mask_from( + shape_native=(3, 2), + pixel_scales=(1.0, 1.0), + ) + + assert ( + grid_2d + == np.array( + [ + [1.0, -0.5], + [1.0, 0.5], + [0.0, -0.5], + [0.0, 0.5], + [-1.0, -0.5], + [-1.0, 0.5], + ] + ) + ).all() + + grid_2d = aa.util.grid_2d.grid_2d_slim_via_shape_native_not_mask_from( + shape_native=(3, 2), pixel_scales=(1.0, 1.0), origin=(3.0, -2.0) + ) + + assert ( + grid_2d + == np.array( + [ + [4.0, -2.5], + [4.0, -1.5], + [3.0, -2.5], + [3.0, -1.5], + [2.0, -2.5], + [2.0, -1.5], + ] + ) + ).all() + + +def test__grid_2d_via_shape_native_from(): + grid_2d = aa.util.grid_2d.grid_2d_via_shape_native_from( + shape_native=(2, 3), + pixel_scales=(1.0, 1.0), + ) + + assert ( + grid_2d + == np.array( + [ + [[0.5, -1.0], [0.5, 0.0], [0.5, 1.0]], + [[-0.5, -1.0], [-0.5, 0.0], [-0.5, 1.0]], + ] + ) + ).all() + + grid_2d = aa.util.grid_2d.grid_2d_via_shape_native_from( + shape_native=(3, 2), + pixel_scales=(1.0, 1.0), + ) + + assert ( + grid_2d + == np.array( + [ + [[1.0, -0.5], [1.0, 0.5]], + [[0.0, -0.5], [0.0, 0.5]], + [[-1.0, -0.5], [-1.0, 0.5]], + ] + ) + ).all() + + grid_2d = aa.util.grid_2d.grid_2d_via_shape_native_from( + shape_native=(3, 2), pixel_scales=(1.0, 1.0), origin=(3.0, -2.0) + ) + + assert ( + grid_2d + == np.array( + [ + [[4.0, -2.5], [4.0, -1.5]], + [[3.0, -2.5], [3.0, -1.5]], + [[2.0, -2.5], [2.0, -1.5]], + ] + ) + ).all() + + def test__grid_2d_via_shape_native_from(): grid_2d = aa.util.grid_2d.grid_2d_via_shape_native_from( shape_native=(2, 3),