diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index bc0daf0ad..55c3b61de 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -1,5 +1,6 @@ import copy import jax.numpy as jnp +from jax.scipy.linalg import block_diag import numpy as np from typing import Dict, List, Optional, Type, Union @@ -334,8 +335,6 @@ def regularization_matrix(self) -> Optional[np.ndarray]: If the `settings.force_edge_pixels_to_zeros` is `True`, the edge pixels of each mapper in the inversion are regularized so high their value is forced to zero. """ - from scipy.linalg import block_diag - return block_diag( *[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list] ) @@ -664,30 +663,17 @@ def log_det_regularization_matrix_term(self) -> float: float The log determinant of the regularization matrix. """ - from scipy.sparse import csc_matrix - from scipy.sparse.linalg import splu - if not self.has(cls=AbstractRegularization): return 0.0 try: - lu = splu(csc_matrix(self.regularization_matrix_reduced)) - diagL = lu.L.diagonal() - diagU = lu.U.diagonal() - diagL = diagL.astype(np.complex128) - diagU = diagU.astype(np.complex128) - - return np.real(np.log(diagL).sum() + np.log(diagU).sum()) - - except RuntimeError: - try: - return 2.0 * np.sum( - np.log( - np.diag(np.linalg.cholesky(self.regularization_matrix_reduced)) - ) + return 2.0 * np.sum( + jnp.log( + jnp.diag(jnp.linalg.cholesky(self.regularization_matrix_reduced)) ) - except np.linalg.LinAlgError as e: - raise exc.InversionException() from e + ) + except np.linalg.LinAlgError as e: + raise exc.InversionException() from e @property def reconstruction_noise_map_with_covariance(self) -> np.ndarray: diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index 091713af1..8d0fad744 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -288,7 +288,7 @@ def pixel_signals_from(self, signal_scale: float) -> np.ndarray: pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index, slim_index_for_sub_slim_index=self.over_sampler.slim_for_sub_slim, - adapt_data=np.array(self.adapt_data), + adapt_data=self.adapt_data.array, ) def slim_indexes_for_pix_indexes(self, pix_indexes: List) -> List[List]: diff --git a/autoarray/inversion/pixelization/mesh/mesh_util.py b/autoarray/inversion/pixelization/mesh/mesh_util.py index ea1384349..420d9c2ab 100644 --- a/autoarray/inversion/pixelization/mesh/mesh_util.py +++ b/autoarray/inversion/pixelization/mesh/mesh_util.py @@ -5,7 +5,6 @@ from autoarray import numba_util -@numba_util.jit() def rectangular_neighbors_from( shape_native: Tuple[int, int], ) -> Tuple[np.ndarray, np.ndarray]: @@ -68,7 +67,6 @@ def rectangular_neighbors_from( return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_corner_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -113,7 +111,6 @@ def rectangular_corner_neighbors( return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_top_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -136,17 +133,20 @@ def rectangular_top_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[1] - 1): - pixel_index = pix - neighbors[pixel_index, 0:3] = np.array( - [pixel_index - 1, pixel_index + 1, pixel_index + shape_native[1]] - ) - neighbors_sizes[pixel_index] = 3 + """ + Vectorized version of the top edge neighbor update using NumPy arithmetic. + """ + # Pixels along the top edge, excluding corners + top_edge_pixels = np.arange(1, shape_native[1] - 1) + + neighbors[top_edge_pixels, 0] = top_edge_pixels - 1 + neighbors[top_edge_pixels, 1] = top_edge_pixels + 1 + neighbors[top_edge_pixels, 2] = top_edge_pixels + shape_native[1] + neighbors_sizes[top_edge_pixels] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_left_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -169,21 +169,20 @@ def rectangular_left_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[0] - 1): - pixel_index = pix * shape_native[1] - neighbors[pixel_index, 0:3] = np.array( - [ - pixel_index - shape_native[1], - pixel_index + 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 3 + # Row indices (excluding top and bottom corners) + rows = np.arange(1, shape_native[0] - 1) + + # Convert to flat pixel indices for the left edge (first column) + pixel_indices = rows * shape_native[1] + + neighbors[pixel_indices, 0] = pixel_indices - shape_native[1] + neighbors[pixel_indices, 1] = pixel_indices + 1 + neighbors[pixel_indices, 2] = pixel_indices + shape_native[1] + neighbors_sizes[pixel_indices] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_right_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -206,21 +205,20 @@ def rectangular_right_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[0] - 1): - pixel_index = pix * shape_native[1] + shape_native[1] - 1 - neighbors[pixel_index, 0:3] = np.array( - [ - pixel_index - shape_native[1], - pixel_index - 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 3 + # Rows excluding the top and bottom corners + rows = np.arange(1, shape_native[0] - 1) + + # Flat indices for the right edge pixels + pixel_indices = rows * shape_native[1] + shape_native[1] - 1 + + neighbors[pixel_indices, 0] = pixel_indices - shape_native[1] + neighbors[pixel_indices, 1] = pixel_indices - 1 + neighbors[pixel_indices, 2] = pixel_indices + shape_native[1] + neighbors_sizes[pixel_indices] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_bottom_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -243,19 +241,21 @@ def rectangular_bottom_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - pixels = int(shape_native[0] * shape_native[1]) + n_rows, n_cols = shape_native + pixels = n_rows * n_cols - for pix in range(1, shape_native[1] - 1): - pixel_index = pixels - pix - 1 - neighbors[pixel_index, 0:3] = np.array( - [pixel_index - shape_native[1], pixel_index - 1, pixel_index + 1] - ) - neighbors_sizes[pixel_index] = 3 + # Horizontal pixel positions along bottom row, excluding corners + cols = np.arange(1, n_cols - 1) + pixel_indices = pixels - cols - 1 # Reverse order from right to left + + neighbors[pixel_indices, 0] = pixel_indices - n_cols + neighbors[pixel_indices, 1] = pixel_indices - 1 + neighbors[pixel_indices, 2] = pixel_indices + 1 + neighbors_sizes[pixel_indices] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_central_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -279,46 +279,61 @@ def rectangular_central_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for x in range(1, shape_native[0] - 1): - for y in range(1, shape_native[1] - 1): - pixel_index = x * shape_native[1] + y - neighbors[pixel_index, 0:4] = np.array( - [ - pixel_index - shape_native[1], - pixel_index - 1, - pixel_index + 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 4 + n_rows, n_cols = shape_native + + # Grid coordinates excluding edges + xs = np.arange(1, n_rows - 1) + ys = np.arange(1, n_cols - 1) + + # 2D grid of central pixel indices + grid_x, grid_y = np.meshgrid(xs, ys, indexing="ij") + pixel_indices = grid_x * n_cols + grid_y + pixel_indices = pixel_indices.ravel() + + # Compute neighbor indices + neighbors[pixel_indices, 0] = pixel_indices - n_cols # Up + neighbors[pixel_indices, 1] = pixel_indices - 1 # Left + neighbors[pixel_indices, 2] = pixel_indices + 1 # Right + neighbors[pixel_indices, 3] = pixel_indices + n_cols # Down + + neighbors_sizes[pixel_indices] = 4 return neighbors, neighbors_sizes -def rectangular_edge_pixel_list_from(neighbors: np.ndarray) -> List: +def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]: """ - Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization. - - This is computed by searching the `neighbors` array for pixels that have a neighbor with index -1, meaning there - is at least one neighbor from the 4 expected missing. + Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization, + based on its 2D shape. Parameters ---------- - neighbors - An array of dimensions [total_pixels, 4] which provides the index of all neighbors of every pixel in the - rectangular pixelization (entries of -1 correspond to no neighbor). + shape_native + The (rows, cols) shape of the rectangular 2D pixel grid. Returns ------- - A list of the 1D indices of all pixels on the edge of a rectangular pixelization. + A list of the 1D indices of all edge pixels. """ - edge_pixel_list = [] + rows, cols = shape_native + + # Top row + top = np.arange(0, cols) + + # Bottom row + bottom = np.arange((rows - 1) * cols, rows * cols) + + # Left column (excluding corners) + left = np.arange(1, rows - 1) * cols + + # Right column (excluding corners) + right = (np.arange(1, rows - 1) + 1) * cols - 1 - for i, neighbors in enumerate(neighbors): - if -1 in neighbors: - edge_pixel_list.append(i) + # Concatenate all edge indices + edge_pixel_indices = np.concatenate([top, left, right, bottom]) - return edge_pixel_list + # Sort and return + return np.sort(edge_pixel_indices).tolist() @numba_util.jit() diff --git a/autoarray/inversion/regularization/adaptive_brightness.py b/autoarray/inversion/regularization/adaptive_brightness.py index d7c322817..c0ba845d0 100644 --- a/autoarray/inversion/regularization/adaptive_brightness.py +++ b/autoarray/inversion/regularization/adaptive_brightness.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,114 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def adaptive_regularization_weights_from( + inner_coefficient: float, outer_coefficient: float, pixel_signals: jnp.ndarray +) -> jnp.ndarray: + """ + Returns the regularization weights for the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). + + The weights define the effective regularization coefficient of every mesh parameter (typically pixels + of a ``Mapper``). + + They are computed using an estimate of the expected signal in each pixel. + + Two regularization coefficients are used, corresponding to the: + + 1) pixel_signals: pixels with a high pixel-signal (i.e. where the signal is located in the pixelization). + 2) 1.0 - pixel_signals: pixels with a low pixel-signal (i.e. where the signal is not located in the pixelization). + + Parameters + ---------- + inner_coefficient + The inner regularization coefficients which controls the degree of smoothing of the inversion reconstruction + in the inner regions of a mesh's reconstruction. + outer_coefficient + The outer regularization coefficients which controls the degree of smoothing of the inversion reconstruction + in the outer regions of a mesh's reconstruction. + pixel_signals + The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal + and low signal pixelizations. + + Returns + ------- + jnp.ndarray + The adaptive regularization weights which act as the effective regularization coefficients of + every source pixel. + """ + return ( + inner_coefficient * pixel_signals + outer_coefficient * (1.0 - pixel_signals) + ) ** 2.0 + + +def weighted_regularization_matrix_from( + regularization_weights: jnp.ndarray, + neighbors: jnp.ndarray, +) -> jnp.ndarray: + """ + Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). + + This matrix is computed using the regularization weights of every mesh pixel, which are computed using the + function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of + every mesh pixel. + + The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate + neighbor calculation of the corresponding ``Mapper`` class. + + Parameters + ---------- + regularization_weights + The regularization weight of each pixel, adaptively governing the degree of gradient regularization + applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). + neighbors + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the mesh grid (entries of -1 correspond to no neighbor). + neighbors_sizes + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + jnp.ndarray + The regularization matrix computed using an adaptive regularization scheme where the effective regularization + coefficient of every source pixel is different. + """ + S, P = neighbors.shape + reg_w = regularization_weights**2 + + # 1) Flatten the (i→j) neighbor pairs + I = jnp.repeat(jnp.arange(S), P) # (S*P,) + J = neighbors.reshape(-1) # (S*P,) + + # 2) Remap “no neighbor” entries to an extra slot S, whose weight=0 + OUT = S + J = jnp.where(J < 0, OUT, J) + + # 3) Build an extended weight vector with a zero at index S + reg_w_ext = jnp.concatenate([reg_w, jnp.zeros((1,))], axis=0) + w_ij = reg_w_ext[J] # (S*P,) + + # 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely + mat = jnp.zeros((S + 1, S + 1), dtype=regularization_weights.dtype) + + # 5) Scatter into the diagonal: + # - the tiny 1e-8 floor on each i < S + # - sum_j reg_w[j] into diag[i] + # - sum contributions reg_w[j] into diag[j] + # (diagonal at OUT=S picks up zeros only) + diag_updates_i = jnp.concatenate( + [jnp.full((S,), 1e-8), jnp.zeros((1,))], axis=0 # out‐of‐bounds slot stays zero + ) + mat = mat.at[jnp.diag_indices(S + 1)].add(diag_updates_i) + mat = mat.at[I, I].add(w_ij) + mat = mat.at[J, J].add(w_ij) + + # 6) Scatter the off‐diagonal subtractions: + mat = mat.at[I, J].add(-w_ij) + mat = mat.at[J, I].add(-w_ij) + + # 7) Drop the extra row/column S and return the S×S result + return mat[:S, :S] class AdaptiveBrightness(AbstractRegularization): @@ -70,7 +177,7 @@ def __init__( self.outer_coefficient = outer_coefficient self.signal_scale = signal_scale - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -91,13 +198,13 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: """ pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale) - return regularization_util.adaptive_regularization_weights_from( + return adaptive_regularization_weights_from( inner_coefficient=self.inner_coefficient, outer_coefficient=self.outer_coefficient, pixel_signals=pixel_signals, ) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -112,8 +219,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ regularization_weights = self.regularization_weights_from(linear_obj=linear_obj) - return regularization_util.weighted_regularization_matrix_from( + return weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=linear_obj.source_plane_mesh_grid.neighbors, - neighbors_sizes=linear_obj.source_plane_mesh_grid.neighbors.sizes, ) diff --git a/autoarray/inversion/regularization/brightness_zeroth.py b/autoarray/inversion/regularization/brightness_zeroth.py index 4cab4e6d2..6cd765aec 100644 --- a/autoarray/inversion/regularization/brightness_zeroth.py +++ b/autoarray/inversion/regularization/brightness_zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -10,6 +10,60 @@ from autoarray.inversion.regularization import regularization_util +def brightness_zeroth_regularization_weights_from( + coefficient: float, pixel_signals: jnp.ndarray +) -> jnp.ndarray: + """ + Returns the regularization weights for the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``). + + The weights define the level of zeroth order regularization applied to every mesh parameter (typically pixels + of a ``Mapper``). + + They are computed using an estimate of the expected signal in each pixel. + + The zeroth order regularization coefficients is applied in combination with 1.0 - pixel_signals, which are + the pixels with a low pixel-signal (i.e. where the signal is not located near the source being reconstructed in + the pixelization). + + Parameters + ---------- + coefficient + The level of zeroth order regularization applied to every mesh parameter (typically pixels of a ``Mapper``), + with the degree applied varying based on the ``pixel_signals``. + pixel_signals + The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal + and low signal pixelizations. + + Returns + ------- + jnp.ndarray + The zeroth order regularization weights which act as the effective level of zeroth order regularization + applied to every mesh parameter. + """ + return coefficient * (1.0 - pixel_signals) + + +def brightness_zeroth_regularization_matrix_from( + regularization_weights: jnp.ndarray, +) -> jnp.ndarray: + """ + Returns the regularization matrix for the zeroth-order brightness regularization scheme. + + Parameters + ---------- + regularization_weights + The regularization weights for each pixel, governing the strength of zeroth-order + regularization applied per inversion parameter. + + Returns + ------- + A diagonal regularization matrix where each diagonal element is the squared regularization weight + for that pixel. + """ + regularization_weight_squared = regularization_weights**2.0 + return jnp.diag(regularization_weight_squared) + + class BrightnessZeroth(AbstractRegularization): def __init__( self, @@ -45,7 +99,7 @@ def __init__( self.coefficient = coefficient self.signal_scale = signal_scale - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of the ``BrightnessZeroth`` regularization scheme. @@ -65,11 +119,11 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: """ pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale) - return regularization_util.brightness_zeroth_regularization_weights_from( + return brightness_zeroth_regularization_weights_from( coefficient=self.coefficient, pixel_signals=pixel_signals ) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -84,6 +138,6 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ regularization_weights = self.regularization_weights_from(linear_obj=linear_obj) - return regularization_util.brightness_zeroth_regularization_matrix_from( + return brightness_zeroth_regularization_matrix_from( regularization_weights=regularization_weights ) diff --git a/autoarray/inversion/regularization/constant.py b/autoarray/inversion/regularization/constant.py index 690b248bd..d9737d075 100644 --- a/autoarray/inversion/regularization/constant.py +++ b/autoarray/inversion/regularization/constant.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,57 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def constant_regularization_matrix_from( + coefficient: float, + neighbors: jnp.ndarray[[int, int], jnp.int64], + neighbors_sizes: jnp.ndarray[[int], jnp.int64], +) -> jnp.ndarray[[int, int], jnp.float64]: + """ + From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. + + A complete description of regularizatin and the `regularization_matrix` can be found in the `Regularization` + class in the module `autoarray.inversion.regularization`. + + Memory requirement: 2SP + S^2 + FLOPS: 1 + 2S + 2SP + + Parameters + ---------- + coefficient + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + neighbors : ndarray, shape (S, P), dtype=int64 + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the Voronoi grid (entries of -1 correspond to no neighbor). + neighbors_sizes : ndarray, shape (S,), dtype=int64 + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + regularization_matrix : ndarray, shape (S, S), dtype=float64 + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + S, P = neighbors.shape + # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) + OUT_OF_BOUND_IDX = S + regularization_coefficient = coefficient * coefficient + + # flatten it for feeding into the matrix as j indices + neighbors = neighbors.flatten() + # now create the corresponding i indices + I_IDX = jnp.repeat(jnp.arange(S), P) + # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. + # This ensures that JAX can efficiently drop these entries during matrix updates. + neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) + return ( + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[ + I_IDX, neighbors + ] + # unique indices should be guranteed by neighbors-spec + .add(-regularization_coefficient, mode="drop", unique_indices=True) + ) class Constant(AbstractRegularization): @@ -38,7 +88,7 @@ def __init__(self, coefficient: float = 1.0): super().__init__() - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -57,9 +107,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -73,7 +123,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ - return regularization_util.constant_regularization_matrix_from( + return constant_regularization_matrix_from( coefficient=self.coefficient, neighbors=linear_obj.neighbors, neighbors_sizes=linear_obj.neighbors.sizes, diff --git a/autoarray/inversion/regularization/constant_zeroth.py b/autoarray/inversion/regularization/constant_zeroth.py index 5e3d8acb3..11d7b9808 100644 --- a/autoarray/inversion/regularization/constant_zeroth.py +++ b/autoarray/inversion/regularization/constant_zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,61 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def constant_zeroth_regularization_matrix_from( + coefficient: float, + coefficient_zeroth: float, + neighbors: jnp.ndarray, + neighbors_sizes: jnp.ndarray[[int], jnp.int64], +) -> jnp.ndarray: + """ + From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. + + A complete description of regularizatin and the ``regularization_matrix`` can be found in the ``Regularization`` + class in the module ``autoarray.inversion.regularization``. + + Parameters + ---------- + coefficients + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + neighbors + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the Voronoi grid (entries of -1 correspond to no neighbor). + neighbors_sizes + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + jnp.ndarray + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + S, P = neighbors.shape + # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) + OUT_OF_BOUND_IDX = S + regularization_coefficient = coefficient * coefficient + + # flatten it for feeding into the matrix as j indices + neighbors = neighbors.flatten() + # now create the corresponding i indices + I_IDX = jnp.repeat(jnp.arange(S), P) + # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. + # This ensures that JAX can efficiently drop these entries during matrix updates. + neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) + const = ( + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[ + I_IDX, neighbors + ] + # unique indices should be guranteed by neighbors-spec + .add(-regularization_coefficient, mode="drop", unique_indices=True) + ) + + reg_coeff = coefficient_zeroth**2.0 + # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T + zeroth = jnp.eye(P) * reg_coeff + + return const + zeroth class ConstantZeroth(AbstractRegularization): @@ -17,7 +71,7 @@ def __init__(self, coefficient_neighbor=1.0, coefficient_zeroth=1.0): self.coefficient_neighbor = coefficient_neighbor self.coefficient_zeroth = coefficient_zeroth - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -36,9 +90,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient_neighbor * np.ones(linear_obj.params) + return self.coefficient_neighbor * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -51,9 +105,8 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization matrix. """ - return regularization_util.constant_zeroth_regularization_matrix_from( + return constant_zeroth_regularization_matrix_from( coefficient=self.coefficient_neighbor, coefficient_zeroth=self.coefficient_zeroth, neighbors=linear_obj.neighbors, - neighbors_sizes=linear_obj.neighbors.sizes, ) diff --git a/autoarray/inversion/regularization/exponential_kernel.py b/autoarray/inversion/regularization/exponential_kernel.py index 73ead006d..cfb03186b 100644 --- a/autoarray/inversion/regularization/exponential_kernel.py +++ b/autoarray/inversion/regularization/exponential_kernel.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,52 +7,44 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray import numba_util - -@numba_util.jit() def exp_cov_matrix_from( scale: float, - pixel_points: np.ndarray, -) -> np.ndarray: + pixel_points: jnp.ndarray, # shape (N, 2) +) -> jnp.ndarray: # shape (N, N) """ - Consutruct the source brightness covariance matrix, which is used to determined the regularization - pattern (i.e, how the different source pixels are smoothed). + Construct the source brightness covariance matrix using an exponential kernel: + + cov[i,j] = exp(- d_{ij} / scale) - The covariance matrix includes one non-linear parameters, the scale coefficient, which is used to determine - the typical scale of the regularization pattern. + with a tiny jitter 1e-8 added on the diagonal for numerical stability. Parameters ---------- scale - The typical scale of the regularization pattern . + The length‐scale of the exponential kernel. pixel_points - An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane). - Something like [[y1,x1], [y2,x2], ...] + Array of shape (N, 2) giving the (y,x) coordinates of each source‐plane pixel. Returns ------- - np.ndarray - The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels]. + jnp.ndarray, shape (N, N) + The exponential covariance matrix. """ + # pairwise differences: shape (N, N, 2) + diff = pixel_points[:, None, :] - pixel_points[None, :, :] - pixels = len(pixel_points) - covariance_matrix = np.zeros(shape=(pixels, pixels)) + # Euclidean distances: shape (N, N) + d = jnp.linalg.norm(diff, axis=-1) - for i in range(pixels): - covariance_matrix[i, i] += 1e-8 - for j in range(pixels): - xi = pixel_points[i, 1] - yi = pixel_points[i, 0] - xj = pixel_points[j, 1] - yj = pixel_points[j, 0] - d_ij = np.sqrt( - (xi - xj) ** 2 + (yi - yj) ** 2 - ) # distance between the pixel i and j + # exponential kernel + cov = jnp.exp(-d / scale) - covariance_matrix[i, j] += np.exp(-1.0 * d_ij / scale) + # add a small jitter on the diagonal + N = pixel_points.shape[0] + cov = cov + jnp.eye(N) * 1e-8 - return covariance_matrix + return cov class ExponentialKernel(AbstractRegularization): @@ -83,7 +75,7 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0): super().__init__() - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -102,9 +94,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -119,7 +111,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ covariance_matrix = exp_cov_matrix_from( scale=self.scale, - pixel_points=np.array(linear_obj.source_plane_mesh_grid), + pixel_points=linear_obj.source_plane_mesh_grid.array, ) - return self.coefficient * np.linalg.inv(covariance_matrix) + return self.coefficient * jnp.linalg.inv(covariance_matrix) diff --git a/autoarray/inversion/regularization/gaussian_kernel.py b/autoarray/inversion/regularization/gaussian_kernel.py index e133a22a2..4b600fba5 100644 --- a/autoarray/inversion/regularization/gaussian_kernel.py +++ b/autoarray/inversion/regularization/gaussian_kernel.py @@ -1,4 +1,5 @@ from __future__ import annotations +import jax.numpy as jnp import numpy as np from typing import TYPE_CHECKING @@ -7,52 +8,46 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray import numba_util - -@numba_util.jit() def gauss_cov_matrix_from( scale: float, - pixel_points: np.ndarray, -) -> np.ndarray: + pixel_points: jnp.ndarray, # shape (N, 2) +) -> jnp.ndarray: """ - Consutruct the source brightness covariance matrix, which is used to determined the regularization - pattern (i.e, how the different source pixels are smoothed). + Construct the source‐pixel Gaussian covariance matrix for regularization. + + For N source‐pixels at coordinates (y_i, x_i), we define - the covariance matrix includes one non-linear parameters, the scale coefficient, which is used to - determine the typical scale of the regularization pattern. + C_ij = exp( -||p_i - p_j||^2 / (2 scale^2) ) + + plus a tiny diagonal “jitter” (1e-8) to ensure numerical stability. Parameters ---------- scale - the typical scale of the regularization pattern . + The characteristic length scale of the Gaussian kernel. pixel_points - An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane). - Something like [[y1,x1], [y2,x2], ...] + Array of shape (N, 2), giving the (y, x) coordinates of each source pixel. Returns ------- - np.ndarray - The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels]. + cov : jnp.ndarray, shape (N, N) + The Gaussian covariance matrix. """ + # Ensure array: + pts = jnp.asarray(pixel_points) # (N, 2) + # Compute squared distances: ||p_i - p_j||^2 + diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2) + d2 = jnp.sum(diffs**2, axis=-1) # (N, N) - pixels = len(pixel_points) - covariance_matrix = np.zeros(shape=(pixels, pixels)) - - for i in range(pixels): - covariance_matrix[i, i] += 1e-8 - for j in range(pixels): - xi = pixel_points[i, 1] - yi = pixel_points[i, 0] - xj = pixel_points[j, 1] - yj = pixel_points[j, 0] - d_ij = np.sqrt( - (xi - xj) ** 2 + (yi - yj) ** 2 - ) # distance between the pixel i and j + # Gaussian kernel + cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N) - covariance_matrix[i, j] += np.exp(-1.0 * d_ij**2 / (2 * scale**2)) + # Add tiny jitter on the diagonal + N = pts.shape[0] + cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8 - return covariance_matrix + return cov class GaussianKernel(AbstractRegularization): @@ -117,7 +112,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ covariance_matrix = gauss_cov_matrix_from( - scale=self.scale, pixel_points=np.array(linear_obj.source_plane_mesh_grid) + scale=self.scale, pixel_points=linear_obj.source_plane_mesh_grid.array ) - return self.coefficient * np.linalg.inv(covariance_matrix) + return self.coefficient * jnp.linalg.inv(covariance_matrix) diff --git a/autoarray/inversion/regularization/regularization_util.py b/autoarray/inversion/regularization/regularization_util.py index cf0c6dc71..8cedca034 100644 --- a/autoarray/inversion/regularization/regularization_util.py +++ b/autoarray/inversion/regularization/regularization_util.py @@ -2,333 +2,29 @@ from typing import Tuple from autoarray import exc -from autoarray import numba_util - -@numba_util.jit() -def zeroth_regularization_matrix_from(coefficient: float, pixels: int) -> np.ndarray: - """ - Apply zeroth order regularization which penalizes every pixel's deviation from zero by addiing non-zero terms - to the regularization matrix. - - A complete description of regularization and the `regularization_matrix` can be found in the `Regularization` - class in the module `autoarray.inversion.regularization`. - - Parameters - ---------- - pixels - The number of pixels in the linear object which is to be regularized, being used to in the inversion. - coefficient - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - - Returns - ------- - np.ndarray - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - - regularization_matrix = np.zeros(shape=(pixels, pixels)) - - regularization_coefficient = coefficient**2.0 - - for i in range(pixels): - regularization_matrix[i, i] += regularization_coefficient - - return regularization_matrix - - -@numba_util.jit() -def constant_regularization_matrix_from( - coefficient: float, neighbors: np.ndarray, neighbors_sizes: np.ndarray -) -> np.ndarray: - """ - From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. - - A complete description of regularizatin and the `regularization_matrix` can be found in the `Regularization` - class in the module `autoarray.inversion.regularization`. - - Parameters - ---------- - coefficient - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - neighbors - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the Voronoi grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - np.ndarray - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - - parameters = len(neighbors) - - regularization_matrix = np.zeros(shape=(parameters, parameters)) - - regularization_coefficient = coefficient**2.0 - - for i in range(parameters): - regularization_matrix[i, i] += 1e-8 - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - regularization_matrix[i, i] += regularization_coefficient - regularization_matrix[i, neighbor_index] -= regularization_coefficient - - return regularization_matrix - - -@numba_util.jit() -def constant_zeroth_regularization_matrix_from( - coefficient: float, - coefficient_zeroth: float, - neighbors: np.ndarray, - neighbors_sizes: np.ndarray, -) -> np.ndarray: - """ - From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. - - A complete description of regularizatin and the ``regularization_matrix`` can be found in the ``Regularization`` - class in the module ``autoarray.inversion.regularization``. - - Parameters - ---------- - coefficients - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - neighbors - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the Voronoi grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - np.ndarray - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - - pixels = len(neighbors) - - regularization_matrix = np.zeros(shape=(pixels, pixels)) - - regularization_coefficient = coefficient**2.0 - regularization_coefficient_zeroth = coefficient_zeroth**2.0 - - for i in range(pixels): - regularization_matrix[i, i] += 1e-8 - regularization_matrix[i, i] += regularization_coefficient_zeroth - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - regularization_matrix[i, i] += regularization_coefficient - regularization_matrix[i, neighbor_index] -= regularization_coefficient - - return regularization_matrix - - -def adaptive_regularization_weights_from( - inner_coefficient: float, outer_coefficient: float, pixel_signals: np.ndarray -) -> np.ndarray: - """ - Returns the regularization weights for the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). - - The weights define the effective regularization coefficient of every mesh parameter (typically pixels - of a ``Mapper``). - - They are computed using an estimate of the expected signal in each pixel. - - Two regularization coefficients are used, corresponding to the: - - 1) pixel_signals: pixels with a high pixel-signal (i.e. where the signal is located in the pixelization). - 2) 1.0 - pixel_signals: pixels with a low pixel-signal (i.e. where the signal is not located in the pixelization). - - Parameters - ---------- - inner_coefficient - The inner regularization coefficients which controls the degree of smoothing of the inversion reconstruction - in the inner regions of a mesh's reconstruction. - outer_coefficient - The outer regularization coefficients which controls the degree of smoothing of the inversion reconstruction - in the outer regions of a mesh's reconstruction. - pixel_signals - The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal - and low signal pixelizations. - - Returns - ------- - np.ndarray - The adaptive regularization weights which act as the effective regularization coefficients of - every source pixel. - """ - return ( - inner_coefficient * pixel_signals + outer_coefficient * (1.0 - pixel_signals) - ) ** 2.0 - - -def brightness_zeroth_regularization_weights_from( - coefficient: float, pixel_signals: np.ndarray -) -> np.ndarray: - """ - Returns the regularization weights for the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``). - - The weights define the level of zeroth order regularization applied to every mesh parameter (typically pixels - of a ``Mapper``). - - They are computed using an estimate of the expected signal in each pixel. - - The zeroth order regularization coefficients is applied in combination with 1.0 - pixel_signals, which are - the pixels with a low pixel-signal (i.e. where the signal is not located near the source being reconstructed in - the pixelization). - - Parameters - ---------- - coefficient - The level of zeroth order regularization applied to every mesh parameter (typically pixels of a ``Mapper``), - with the degree applied varying based on the ``pixel_signals``. - pixel_signals - The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal - and low signal pixelizations. - - Returns - ------- - np.ndarray - The zeroth order regularization weights which act as the effective level of zeroth order regularization - applied to every mesh parameter. - """ - return coefficient * (1.0 - pixel_signals) - - -# @numba_util.jit() -def weighted_regularization_matrix_from( - regularization_weights: np.ndarray, - neighbors: np.ndarray, - neighbors_sizes: np.ndarray, -) -> np.ndarray: - """ - Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). - - This matrix is computed using the regularization weights of every mesh pixel, which are computed using the - function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of - every mesh pixel. - - The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate - neighbor calculation of the corresponding ``Mapper`` class. - - Parameters - ---------- - regularization_weights - The regularization weight of each pixel, adaptively governing the degree of gradient regularization - applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). - neighbors - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the mesh grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - np.ndarray - The regularization matrix computed using an adaptive regularization scheme where the effective regularization - coefficient of every source pixel is different. - """ - parameters = len(regularization_weights) - regularization_matrix = np.zeros((parameters, parameters)) - regularization_weight = regularization_weights**2.0 - - # Add small diagonal offset - np.fill_diagonal(regularization_matrix, 1e-8) - - for i in range(parameters): - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - w = regularization_weight[neighbor_index] - - regularization_matrix[i, i] += w - regularization_matrix[neighbor_index, neighbor_index] += w - regularization_matrix[i, neighbor_index] -= w - regularization_matrix[neighbor_index, i] -= w - - return regularization_matrix - - -# def weighted_regularization_matrix_from( -# regularization_weights: np.ndarray, -# neighbors: np.ndarray, -# neighbors_sizes: np.ndarray, -# ) -> np.ndarray: -# """ -# Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). -# -# This matrix is computed using the regularization weights of every mesh pixel, which are computed using the -# function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of -# every mesh pixel. -# -# The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate -# neighbor calculation of the corresponding ``Mapper`` class. -# -# Parameters -# ---------- -# regularization_weights -# The regularization weight of each pixel, adaptively governing the degree of gradient regularization -# applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). -# neighbors -# An array of length (total_pixels) which provides the index of all neighbors of every pixel in -# the mesh grid (entries of -1 correspond to no neighbor). -# neighbors_sizes -# An array of length (total_pixels) which gives the number of neighbors of every pixel in the -# Voronoi grid. -# -# Returns -# ------- -# np.ndarray -# The regularization matrix computed using an adaptive regularization scheme where the effective regularization -# coefficient of every source pixel is different. -# """ -# parameters = len(regularization_weights) -# regularization_matrix = np.zeros((parameters, parameters)) -# regularization_weight = regularization_weights**2.0 -# -# # Add small diagonal offset -# np.fill_diagonal(regularization_matrix, 1e-8) -# -# for i in range(parameters): -# for j in range(neighbors_sizes[i]): -# neighbor_index = neighbors[i, j] -# w = regularization_weight[neighbor_index] -# -# regularization_matrix[i, i] += w -# regularization_matrix[neighbor_index, neighbor_index] += w -# regularization_matrix[i, neighbor_index] -= w -# regularization_matrix[neighbor_index, i] -= w -# -# return regularization_matrix - - -def brightness_zeroth_regularization_matrix_from( - regularization_weights: np.ndarray, -) -> np.ndarray: - """ - Returns the regularization matrix for the zeroth-order brightness regularization scheme. - - Parameters - ---------- - regularization_weights - The regularization weights for each pixel, governing the strength of zeroth-order - regularization applied per inversion parameter. - - Returns - ------- - A diagonal regularization matrix where each diagonal element is the squared regularization weight - for that pixel. - """ - regularization_weight_squared = regularization_weights**2.0 - return np.diag(regularization_weight_squared) +from autoarray.inversion.regularization.adaptive_brightness import ( + adaptive_regularization_weights_from, +) +from autoarray.inversion.regularization.adaptive_brightness import ( + weighted_regularization_matrix_from, +) +from autoarray.inversion.regularization.brightness_zeroth import ( + brightness_zeroth_regularization_matrix_from, +) +from autoarray.inversion.regularization.brightness_zeroth import ( + brightness_zeroth_regularization_weights_from, +) +from autoarray.inversion.regularization.constant import ( + constant_regularization_matrix_from, +) +from autoarray.inversion.regularization.constant_zeroth import ( + constant_zeroth_regularization_matrix_from, +) +from autoarray.inversion.regularization.exponential_kernel import exp_cov_matrix_from +from autoarray.inversion.regularization.gaussian_kernel import gauss_cov_matrix_from +from autoarray.inversion.regularization.matern_kernel import matern_kernel +from autoarray.inversion.regularization.zeroth import zeroth_regularization_matrix_from def reg_split_from( diff --git a/autoarray/inversion/regularization/zeroth.py b/autoarray/inversion/regularization/zeroth.py index e30b1222e..04f61ad0e 100644 --- a/autoarray/inversion/regularization/zeroth.py +++ b/autoarray/inversion/regularization/zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,34 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def zeroth_regularization_matrix_from(coefficient: float, pixels: int) -> jnp.ndarray: + """ + Apply zeroth order regularization which penalizes every pixel's deviation from zero by addiing non-zero terms + to the regularization matrix. + + A complete description of regularization and the `regularization_matrix` can be found in the `Regularization` + class in the module `autoarray.inversion.regularization`. + + Parameters + ---------- + pixels + The number of pixels in the linear object which is to be regularized, being used to in the inversion. + coefficient + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + + Returns + ------- + np.ndarray + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + + reg_coeff = coefficient**2.0 + + # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T + + return jnp.eye(pixels) * reg_coeff class Zeroth(AbstractRegularization): @@ -60,9 +87,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -75,6 +102,6 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization matrix. """ - return regularization_util.zeroth_regularization_matrix_from( + return zeroth_regularization_matrix_from( coefficient=self.coefficient, pixels=linear_obj.params ) diff --git a/autoarray/plot/visuals/two_d.py b/autoarray/plot/visuals/two_d.py index 863e54eb9..35242db57 100644 --- a/autoarray/plot/visuals/two_d.py +++ b/autoarray/plot/visuals/two_d.py @@ -71,7 +71,10 @@ def plot_via_plotter(self, plotter, grid_indexes=None): plotter.mesh_grid_scatter.scatter_grid(grid=self.mesh_grid.array) if self.positions is not None: - plotter.positions_scatter.scatter_grid(grid=self.positions) + try: + plotter.positions_scatter.scatter_grid(grid=self.positions.array) + except (AttributeError, ValueError): + plotter.positions_scatter.scatter_grid(grid=self.positions) if self.vectors is not None: plotter.vector_yx_quiver.quiver_vectors(vectors=self.vectors) diff --git a/autoarray/plot/wrap/two_d/grid_scatter.py b/autoarray/plot/wrap/two_d/grid_scatter.py index 1d2acae8d..e9b9879d0 100644 --- a/autoarray/plot/wrap/two_d/grid_scatter.py +++ b/autoarray/plot/wrap/two_d/grid_scatter.py @@ -79,7 +79,17 @@ def scatter_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular] try: for grid in grid_list: - plt.scatter(y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict) + try: + plt.scatter( + y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict + ) + except ValueError: + plt.scatter( + y=grid.array[:, 0], + x=grid.array[:, 1], + c=next(color), + **config_dict, + ) except IndexError: return None diff --git a/autoarray/preloads.py b/autoarray/preloads.py index 6cedca99d..340d85bdd 100644 --- a/autoarray/preloads.py +++ b/autoarray/preloads.py @@ -14,6 +14,7 @@ def __init__( self, mapper_indices: np.ndarray = None, source_pixel_zeroed_indices: np.ndarray = None, + linear_light_profile_blurred_mapping_matrix=None, ): """ Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance @@ -37,11 +38,17 @@ def __init__( Indices of source pixels that should be set to zero in the reconstruction. These typically correspond to outer-edge source-plane regions with no image-plane mapping (e.g. outside a circular mask), helping separate the lens light from the pixelized source model. + linear_light_profile_blurred_mapping_matrix + The evaluated images of the linear light profiles that make up the blurred mapping matrix component of the + inversion, with the other component being the pixelization's pixels. These are fixed when the lens light + is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but + the intensity values will still be solved for during the inversion. """ self.mapper_indices = None self.source_pixel_zeroed_indices = None self.source_pixel_zeroed_indices_to_keep = None + self.linear_light_profile_blurred_mapping_matrix = None if mapper_indices is not None: @@ -58,3 +65,9 @@ def __init__( # Get the indices where values_to_solve is True self.source_pixel_zeroed_indices_to_keep = jnp.where(values_to_solve)[0] + + if linear_light_profile_blurred_mapping_matrix is not None: + + self.linear_light_profile_blurred_mapping_matrix = jnp.array( + linear_light_profile_blurred_mapping_matrix + ) diff --git a/autoarray/structures/mesh/rectangular_2d.py b/autoarray/structures/mesh/rectangular_2d.py index bf51c3d75..e8c7a8a82 100644 --- a/autoarray/structures/mesh/rectangular_2d.py +++ b/autoarray/structures/mesh/rectangular_2d.py @@ -136,7 +136,9 @@ def neighbors(self) -> Neighbors: @cached_property def edge_pixel_list(self) -> List: - return mesh_util.rectangular_edge_pixel_list_from(neighbors=self.neighbors) + return mesh_util.rectangular_edge_pixel_list_from( + shape_native=self.shape_native + ) @property def pixels(self) -> int: diff --git a/autoarray/structures/mesh/triangulation_2d.py b/autoarray/structures/mesh/triangulation_2d.py index 95ff46633..3ddb95c01 100644 --- a/autoarray/structures/mesh/triangulation_2d.py +++ b/autoarray/structures/mesh/triangulation_2d.py @@ -95,6 +95,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay": to compute the Voronoi mesh are ill posed. These exceptions are caught and combined into a single `MeshException`, which helps exception handling in the `inversion` package. """ + import scipy.spatial try: diff --git a/pyproject.toml b/pyproject.toml index 850141794..039c8d671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "astropy>=5.0,<=6.1.2", "decorator>=4.0.0", "dill>=0.3.1.1", + "jaxnnls==1.0.1", "matplotlib>=3.7.0", "scipy<=1.14.0", "scikit-image<=0.24.0", diff --git a/test_autoarray/inversion/regularizations/test_adaptive_brightness.py b/test_autoarray/inversion/regularizations/test_adaptive_brightness.py index b808682b7..b3cf4f132 100644 --- a/test_autoarray/inversion/regularizations/test_adaptive_brightness.py +++ b/test_autoarray/inversion/regularizations/test_adaptive_brightness.py @@ -55,7 +55,6 @@ def test__regularization_matrix__matches_util(): aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) ) diff --git a/test_autoarray/inversion/regularizations/test_regularization_util.py b/test_autoarray/inversion/regularizations/test_regularization_util.py index e9af395d4..05a4bd0d4 100644 --- a/test_autoarray/inversion/regularizations/test_regularization_util.py +++ b/test_autoarray/inversion/regularizations/test_regularization_util.py @@ -227,8 +227,6 @@ def test__brightness_zeroth_regularization_weights_from(): def test__weighted_regularization_matrix_from(): neighbors = np.array([[2], [3], [0], [1]]) - neighbors_sizes = np.array([1, 1, 1, 1]) - b_matrix = np.array([[-1, 0, 1, 0], [0, -1, 0, 1], [1, 0, -1, 0], [0, 1, 0, -1]]) test_regularization_matrix = np.matmul(b_matrix.T, b_matrix) @@ -238,7 +236,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -251,8 +248,6 @@ def test__weighted_regularization_matrix_from(): neighbors = np.array([[1, 2], [0, -1], [0, -1]]) - neighbors_sizes = np.array([2, 1, 1]) - b_matrix_1 = np.array( [[-1, 1, 0], [-1, 0, 1], [1, -1, 0]] # Pair 1 # Pair 2 ) # Pair 1 flip @@ -272,7 +267,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -291,14 +285,11 @@ def test__weighted_regularization_matrix_from(): neighbors = np.array([[1, 3], [0, 2], [1, 3], [0, 2]]) - neighbors_sizes = np.array([2, 2, 2, 2]) - regularization_weights = np.ones((4,)) regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -383,7 +374,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -415,12 +405,9 @@ def test__weighted_regularization_matrix_from(): [[1, 2, -1, -1], [0, 2, 3, -1], [0, 1, -1, -1], [1, -1, -1, -1]] ) - neighbors_sizes = np.array([2, 3, 2, 1]) - regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -436,7 +423,6 @@ def test__weighted_regularization_matrix_from(): ] ) - neighbors_sizes = np.array([2, 3, 4, 2, 4, 3]) regularization_weights = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) # I'm inputting the regularization weight_list directly thiss time, as it'd be a pain to multiply with a @@ -503,7 +489,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4)