From 1d6043ba2fb74cd5b1c6af343207ba29dc3d436c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 23 Jul 2025 16:07:43 +0100 Subject: [PATCH 1/6] regularization util JAX conversons, seem to work --- autoarray/inversion/inversion/abstract.py | 31 +--- .../pixelization/mappers/abstract.py | 2 +- .../inversion/pixelization/mesh/mesh_util.py | 152 ++++++++-------- .../regularization/adaptive_brightness.py | 1 - .../regularization/regularization_util.py | 172 +++++++----------- autoarray/plot/visuals/two_d.py | 5 +- autoarray/plot/wrap/two_d/grid_scatter.py | 5 +- autoarray/structures/mesh/rectangular_2d.py | 2 +- .../test_adaptive_brightness.py | 1 - .../test_regularization_util.py | 14 -- 10 files changed, 166 insertions(+), 219 deletions(-) diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index bc0daf0ad..fa490222b 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -1,5 +1,6 @@ import copy import jax.numpy as jnp +from jax.scipy.linalg import block_diag import numpy as np from typing import Dict, List, Optional, Type, Union @@ -334,8 +335,6 @@ def regularization_matrix(self) -> Optional[np.ndarray]: If the `settings.force_edge_pixels_to_zeros` is `True`, the edge pixels of each mapper in the inversion are regularized so high their value is forced to zero. """ - from scipy.linalg import block_diag - return block_diag( *[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list] ) @@ -664,30 +663,14 @@ def log_det_regularization_matrix_term(self) -> float: float The log determinant of the regularization matrix. """ - from scipy.sparse import csc_matrix - from scipy.sparse.linalg import splu - - if not self.has(cls=AbstractRegularization): - return 0.0 - try: - lu = splu(csc_matrix(self.regularization_matrix_reduced)) - diagL = lu.L.diagonal() - diagU = lu.U.diagonal() - diagL = diagL.astype(np.complex128) - diagU = diagU.astype(np.complex128) - - return np.real(np.log(diagL).sum() + np.log(diagU).sum()) - - except RuntimeError: - try: - return 2.0 * np.sum( - np.log( - np.diag(np.linalg.cholesky(self.regularization_matrix_reduced)) - ) + return 2.0 * np.sum( + jnp.log( + jnp.diag(jnp.linalg.cholesky(self.regularization_matrix_reduced)) ) - except np.linalg.LinAlgError as e: - raise exc.InversionException() from e + ) + except np.linalg.LinAlgError as e: + raise exc.InversionException() from e @property def reconstruction_noise_map_with_covariance(self) -> np.ndarray: diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index 091713af1..8d0fad744 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -288,7 +288,7 @@ def pixel_signals_from(self, signal_scale: float) -> np.ndarray: pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index, slim_index_for_sub_slim_index=self.over_sampler.slim_for_sub_slim, - adapt_data=np.array(self.adapt_data), + adapt_data=self.adapt_data.array, ) def slim_indexes_for_pix_indexes(self, pix_indexes: List) -> List[List]: diff --git a/autoarray/inversion/pixelization/mesh/mesh_util.py b/autoarray/inversion/pixelization/mesh/mesh_util.py index ea1384349..2160b8629 100644 --- a/autoarray/inversion/pixelization/mesh/mesh_util.py +++ b/autoarray/inversion/pixelization/mesh/mesh_util.py @@ -5,7 +5,6 @@ from autoarray import numba_util -@numba_util.jit() def rectangular_neighbors_from( shape_native: Tuple[int, int], ) -> Tuple[np.ndarray, np.ndarray]: @@ -68,7 +67,6 @@ def rectangular_neighbors_from( return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_corner_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -112,8 +110,6 @@ def rectangular_corner_neighbors( return neighbors, neighbors_sizes - -@numba_util.jit() def rectangular_top_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -136,17 +132,19 @@ def rectangular_top_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[1] - 1): - pixel_index = pix - neighbors[pixel_index, 0:3] = np.array( - [pixel_index - 1, pixel_index + 1, pixel_index + shape_native[1]] - ) - neighbors_sizes[pixel_index] = 3 + """ + Vectorized version of the top edge neighbor update using NumPy arithmetic. + """ + # Pixels along the top edge, excluding corners + top_edge_pixels = np.arange(1, shape_native[1] - 1) - return neighbors, neighbors_sizes + neighbors[top_edge_pixels, 0] = top_edge_pixels - 1 + neighbors[top_edge_pixels, 1] = top_edge_pixels + 1 + neighbors[top_edge_pixels, 2] = top_edge_pixels + shape_native[1] + neighbors_sizes[top_edge_pixels] = 3 + return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_left_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -169,21 +167,19 @@ def rectangular_left_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[0] - 1): - pixel_index = pix * shape_native[1] - neighbors[pixel_index, 0:3] = np.array( - [ - pixel_index - shape_native[1], - pixel_index + 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 3 + # Row indices (excluding top and bottom corners) + rows = np.arange(1, shape_native[0] - 1) - return neighbors, neighbors_sizes + # Convert to flat pixel indices for the left edge (first column) + pixel_indices = rows * shape_native[1] + neighbors[pixel_indices, 0] = pixel_indices - shape_native[1] + neighbors[pixel_indices, 1] = pixel_indices + 1 + neighbors[pixel_indices, 2] = pixel_indices + shape_native[1] + neighbors_sizes[pixel_indices] = 3 + + return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_right_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -206,21 +202,19 @@ def rectangular_right_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[0] - 1): - pixel_index = pix * shape_native[1] + shape_native[1] - 1 - neighbors[pixel_index, 0:3] = np.array( - [ - pixel_index - shape_native[1], - pixel_index - 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 3 + # Rows excluding the top and bottom corners + rows = np.arange(1, shape_native[0] - 1) - return neighbors, neighbors_sizes + # Flat indices for the right edge pixels + pixel_indices = rows * shape_native[1] + shape_native[1] - 1 + neighbors[pixel_indices, 0] = pixel_indices - shape_native[1] + neighbors[pixel_indices, 1] = pixel_indices - 1 + neighbors[pixel_indices, 2] = pixel_indices + shape_native[1] + neighbors_sizes[pixel_indices] = 3 + + return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_bottom_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -243,19 +237,21 @@ def rectangular_bottom_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - pixels = int(shape_native[0] * shape_native[1]) + n_rows, n_cols = shape_native + pixels = n_rows * n_cols + + # Horizontal pixel positions along bottom row, excluding corners + cols = np.arange(1, n_cols - 1) + pixel_indices = pixels - cols - 1 # Reverse order from right to left - for pix in range(1, shape_native[1] - 1): - pixel_index = pixels - pix - 1 - neighbors[pixel_index, 0:3] = np.array( - [pixel_index - shape_native[1], pixel_index - 1, pixel_index + 1] - ) - neighbors_sizes[pixel_index] = 3 + neighbors[pixel_indices, 0] = pixel_indices - n_cols + neighbors[pixel_indices, 1] = pixel_indices - 1 + neighbors[pixel_indices, 2] = pixel_indices + 1 + neighbors_sizes[pixel_indices] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_central_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -279,46 +275,60 @@ def rectangular_central_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for x in range(1, shape_native[0] - 1): - for y in range(1, shape_native[1] - 1): - pixel_index = x * shape_native[1] + y - neighbors[pixel_index, 0:4] = np.array( - [ - pixel_index - shape_native[1], - pixel_index - 1, - pixel_index + 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 4 + n_rows, n_cols = shape_native - return neighbors, neighbors_sizes + # Grid coordinates excluding edges + xs = np.arange(1, n_rows - 1) + ys = np.arange(1, n_cols - 1) + # 2D grid of central pixel indices + grid_x, grid_y = np.meshgrid(xs, ys, indexing="ij") + pixel_indices = grid_x * n_cols + grid_y + pixel_indices = pixel_indices.ravel() -def rectangular_edge_pixel_list_from(neighbors: np.ndarray) -> List: - """ - Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization. + # Compute neighbor indices + neighbors[pixel_indices, 0] = pixel_indices - n_cols # Up + neighbors[pixel_indices, 1] = pixel_indices - 1 # Left + neighbors[pixel_indices, 2] = pixel_indices + 1 # Right + neighbors[pixel_indices, 3] = pixel_indices + n_cols # Down + + neighbors_sizes[pixel_indices] = 4 + + return neighbors, neighbors_sizes - This is computed by searching the `neighbors` array for pixels that have a neighbor with index -1, meaning there - is at least one neighbor from the 4 expected missing. +def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]: + """ + Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization, + based on its 2D shape. Parameters ---------- - neighbors - An array of dimensions [total_pixels, 4] which provides the index of all neighbors of every pixel in the - rectangular pixelization (entries of -1 correspond to no neighbor). + shape_native + The (rows, cols) shape of the rectangular 2D pixel grid. Returns ------- - A list of the 1D indices of all pixels on the edge of a rectangular pixelization. + A list of the 1D indices of all edge pixels. """ - edge_pixel_list = [] + rows, cols = shape_native + + # Top row + top = np.arange(0, cols) + + # Bottom row + bottom = np.arange((rows - 1) * cols, rows * cols) + + # Left column (excluding corners) + left = np.arange(1, rows - 1) * cols + + # Right column (excluding corners) + right = (np.arange(1, rows - 1) + 1) * cols - 1 - for i, neighbors in enumerate(neighbors): - if -1 in neighbors: - edge_pixel_list.append(i) + # Concatenate all edge indices + edge_pixel_indices = np.concatenate([top, left, right, bottom]) - return edge_pixel_list + # Sort and return + return np.sort(edge_pixel_indices).tolist() @numba_util.jit() diff --git a/autoarray/inversion/regularization/adaptive_brightness.py b/autoarray/inversion/regularization/adaptive_brightness.py index d7c322817..27f399fd0 100644 --- a/autoarray/inversion/regularization/adaptive_brightness.py +++ b/autoarray/inversion/regularization/adaptive_brightness.py @@ -115,5 +115,4 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: return regularization_util.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=linear_obj.source_plane_mesh_grid.neighbors, - neighbors_sizes=linear_obj.source_plane_mesh_grid.neighbors.sizes, ) diff --git a/autoarray/inversion/regularization/regularization_util.py b/autoarray/inversion/regularization/regularization_util.py index cf0c6dc71..e8d798a55 100644 --- a/autoarray/inversion/regularization/regularization_util.py +++ b/autoarray/inversion/regularization/regularization_util.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from typing import Tuple @@ -5,7 +6,6 @@ from autoarray import numba_util -@numba_util.jit() def zeroth_regularization_matrix_from(coefficient: float, pixels: int) -> np.ndarray: """ Apply zeroth order regularization which penalizes every pixel's deviation from zero by addiing non-zero terms @@ -27,59 +27,60 @@ class in the module `autoarray.inversion.regularization`. The regularization matrix computed using Regularization where the effective regularization coefficient of every source pixel is the same. """ + reg_coeff = coefficient ** 2.0 + # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T + return jnp.eye(pixels) * reg_coeff - regularization_matrix = np.zeros(shape=(pixels, pixels)) - - regularization_coefficient = coefficient**2.0 - - for i in range(pixels): - regularization_matrix[i, i] += regularization_coefficient - return regularization_matrix - - -@numba_util.jit() def constant_regularization_matrix_from( - coefficient: float, neighbors: np.ndarray, neighbors_sizes: np.ndarray -) -> np.ndarray: + coefficient: float, + neighbors: np.ndarray[[int, int], np.int64], + neighbors_sizes: np.ndarray[[int], np.int64], +) -> np.ndarray[[int, int], np.float64]: """ From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. A complete description of regularizatin and the `regularization_matrix` can be found in the `Regularization` class in the module `autoarray.inversion.regularization`. + Memory requirement: 2SP + S^2 + FLOPS: 1 + 2S + 2SP + Parameters ---------- coefficient The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - neighbors + neighbors : ndarray, shape (S, P), dtype=int64 An array of length (total_pixels) which provides the index of all neighbors of every pixel in the Voronoi grid (entries of -1 correspond to no neighbor). - neighbors_sizes + neighbors_sizes : ndarray, shape (S,), dtype=int64 An array of length (total_pixels) which gives the number of neighbors of every pixel in the Voronoi grid. Returns ------- - np.ndarray + regularization_matrix : ndarray, shape (S, S), dtype=float64 The regularization matrix computed using Regularization where the effective regularization coefficient of every source pixel is the same. """ + S, P = neighbors.shape + # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) + OUT_OF_BOUND_IDX = S + regularization_coefficient = coefficient * coefficient + + # flatten it for feeding into the matrix as j indices + neighbors = neighbors.flatten() + # now create the corresponding i indices + I_IDX = jnp.repeat(jnp.arange(S), P) + # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. + # This ensures that JAX can efficiently drop these entries during matrix updates. + neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) + return ( + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors] + # unique indices should be guranteed by neighbors-spec + .add(-regularization_coefficient, mode="drop", unique_indices=True) + ) - parameters = len(neighbors) - - regularization_matrix = np.zeros(shape=(parameters, parameters)) - - regularization_coefficient = coefficient**2.0 - - for i in range(parameters): - regularization_matrix[i, i] += 1e-8 - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - regularization_matrix[i, i] += regularization_coefficient - regularization_matrix[i, neighbor_index] -= regularization_coefficient - - return regularization_matrix @numba_util.jit() @@ -203,11 +204,9 @@ def brightness_zeroth_regularization_weights_from( return coefficient * (1.0 - pixel_signals) -# @numba_util.jit() def weighted_regularization_matrix_from( regularization_weights: np.ndarray, neighbors: np.ndarray, - neighbors_sizes: np.ndarray, ) -> np.ndarray: """ Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). @@ -237,78 +236,43 @@ def weighted_regularization_matrix_from( The regularization matrix computed using an adaptive regularization scheme where the effective regularization coefficient of every source pixel is different. """ - parameters = len(regularization_weights) - regularization_matrix = np.zeros((parameters, parameters)) - regularization_weight = regularization_weights**2.0 - - # Add small diagonal offset - np.fill_diagonal(regularization_matrix, 1e-8) - - for i in range(parameters): - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - w = regularization_weight[neighbor_index] - - regularization_matrix[i, i] += w - regularization_matrix[neighbor_index, neighbor_index] += w - regularization_matrix[i, neighbor_index] -= w - regularization_matrix[neighbor_index, i] -= w - - return regularization_matrix - - -# def weighted_regularization_matrix_from( -# regularization_weights: np.ndarray, -# neighbors: np.ndarray, -# neighbors_sizes: np.ndarray, -# ) -> np.ndarray: -# """ -# Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). -# -# This matrix is computed using the regularization weights of every mesh pixel, which are computed using the -# function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of -# every mesh pixel. -# -# The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate -# neighbor calculation of the corresponding ``Mapper`` class. -# -# Parameters -# ---------- -# regularization_weights -# The regularization weight of each pixel, adaptively governing the degree of gradient regularization -# applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). -# neighbors -# An array of length (total_pixels) which provides the index of all neighbors of every pixel in -# the mesh grid (entries of -1 correspond to no neighbor). -# neighbors_sizes -# An array of length (total_pixels) which gives the number of neighbors of every pixel in the -# Voronoi grid. -# -# Returns -# ------- -# np.ndarray -# The regularization matrix computed using an adaptive regularization scheme where the effective regularization -# coefficient of every source pixel is different. -# """ -# parameters = len(regularization_weights) -# regularization_matrix = np.zeros((parameters, parameters)) -# regularization_weight = regularization_weights**2.0 -# -# # Add small diagonal offset -# np.fill_diagonal(regularization_matrix, 1e-8) -# -# for i in range(parameters): -# for j in range(neighbors_sizes[i]): -# neighbor_index = neighbors[i, j] -# w = regularization_weight[neighbor_index] -# -# regularization_matrix[i, i] += w -# regularization_matrix[neighbor_index, neighbor_index] += w -# regularization_matrix[i, neighbor_index] -= w -# regularization_matrix[neighbor_index, i] -= w -# -# return regularization_matrix - + S, P = neighbors.shape + reg_w = regularization_weights ** 2 + + # 1) Flatten the (i→j) neighbor pairs + I = jnp.repeat(jnp.arange(S), P) # (S*P,) + J = neighbors.reshape(-1) # (S*P,) + + # 2) Remap “no neighbor” entries to an extra slot S, whose weight=0 + OUT = S + J = jnp.where(J < 0, OUT, J) + + # 3) Build an extended weight vector with a zero at index S + reg_w_ext = jnp.concatenate([reg_w, jnp.zeros((1,))], axis=0) + w_ij = reg_w_ext[J] # (S*P,) + + # 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely + mat = jnp.zeros((S + 1, S + 1), dtype=regularization_weights.dtype) + + # 5) Scatter into the diagonal: + # - the tiny 1e-8 floor on each i < S + # - sum_j reg_w[j] into diag[i] + # - sum contributions reg_w[j] into diag[j] + # (diagonal at OUT=S picks up zeros only) + diag_updates_i = jnp.concatenate([ + jnp.full((S,), 1e-8), + jnp.zeros((1,)) # out‐of‐bounds slot stays zero + ], axis=0) + mat = mat.at[jnp.diag_indices(S + 1)].add(diag_updates_i) + mat = mat.at[I, I].add(w_ij) + mat = mat.at[J, J].add(w_ij) + + # 6) Scatter the off‐diagonal subtractions: + mat = mat.at[I, J].add(-w_ij) + mat = mat.at[J, I].add(-w_ij) + + # 7) Drop the extra row/column S and return the S×S result + return mat[:S, :S] def brightness_zeroth_regularization_matrix_from( regularization_weights: np.ndarray, diff --git a/autoarray/plot/visuals/two_d.py b/autoarray/plot/visuals/two_d.py index 863e54eb9..35242db57 100644 --- a/autoarray/plot/visuals/two_d.py +++ b/autoarray/plot/visuals/two_d.py @@ -71,7 +71,10 @@ def plot_via_plotter(self, plotter, grid_indexes=None): plotter.mesh_grid_scatter.scatter_grid(grid=self.mesh_grid.array) if self.positions is not None: - plotter.positions_scatter.scatter_grid(grid=self.positions) + try: + plotter.positions_scatter.scatter_grid(grid=self.positions.array) + except (AttributeError, ValueError): + plotter.positions_scatter.scatter_grid(grid=self.positions) if self.vectors is not None: plotter.vector_yx_quiver.quiver_vectors(vectors=self.vectors) diff --git a/autoarray/plot/wrap/two_d/grid_scatter.py b/autoarray/plot/wrap/two_d/grid_scatter.py index 1d2acae8d..ad934f4c0 100644 --- a/autoarray/plot/wrap/two_d/grid_scatter.py +++ b/autoarray/plot/wrap/two_d/grid_scatter.py @@ -79,7 +79,10 @@ def scatter_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular] try: for grid in grid_list: - plt.scatter(y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict) + try: + plt.scatter(y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict) + except ValueError: + plt.scatter(y=grid.array[:, 0], x=grid.array[:, 1], c=next(color), **config_dict) except IndexError: return None diff --git a/autoarray/structures/mesh/rectangular_2d.py b/autoarray/structures/mesh/rectangular_2d.py index bf51c3d75..28a9e2372 100644 --- a/autoarray/structures/mesh/rectangular_2d.py +++ b/autoarray/structures/mesh/rectangular_2d.py @@ -136,7 +136,7 @@ def neighbors(self) -> Neighbors: @cached_property def edge_pixel_list(self) -> List: - return mesh_util.rectangular_edge_pixel_list_from(neighbors=self.neighbors) + return mesh_util.rectangular_edge_pixel_list_from(shape_native=self.shape_native) @property def pixels(self) -> int: diff --git a/test_autoarray/inversion/regularizations/test_adaptive_brightness.py b/test_autoarray/inversion/regularizations/test_adaptive_brightness.py index b808682b7..b3cf4f132 100644 --- a/test_autoarray/inversion/regularizations/test_adaptive_brightness.py +++ b/test_autoarray/inversion/regularizations/test_adaptive_brightness.py @@ -55,7 +55,6 @@ def test__regularization_matrix__matches_util(): aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) ) diff --git a/test_autoarray/inversion/regularizations/test_regularization_util.py b/test_autoarray/inversion/regularizations/test_regularization_util.py index e9af395d4..1b26f7095 100644 --- a/test_autoarray/inversion/regularizations/test_regularization_util.py +++ b/test_autoarray/inversion/regularizations/test_regularization_util.py @@ -227,8 +227,6 @@ def test__brightness_zeroth_regularization_weights_from(): def test__weighted_regularization_matrix_from(): neighbors = np.array([[2], [3], [0], [1]]) - neighbors_sizes = np.array([1, 1, 1, 1]) - b_matrix = np.array([[-1, 0, 1, 0], [0, -1, 0, 1], [1, 0, -1, 0], [0, 1, 0, -1]]) test_regularization_matrix = np.matmul(b_matrix.T, b_matrix) @@ -238,7 +236,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -251,8 +248,6 @@ def test__weighted_regularization_matrix_from(): neighbors = np.array([[1, 2], [0, -1], [0, -1]]) - neighbors_sizes = np.array([2, 1, 1]) - b_matrix_1 = np.array( [[-1, 1, 0], [-1, 0, 1], [1, -1, 0]] # Pair 1 # Pair 2 ) # Pair 1 flip @@ -272,7 +267,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -291,14 +285,11 @@ def test__weighted_regularization_matrix_from(): neighbors = np.array([[1, 3], [0, 2], [1, 3], [0, 2]]) - neighbors_sizes = np.array([2, 2, 2, 2]) - regularization_weights = np.ones((4,)) regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -383,7 +374,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -415,12 +405,10 @@ def test__weighted_regularization_matrix_from(): [[1, 2, -1, -1], [0, 2, 3, -1], [0, 1, -1, -1], [1, -1, -1, -1]] ) - neighbors_sizes = np.array([2, 3, 2, 1]) regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -436,7 +424,6 @@ def test__weighted_regularization_matrix_from(): ] ) - neighbors_sizes = np.array([2, 3, 4, 2, 4, 3]) regularization_weights = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) # I'm inputting the regularization weight_list directly thiss time, as it'd be a pain to multiply with a @@ -503,7 +490,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) From 8a6951acba9e17e5b93bc6a459b323ee2e151961 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 23 Jul 2025 16:49:01 +0100 Subject: [PATCH 2/6] gaussian kernel converted successfully --- autoarray/inversion/inversion/abstract.py | 3 + .../regularization/constant_zeroth.py | 1 - .../regularization/gaussian_kernel.py | 59 +++++++++---------- .../regularization/regularization_util.py | 37 +++++++----- autoarray/preloads.py | 13 ++++ .../test_regularization_util.py | 3 - 6 files changed, 65 insertions(+), 51 deletions(-) diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index fa490222b..55c3b61de 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -663,6 +663,9 @@ def log_det_regularization_matrix_term(self) -> float: float The log determinant of the regularization matrix. """ + if not self.has(cls=AbstractRegularization): + return 0.0 + try: return 2.0 * np.sum( jnp.log( diff --git a/autoarray/inversion/regularization/constant_zeroth.py b/autoarray/inversion/regularization/constant_zeroth.py index 5e3d8acb3..45e1729cf 100644 --- a/autoarray/inversion/regularization/constant_zeroth.py +++ b/autoarray/inversion/regularization/constant_zeroth.py @@ -55,5 +55,4 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: coefficient=self.coefficient_neighbor, coefficient_zeroth=self.coefficient_zeroth, neighbors=linear_obj.neighbors, - neighbors_sizes=linear_obj.neighbors.sizes, ) diff --git a/autoarray/inversion/regularization/gaussian_kernel.py b/autoarray/inversion/regularization/gaussian_kernel.py index e133a22a2..27e248dd3 100644 --- a/autoarray/inversion/regularization/gaussian_kernel.py +++ b/autoarray/inversion/regularization/gaussian_kernel.py @@ -1,4 +1,5 @@ from __future__ import annotations +import jax.numpy as jnp import numpy as np from typing import TYPE_CHECKING @@ -7,52 +8,45 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray import numba_util - - -@numba_util.jit() def gauss_cov_matrix_from( scale: float, - pixel_points: np.ndarray, -) -> np.ndarray: + pixel_points: jnp.ndarray, # shape (N, 2) +) -> jnp.ndarray: """ - Consutruct the source brightness covariance matrix, which is used to determined the regularization - pattern (i.e, how the different source pixels are smoothed). + Construct the source‐pixel Gaussian covariance matrix for regularization. + + For N source‐pixels at coordinates (y_i, x_i), we define - the covariance matrix includes one non-linear parameters, the scale coefficient, which is used to - determine the typical scale of the regularization pattern. + C_ij = exp( -||p_i - p_j||^2 / (2 scale^2) ) + + plus a tiny diagonal “jitter” (1e-8) to ensure numerical stability. Parameters ---------- scale - the typical scale of the regularization pattern . + The characteristic length scale of the Gaussian kernel. pixel_points - An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane). - Something like [[y1,x1], [y2,x2], ...] + Array of shape (N, 2), giving the (y, x) coordinates of each source pixel. Returns ------- - np.ndarray - The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels]. + cov : jnp.ndarray, shape (N, N) + The Gaussian covariance matrix. """ + # Ensure array: + pts = jnp.asarray(pixel_points) # (N, 2) + # Compute squared distances: ||p_i - p_j||^2 + diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2) + d2 = jnp.sum(diffs**2, axis=-1) # (N, N) - pixels = len(pixel_points) - covariance_matrix = np.zeros(shape=(pixels, pixels)) - - for i in range(pixels): - covariance_matrix[i, i] += 1e-8 - for j in range(pixels): - xi = pixel_points[i, 1] - yi = pixel_points[i, 0] - xj = pixel_points[j, 1] - yj = pixel_points[j, 0] - d_ij = np.sqrt( - (xi - xj) ** 2 + (yi - yj) ** 2 - ) # distance between the pixel i and j + # Gaussian kernel + cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N) - covariance_matrix[i, j] += np.exp(-1.0 * d_ij**2 / (2 * scale**2)) + # Add tiny jitter on the diagonal + N = pts.shape[0] + cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8 - return covariance_matrix + return cov class GaussianKernel(AbstractRegularization): @@ -117,7 +111,8 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ covariance_matrix = gauss_cov_matrix_from( - scale=self.scale, pixel_points=np.array(linear_obj.source_plane_mesh_grid) + scale=self.scale, + pixel_points=linear_obj.source_plane_mesh_grid.array ) - return self.coefficient * np.linalg.inv(covariance_matrix) + return self.coefficient * jnp.linalg.inv(covariance_matrix) diff --git a/autoarray/inversion/regularization/regularization_util.py b/autoarray/inversion/regularization/regularization_util.py index e8d798a55..2073ebdf7 100644 --- a/autoarray/inversion/regularization/regularization_util.py +++ b/autoarray/inversion/regularization/regularization_util.py @@ -82,13 +82,10 @@ class in the module `autoarray.inversion.regularization`. ) - -@numba_util.jit() def constant_zeroth_regularization_matrix_from( coefficient: float, coefficient_zeroth: float, neighbors: np.ndarray, - neighbors_sizes: np.ndarray, ) -> np.ndarray: """ From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. @@ -113,24 +110,34 @@ class in the module ``autoarray.inversion.regularization``. The regularization matrix computed using Regularization where the effective regularization coefficient of every source pixel is the same. """ + S, P = neighbors.shape + reg1 = coefficient**2 + reg0 = coefficient_zeroth**2 + + # 1) Flatten (i,j) neighbor‐pairs + I = jnp.repeat(jnp.arange(S), P) # (S*P,) + J = neighbors.reshape(-1) # (S*P,) - pixels = len(neighbors) + # 2) Remap “no neighbor” = -1 → OUT = S + OUT = S + J = jnp.where(J < 0, OUT, J) - regularization_matrix = np.zeros(shape=(pixels, pixels)) + # 3) Start on an (S+1)x(S+1) zero canvas + M = jnp.zeros((S+1, S+1), dtype=jnp.float32) - regularization_coefficient = coefficient**2.0 - regularization_coefficient_zeroth = coefficient_zeroth**2.0 + # 4) Diagonal baseline: 1e-8 + reg0 for i in [0..S-1] + diag_base = jnp.concatenate([jnp.full((S,), 1e-8 + reg0), jnp.zeros((1,))]) + M = M.at[jnp.diag_indices(S+1)].add(diag_base) - for i in range(pixels): - regularization_matrix[i, i] += 1e-8 - regularization_matrix[i, i] += regularization_coefficient_zeroth - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - regularization_matrix[i, i] += regularization_coefficient - regularization_matrix[i, neighbor_index] -= regularization_coefficient + # 5) Scatter the first-order reg1 into diag[i] for each neighbor (i→j): + # M[i,i] += reg1 + M = M.at[I, I].add(reg1) - return regularization_matrix + # 6) Scatter the off-diagonals: M[i,j] -= reg1 + M = M.at[I, J].add(-reg1) + # 7) Return only the top‐left S×S block + return M[:S, :S] def adaptive_regularization_weights_from( inner_coefficient: float, outer_coefficient: float, pixel_signals: np.ndarray diff --git a/autoarray/preloads.py b/autoarray/preloads.py index 6cedca99d..b3f0ca669 100644 --- a/autoarray/preloads.py +++ b/autoarray/preloads.py @@ -14,6 +14,7 @@ def __init__( self, mapper_indices: np.ndarray = None, source_pixel_zeroed_indices: np.ndarray = None, + linear_light_profile_blurred_mapping_matrix = None, ): """ Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance @@ -37,11 +38,17 @@ def __init__( Indices of source pixels that should be set to zero in the reconstruction. These typically correspond to outer-edge source-plane regions with no image-plane mapping (e.g. outside a circular mask), helping separate the lens light from the pixelized source model. + linear_light_profile_blurred_mapping_matrix + The evaluated images of the linear light profiles that make up the blurred mapping matrix component of the + inversion, with the other component being the pixelization's pixels. These are fixed when the lens light + is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but + the intensity values will still be solved for during the inversion. """ self.mapper_indices = None self.source_pixel_zeroed_indices = None self.source_pixel_zeroed_indices_to_keep = None + self.linear_light_profile_blurred_mapping_matrix = None if mapper_indices is not None: @@ -58,3 +65,9 @@ def __init__( # Get the indices where values_to_solve is True self.source_pixel_zeroed_indices_to_keep = jnp.where(values_to_solve)[0] + + if linear_light_profile_blurred_mapping_matrix is not None: + + self.linear_light_profile_blurred_mapping_matrix = jnp.array( + linear_light_profile_blurred_mapping_matrix + ) \ No newline at end of file diff --git a/test_autoarray/inversion/regularizations/test_regularization_util.py b/test_autoarray/inversion/regularizations/test_regularization_util.py index 1b26f7095..f189491fa 100644 --- a/test_autoarray/inversion/regularizations/test_regularization_util.py +++ b/test_autoarray/inversion/regularizations/test_regularization_util.py @@ -147,14 +147,11 @@ def test__constant_regularization_matrix_from(): def test__constant_zeroth_regularization_matrix_from(): neighbors = np.array([[1, 2, -1], [0, -1, -1], [0, -1, -1]]) - neighbors_sizes = np.array([2, 1, 1]) - regularization_matrix = ( aa.util.regularization.constant_zeroth_regularization_matrix_from( coefficient=2.0, coefficient_zeroth=0.5, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) ) From a50901488020811c5be52e4174108769b4169bc4 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 23 Jul 2025 16:59:52 +0100 Subject: [PATCH 3/6] fix constant zeroth --- .../regularization/regularization_util.py | 44 +++++++++---------- .../test_regularization_util.py | 3 ++ 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/autoarray/inversion/regularization/regularization_util.py b/autoarray/inversion/regularization/regularization_util.py index 2073ebdf7..44e2e326a 100644 --- a/autoarray/inversion/regularization/regularization_util.py +++ b/autoarray/inversion/regularization/regularization_util.py @@ -86,6 +86,7 @@ def constant_zeroth_regularization_matrix_from( coefficient: float, coefficient_zeroth: float, neighbors: np.ndarray, + neighbors_sizes: np.ndarray[[int], np.int64], ) -> np.ndarray: """ From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. @@ -111,33 +112,28 @@ class in the module ``autoarray.inversion.regularization``. coefficient of every source pixel is the same. """ S, P = neighbors.shape - reg1 = coefficient**2 - reg0 = coefficient_zeroth**2 - - # 1) Flatten (i,j) neighbor‐pairs - I = jnp.repeat(jnp.arange(S), P) # (S*P,) - J = neighbors.reshape(-1) # (S*P,) - - # 2) Remap “no neighbor” = -1 → OUT = S - OUT = S - J = jnp.where(J < 0, OUT, J) - - # 3) Start on an (S+1)x(S+1) zero canvas - M = jnp.zeros((S+1, S+1), dtype=jnp.float32) - - # 4) Diagonal baseline: 1e-8 + reg0 for i in [0..S-1] - diag_base = jnp.concatenate([jnp.full((S,), 1e-8 + reg0), jnp.zeros((1,))]) - M = M.at[jnp.diag_indices(S+1)].add(diag_base) + # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) + OUT_OF_BOUND_IDX = S + regularization_coefficient = coefficient * coefficient - # 5) Scatter the first-order reg1 into diag[i] for each neighbor (i→j): - # M[i,i] += reg1 - M = M.at[I, I].add(reg1) + # flatten it for feeding into the matrix as j indices + neighbors = neighbors.flatten() + # now create the corresponding i indices + I_IDX = jnp.repeat(jnp.arange(S), P) + # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. + # This ensures that JAX can efficiently drop these entries during matrix updates. + neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) + const = ( + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors] + # unique indices should be guranteed by neighbors-spec + .add(-regularization_coefficient, mode="drop", unique_indices=True) + ) - # 6) Scatter the off-diagonals: M[i,j] -= reg1 - M = M.at[I, J].add(-reg1) + reg_coeff = coefficient_zeroth ** 2.0 + # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T + zeroth = jnp.eye(P) * reg_coeff - # 7) Return only the top‐left S×S block - return M[:S, :S] + return const + zeroth def adaptive_regularization_weights_from( inner_coefficient: float, outer_coefficient: float, pixel_signals: np.ndarray diff --git a/test_autoarray/inversion/regularizations/test_regularization_util.py b/test_autoarray/inversion/regularizations/test_regularization_util.py index f189491fa..1b26f7095 100644 --- a/test_autoarray/inversion/regularizations/test_regularization_util.py +++ b/test_autoarray/inversion/regularizations/test_regularization_util.py @@ -147,11 +147,14 @@ def test__constant_regularization_matrix_from(): def test__constant_zeroth_regularization_matrix_from(): neighbors = np.array([[1, 2, -1], [0, -1, -1], [0, -1, -1]]) + neighbors_sizes = np.array([2, 1, 1]) + regularization_matrix = ( aa.util.regularization.constant_zeroth_regularization_matrix_from( coefficient=2.0, coefficient_zeroth=0.5, neighbors=neighbors, + neighbors_sizes=neighbors_sizes, ) ) From 150a3603dd932e7e666a6ec9d422cb1d25d824c9 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 23 Jul 2025 17:09:11 +0100 Subject: [PATCH 4/6] move utils to their specific modules --- .../inversion/regularization/__init__.py | 1 + .../regularization/adaptive_brightness.py | 119 ++++++- .../regularization/brightness_zeroth.py | 64 +++- .../inversion/regularization/constant.py | 60 +++- .../regularization/constant_zeroth.py | 64 +++- .../regularization/regularization_util.py | 301 +----------------- autoarray/inversion/regularization/zeroth.py | 37 ++- 7 files changed, 325 insertions(+), 321 deletions(-) diff --git a/autoarray/inversion/regularization/__init__.py b/autoarray/inversion/regularization/__init__.py index e34d07b6b..c592cb51f 100644 --- a/autoarray/inversion/regularization/__init__.py +++ b/autoarray/inversion/regularization/__init__.py @@ -10,3 +10,4 @@ from .gaussian_kernel import GaussianKernel from .exponential_kernel import ExponentialKernel from .matern_kernel import MaternKernel + diff --git a/autoarray/inversion/regularization/adaptive_brightness.py b/autoarray/inversion/regularization/adaptive_brightness.py index 27f399fd0..cf5af2b31 100644 --- a/autoarray/inversion/regularization/adaptive_brightness.py +++ b/autoarray/inversion/regularization/adaptive_brightness.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,8 +7,115 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util +def adaptive_regularization_weights_from( + inner_coefficient: float, outer_coefficient: float, pixel_signals: jnp.ndarray +) -> jnp.ndarray: + """ + Returns the regularization weights for the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). + + The weights define the effective regularization coefficient of every mesh parameter (typically pixels + of a ``Mapper``). + + They are computed using an estimate of the expected signal in each pixel. + + Two regularization coefficients are used, corresponding to the: + + 1) pixel_signals: pixels with a high pixel-signal (i.e. where the signal is located in the pixelization). + 2) 1.0 - pixel_signals: pixels with a low pixel-signal (i.e. where the signal is not located in the pixelization). + + Parameters + ---------- + inner_coefficient + The inner regularization coefficients which controls the degree of smoothing of the inversion reconstruction + in the inner regions of a mesh's reconstruction. + outer_coefficient + The outer regularization coefficients which controls the degree of smoothing of the inversion reconstruction + in the outer regions of a mesh's reconstruction. + pixel_signals + The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal + and low signal pixelizations. + + Returns + ------- + jnp.ndarray + The adaptive regularization weights which act as the effective regularization coefficients of + every source pixel. + """ + return ( + inner_coefficient * pixel_signals + outer_coefficient * (1.0 - pixel_signals) + ) ** 2.0 + + +def weighted_regularization_matrix_from( + regularization_weights: jnp.ndarray, + neighbors: jnp.ndarray, +) -> jnp.ndarray: + """ + Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). + + This matrix is computed using the regularization weights of every mesh pixel, which are computed using the + function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of + every mesh pixel. + + The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate + neighbor calculation of the corresponding ``Mapper`` class. + + Parameters + ---------- + regularization_weights + The regularization weight of each pixel, adaptively governing the degree of gradient regularization + applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). + neighbors + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the mesh grid (entries of -1 correspond to no neighbor). + neighbors_sizes + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + jnp.ndarray + The regularization matrix computed using an adaptive regularization scheme where the effective regularization + coefficient of every source pixel is different. + """ + S, P = neighbors.shape + reg_w = regularization_weights ** 2 + + # 1) Flatten the (i→j) neighbor pairs + I = jnp.repeat(jnp.arange(S), P) # (S*P,) + J = neighbors.reshape(-1) # (S*P,) + + # 2) Remap “no neighbor” entries to an extra slot S, whose weight=0 + OUT = S + J = jnp.where(J < 0, OUT, J) + + # 3) Build an extended weight vector with a zero at index S + reg_w_ext = jnp.concatenate([reg_w, jnp.zeros((1,))], axis=0) + w_ij = reg_w_ext[J] # (S*P,) + + # 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely + mat = jnp.zeros((S + 1, S + 1), dtype=regularization_weights.dtype) + + # 5) Scatter into the diagonal: + # - the tiny 1e-8 floor on each i < S + # - sum_j reg_w[j] into diag[i] + # - sum contributions reg_w[j] into diag[j] + # (diagonal at OUT=S picks up zeros only) + diag_updates_i = jnp.concatenate([ + jnp.full((S,), 1e-8), + jnp.zeros((1,)) # out‐of‐bounds slot stays zero + ], axis=0) + mat = mat.at[jnp.diag_indices(S + 1)].add(diag_updates_i) + mat = mat.at[I, I].add(w_ij) + mat = mat.at[J, J].add(w_ij) + + # 6) Scatter the off‐diagonal subtractions: + mat = mat.at[I, J].add(-w_ij) + mat = mat.at[J, I].add(-w_ij) + + # 7) Drop the extra row/column S and return the S×S result + return mat[:S, :S] class AdaptiveBrightness(AbstractRegularization): def __init__( @@ -70,7 +177,7 @@ def __init__( self.outer_coefficient = outer_coefficient self.signal_scale = signal_scale - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -91,13 +198,13 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: """ pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale) - return regularization_util.adaptive_regularization_weights_from( + return adaptive_regularization_weights_from( inner_coefficient=self.inner_coefficient, outer_coefficient=self.outer_coefficient, pixel_signals=pixel_signals, ) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -112,7 +219,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ regularization_weights = self.regularization_weights_from(linear_obj=linear_obj) - return regularization_util.weighted_regularization_matrix_from( + return weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=linear_obj.source_plane_mesh_grid.neighbors, ) diff --git a/autoarray/inversion/regularization/brightness_zeroth.py b/autoarray/inversion/regularization/brightness_zeroth.py index 4cab4e6d2..713e66cf8 100644 --- a/autoarray/inversion/regularization/brightness_zeroth.py +++ b/autoarray/inversion/regularization/brightness_zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -10,6 +10,60 @@ from autoarray.inversion.regularization import regularization_util +def brightness_zeroth_regularization_weights_from( + coefficient: float, pixel_signals: jnp.ndarray +) -> jnp.ndarray: + """ + Returns the regularization weights for the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``). + + The weights define the level of zeroth order regularization applied to every mesh parameter (typically pixels + of a ``Mapper``). + + They are computed using an estimate of the expected signal in each pixel. + + The zeroth order regularization coefficients is applied in combination with 1.0 - pixel_signals, which are + the pixels with a low pixel-signal (i.e. where the signal is not located near the source being reconstructed in + the pixelization). + + Parameters + ---------- + coefficient + The level of zeroth order regularization applied to every mesh parameter (typically pixels of a ``Mapper``), + with the degree applied varying based on the ``pixel_signals``. + pixel_signals + The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal + and low signal pixelizations. + + Returns + ------- + jnp.ndarray + The zeroth order regularization weights which act as the effective level of zeroth order regularization + applied to every mesh parameter. + """ + return coefficient * (1.0 - pixel_signals) + +def brightness_zeroth_regularization_matrix_from( + regularization_weights: jnp.ndarray, +) -> jnp.ndarray: + """ + Returns the regularization matrix for the zeroth-order brightness regularization scheme. + + Parameters + ---------- + regularization_weights + The regularization weights for each pixel, governing the strength of zeroth-order + regularization applied per inversion parameter. + + Returns + ------- + A diagonal regularization matrix where each diagonal element is the squared regularization weight + for that pixel. + """ + regularization_weight_squared = regularization_weights**2.0 + return jnp.diag(regularization_weight_squared) + + + class BrightnessZeroth(AbstractRegularization): def __init__( self, @@ -45,7 +99,7 @@ def __init__( self.coefficient = coefficient self.signal_scale = signal_scale - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of the ``BrightnessZeroth`` regularization scheme. @@ -65,11 +119,11 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: """ pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale) - return regularization_util.brightness_zeroth_regularization_weights_from( + return brightness_zeroth_regularization_weights_from( coefficient=self.coefficient, pixel_signals=pixel_signals ) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -84,6 +138,6 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ regularization_weights = self.regularization_weights_from(linear_obj=linear_obj) - return regularization_util.brightness_zeroth_regularization_matrix_from( + return brightness_zeroth_regularization_matrix_from( regularization_weights=regularization_weights ) diff --git a/autoarray/inversion/regularization/constant.py b/autoarray/inversion/regularization/constant.py index 690b248bd..36d23e8dc 100644 --- a/autoarray/inversion/regularization/constant.py +++ b/autoarray/inversion/regularization/constant.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,55 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def constant_regularization_matrix_from( + coefficient: float, + neighbors: jnp.ndarray[[int, int], jnp.int64], + neighbors_sizes: jnp.ndarray[[int], jnp.int64], +) -> jnp.ndarray[[int, int], jnp.float64]: + """ + From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. + + A complete description of regularizatin and the `regularization_matrix` can be found in the `Regularization` + class in the module `autoarray.inversion.regularization`. + + Memory requirement: 2SP + S^2 + FLOPS: 1 + 2S + 2SP + + Parameters + ---------- + coefficient + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + neighbors : ndarray, shape (S, P), dtype=int64 + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the Voronoi grid (entries of -1 correspond to no neighbor). + neighbors_sizes : ndarray, shape (S,), dtype=int64 + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + regularization_matrix : ndarray, shape (S, S), dtype=float64 + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + S, P = neighbors.shape + # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) + OUT_OF_BOUND_IDX = S + regularization_coefficient = coefficient * coefficient + + # flatten it for feeding into the matrix as j indices + neighbors = neighbors.flatten() + # now create the corresponding i indices + I_IDX = jnp.repeat(jnp.arange(S), P) + # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. + # This ensures that JAX can efficiently drop these entries during matrix updates. + neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) + return ( + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors] + # unique indices should be guranteed by neighbors-spec + .add(-regularization_coefficient, mode="drop", unique_indices=True) + ) class Constant(AbstractRegularization): @@ -38,7 +86,7 @@ def __init__(self, coefficient: float = 1.0): super().__init__() - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -57,9 +105,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -73,7 +121,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ - return regularization_util.constant_regularization_matrix_from( + return constant_regularization_matrix_from( coefficient=self.coefficient, neighbors=linear_obj.neighbors, neighbors_sizes=linear_obj.neighbors.sizes, diff --git a/autoarray/inversion/regularization/constant_zeroth.py b/autoarray/inversion/regularization/constant_zeroth.py index 45e1729cf..5ed35bd9f 100644 --- a/autoarray/inversion/regularization/constant_zeroth.py +++ b/autoarray/inversion/regularization/constant_zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,59 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def constant_zeroth_regularization_matrix_from( + coefficient: float, + coefficient_zeroth: float, + neighbors: jnp.ndarray, + neighbors_sizes: jnp.ndarray[[int], jnp.int64], +) -> jnp.ndarray: + """ + From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. + + A complete description of regularizatin and the ``regularization_matrix`` can be found in the ``Regularization`` + class in the module ``autoarray.inversion.regularization``. + + Parameters + ---------- + coefficients + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + neighbors + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the Voronoi grid (entries of -1 correspond to no neighbor). + neighbors_sizes + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + jnp.ndarray + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + S, P = neighbors.shape + # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) + OUT_OF_BOUND_IDX = S + regularization_coefficient = coefficient * coefficient + + # flatten it for feeding into the matrix as j indices + neighbors = neighbors.flatten() + # now create the corresponding i indices + I_IDX = jnp.repeat(jnp.arange(S), P) + # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. + # This ensures that JAX can efficiently drop these entries during matrix updates. + neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) + const = ( + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors] + # unique indices should be guranteed by neighbors-spec + .add(-regularization_coefficient, mode="drop", unique_indices=True) + ) + + reg_coeff = coefficient_zeroth ** 2.0 + # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T + zeroth = jnp.eye(P) * reg_coeff + + return const + zeroth class ConstantZeroth(AbstractRegularization): @@ -17,7 +69,7 @@ def __init__(self, coefficient_neighbor=1.0, coefficient_zeroth=1.0): self.coefficient_neighbor = coefficient_neighbor self.coefficient_zeroth = coefficient_zeroth - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -36,9 +88,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient_neighbor * np.ones(linear_obj.params) + return self.coefficient_neighbor * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -51,7 +103,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization matrix. """ - return regularization_util.constant_zeroth_regularization_matrix_from( + return constant_zeroth_regularization_matrix_from( coefficient=self.coefficient_neighbor, coefficient_zeroth=self.coefficient_zeroth, neighbors=linear_obj.neighbors, diff --git a/autoarray/inversion/regularization/regularization_util.py b/autoarray/inversion/regularization/regularization_util.py index 44e2e326a..d340d5180 100644 --- a/autoarray/inversion/regularization/regularization_util.py +++ b/autoarray/inversion/regularization/regularization_util.py @@ -1,301 +1,16 @@ -import jax.numpy as jnp import numpy as np from typing import Tuple from autoarray import exc -from autoarray import numba_util - -def zeroth_regularization_matrix_from(coefficient: float, pixels: int) -> np.ndarray: - """ - Apply zeroth order regularization which penalizes every pixel's deviation from zero by addiing non-zero terms - to the regularization matrix. - - A complete description of regularization and the `regularization_matrix` can be found in the `Regularization` - class in the module `autoarray.inversion.regularization`. - - Parameters - ---------- - pixels - The number of pixels in the linear object which is to be regularized, being used to in the inversion. - coefficient - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - - Returns - ------- - np.ndarray - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - reg_coeff = coefficient ** 2.0 - # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T - return jnp.eye(pixels) * reg_coeff - - -def constant_regularization_matrix_from( - coefficient: float, - neighbors: np.ndarray[[int, int], np.int64], - neighbors_sizes: np.ndarray[[int], np.int64], -) -> np.ndarray[[int, int], np.float64]: - """ - From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. - - A complete description of regularizatin and the `regularization_matrix` can be found in the `Regularization` - class in the module `autoarray.inversion.regularization`. - - Memory requirement: 2SP + S^2 - FLOPS: 1 + 2S + 2SP - - Parameters - ---------- - coefficient - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - neighbors : ndarray, shape (S, P), dtype=int64 - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the Voronoi grid (entries of -1 correspond to no neighbor). - neighbors_sizes : ndarray, shape (S,), dtype=int64 - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - regularization_matrix : ndarray, shape (S, S), dtype=float64 - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - S, P = neighbors.shape - # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) - OUT_OF_BOUND_IDX = S - regularization_coefficient = coefficient * coefficient - - # flatten it for feeding into the matrix as j indices - neighbors = neighbors.flatten() - # now create the corresponding i indices - I_IDX = jnp.repeat(jnp.arange(S), P) - # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. - # This ensures that JAX can efficiently drop these entries during matrix updates. - neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) - return ( - jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors] - # unique indices should be guranteed by neighbors-spec - .add(-regularization_coefficient, mode="drop", unique_indices=True) - ) - - -def constant_zeroth_regularization_matrix_from( - coefficient: float, - coefficient_zeroth: float, - neighbors: np.ndarray, - neighbors_sizes: np.ndarray[[int], np.int64], -) -> np.ndarray: - """ - From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. - - A complete description of regularizatin and the ``regularization_matrix`` can be found in the ``Regularization`` - class in the module ``autoarray.inversion.regularization``. - - Parameters - ---------- - coefficients - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - neighbors - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the Voronoi grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - np.ndarray - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - S, P = neighbors.shape - # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) - OUT_OF_BOUND_IDX = S - regularization_coefficient = coefficient * coefficient - - # flatten it for feeding into the matrix as j indices - neighbors = neighbors.flatten() - # now create the corresponding i indices - I_IDX = jnp.repeat(jnp.arange(S), P) - # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. - # This ensures that JAX can efficiently drop these entries during matrix updates. - neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) - const = ( - jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors] - # unique indices should be guranteed by neighbors-spec - .add(-regularization_coefficient, mode="drop", unique_indices=True) - ) - - reg_coeff = coefficient_zeroth ** 2.0 - # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T - zeroth = jnp.eye(P) * reg_coeff - - return const + zeroth - -def adaptive_regularization_weights_from( - inner_coefficient: float, outer_coefficient: float, pixel_signals: np.ndarray -) -> np.ndarray: - """ - Returns the regularization weights for the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). - - The weights define the effective regularization coefficient of every mesh parameter (typically pixels - of a ``Mapper``). - - They are computed using an estimate of the expected signal in each pixel. - - Two regularization coefficients are used, corresponding to the: - - 1) pixel_signals: pixels with a high pixel-signal (i.e. where the signal is located in the pixelization). - 2) 1.0 - pixel_signals: pixels with a low pixel-signal (i.e. where the signal is not located in the pixelization). - - Parameters - ---------- - inner_coefficient - The inner regularization coefficients which controls the degree of smoothing of the inversion reconstruction - in the inner regions of a mesh's reconstruction. - outer_coefficient - The outer regularization coefficients which controls the degree of smoothing of the inversion reconstruction - in the outer regions of a mesh's reconstruction. - pixel_signals - The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal - and low signal pixelizations. - - Returns - ------- - np.ndarray - The adaptive regularization weights which act as the effective regularization coefficients of - every source pixel. - """ - return ( - inner_coefficient * pixel_signals + outer_coefficient * (1.0 - pixel_signals) - ) ** 2.0 - - -def brightness_zeroth_regularization_weights_from( - coefficient: float, pixel_signals: np.ndarray -) -> np.ndarray: - """ - Returns the regularization weights for the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``). - - The weights define the level of zeroth order regularization applied to every mesh parameter (typically pixels - of a ``Mapper``). - - They are computed using an estimate of the expected signal in each pixel. - - The zeroth order regularization coefficients is applied in combination with 1.0 - pixel_signals, which are - the pixels with a low pixel-signal (i.e. where the signal is not located near the source being reconstructed in - the pixelization). - - Parameters - ---------- - coefficient - The level of zeroth order regularization applied to every mesh parameter (typically pixels of a ``Mapper``), - with the degree applied varying based on the ``pixel_signals``. - pixel_signals - The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal - and low signal pixelizations. - - Returns - ------- - np.ndarray - The zeroth order regularization weights which act as the effective level of zeroth order regularization - applied to every mesh parameter. - """ - return coefficient * (1.0 - pixel_signals) - - -def weighted_regularization_matrix_from( - regularization_weights: np.ndarray, - neighbors: np.ndarray, -) -> np.ndarray: - """ - Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). - - This matrix is computed using the regularization weights of every mesh pixel, which are computed using the - function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of - every mesh pixel. - - The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate - neighbor calculation of the corresponding ``Mapper`` class. - - Parameters - ---------- - regularization_weights - The regularization weight of each pixel, adaptively governing the degree of gradient regularization - applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). - neighbors - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the mesh grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - np.ndarray - The regularization matrix computed using an adaptive regularization scheme where the effective regularization - coefficient of every source pixel is different. - """ - S, P = neighbors.shape - reg_w = regularization_weights ** 2 - - # 1) Flatten the (i→j) neighbor pairs - I = jnp.repeat(jnp.arange(S), P) # (S*P,) - J = neighbors.reshape(-1) # (S*P,) - - # 2) Remap “no neighbor” entries to an extra slot S, whose weight=0 - OUT = S - J = jnp.where(J < 0, OUT, J) - - # 3) Build an extended weight vector with a zero at index S - reg_w_ext = jnp.concatenate([reg_w, jnp.zeros((1,))], axis=0) - w_ij = reg_w_ext[J] # (S*P,) - - # 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely - mat = jnp.zeros((S + 1, S + 1), dtype=regularization_weights.dtype) - - # 5) Scatter into the diagonal: - # - the tiny 1e-8 floor on each i < S - # - sum_j reg_w[j] into diag[i] - # - sum contributions reg_w[j] into diag[j] - # (diagonal at OUT=S picks up zeros only) - diag_updates_i = jnp.concatenate([ - jnp.full((S,), 1e-8), - jnp.zeros((1,)) # out‐of‐bounds slot stays zero - ], axis=0) - mat = mat.at[jnp.diag_indices(S + 1)].add(diag_updates_i) - mat = mat.at[I, I].add(w_ij) - mat = mat.at[J, J].add(w_ij) - - # 6) Scatter the off‐diagonal subtractions: - mat = mat.at[I, J].add(-w_ij) - mat = mat.at[J, I].add(-w_ij) - - # 7) Drop the extra row/column S and return the S×S result - return mat[:S, :S] - -def brightness_zeroth_regularization_matrix_from( - regularization_weights: np.ndarray, -) -> np.ndarray: - """ - Returns the regularization matrix for the zeroth-order brightness regularization scheme. - - Parameters - ---------- - regularization_weights - The regularization weights for each pixel, governing the strength of zeroth-order - regularization applied per inversion parameter. - - Returns - ------- - A diagonal regularization matrix where each diagonal element is the squared regularization weight - for that pixel. - """ - regularization_weight_squared = regularization_weights**2.0 - return np.diag(regularization_weight_squared) +from autoarray.inversion.regularization.adaptive_brightness import adaptive_regularization_weights_from +from autoarray.inversion.regularization.brightness_zeroth import brightness_zeroth_regularization_matrix_from +from autoarray.inversion.regularization.brightness_zeroth import brightness_zeroth_regularization_weights_from +from autoarray.inversion.regularization.constant import constant_regularization_matrix_from +from autoarray.inversion.regularization.exponential_kernel import exp_cov_matrix_from +from autoarray.inversion.regularization.gaussian_kernel import gauss_cov_matrix_from +from autoarray.inversion.regularization.matern_kernel import matern_kernel +from autoarray.inversion.regularization.zeroth import zeroth_regularization_matrix_from def reg_split_from( diff --git a/autoarray/inversion/regularization/zeroth.py b/autoarray/inversion/regularization/zeroth.py index e30b1222e..728776bc1 100644 --- a/autoarray/inversion/regularization/zeroth.py +++ b/autoarray/inversion/regularization/zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,34 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def zeroth_regularization_matrix_from(coefficient: float, pixels: int) -> jnp.ndarray: + """ + Apply zeroth order regularization which penalizes every pixel's deviation from zero by addiing non-zero terms + to the regularization matrix. + + A complete description of regularization and the `regularization_matrix` can be found in the `Regularization` + class in the module `autoarray.inversion.regularization`. + + Parameters + ---------- + pixels + The number of pixels in the linear object which is to be regularized, being used to in the inversion. + coefficient + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + + Returns + ------- + np.ndarray + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + + reg_coeff = coefficient ** 2.0 + + # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T + + return jnp.eye(pixels) * reg_coeff class Zeroth(AbstractRegularization): @@ -60,9 +87,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -75,6 +102,6 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization matrix. """ - return regularization_util.zeroth_regularization_matrix_from( + return zeroth_regularization_matrix_from( coefficient=self.coefficient, pixels=linear_obj.params ) From a63369de991325f948dc7d61489a84266d698941 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 23 Jul 2025 17:27:57 +0100 Subject: [PATCH 5/6] regulsirztion refactor complete and rectangular works --- .../inversion/pixelization/mesh/mesh_util.py | 9 ++- .../inversion/regularization/__init__.py | 1 - .../regularization/adaptive_brightness.py | 16 ++--- .../regularization/brightness_zeroth.py | 2 +- .../inversion/regularization/constant.py | 4 +- .../regularization/constant_zeroth.py | 6 +- .../regularization/exponential_kernel.py | 60 ++++++++----------- .../regularization/gaussian_kernel.py | 16 ++--- .../regularization/regularization_util.py | 22 +++++-- autoarray/inversion/regularization/zeroth.py | 2 +- autoarray/plot/wrap/two_d/grid_scatter.py | 11 +++- autoarray/preloads.py | 4 +- autoarray/structures/mesh/rectangular_2d.py | 4 +- .../test_regularization_util.py | 1 - 14 files changed, 90 insertions(+), 68 deletions(-) diff --git a/autoarray/inversion/pixelization/mesh/mesh_util.py b/autoarray/inversion/pixelization/mesh/mesh_util.py index 2160b8629..420d9c2ab 100644 --- a/autoarray/inversion/pixelization/mesh/mesh_util.py +++ b/autoarray/inversion/pixelization/mesh/mesh_util.py @@ -110,6 +110,7 @@ def rectangular_corner_neighbors( return neighbors, neighbors_sizes + def rectangular_top_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -145,6 +146,7 @@ def rectangular_top_edge_neighbors( return neighbors, neighbors_sizes + def rectangular_left_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -180,6 +182,7 @@ def rectangular_left_edge_neighbors( return neighbors, neighbors_sizes + def rectangular_right_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -215,6 +218,7 @@ def rectangular_right_edge_neighbors( return neighbors, neighbors_sizes + def rectangular_bottom_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -288,14 +292,15 @@ def rectangular_central_neighbors( # Compute neighbor indices neighbors[pixel_indices, 0] = pixel_indices - n_cols # Up - neighbors[pixel_indices, 1] = pixel_indices - 1 # Left - neighbors[pixel_indices, 2] = pixel_indices + 1 # Right + neighbors[pixel_indices, 1] = pixel_indices - 1 # Left + neighbors[pixel_indices, 2] = pixel_indices + 1 # Right neighbors[pixel_indices, 3] = pixel_indices + n_cols # Down neighbors_sizes[pixel_indices] = 4 return neighbors, neighbors_sizes + def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]: """ Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization, diff --git a/autoarray/inversion/regularization/__init__.py b/autoarray/inversion/regularization/__init__.py index c592cb51f..e34d07b6b 100644 --- a/autoarray/inversion/regularization/__init__.py +++ b/autoarray/inversion/regularization/__init__.py @@ -10,4 +10,3 @@ from .gaussian_kernel import GaussianKernel from .exponential_kernel import ExponentialKernel from .matern_kernel import MaternKernel - diff --git a/autoarray/inversion/regularization/adaptive_brightness.py b/autoarray/inversion/regularization/adaptive_brightness.py index cf5af2b31..c0ba845d0 100644 --- a/autoarray/inversion/regularization/adaptive_brightness.py +++ b/autoarray/inversion/regularization/adaptive_brightness.py @@ -80,11 +80,11 @@ def weighted_regularization_matrix_from( coefficient of every source pixel is different. """ S, P = neighbors.shape - reg_w = regularization_weights ** 2 + reg_w = regularization_weights**2 # 1) Flatten the (i→j) neighbor pairs - I = jnp.repeat(jnp.arange(S), P) # (S*P,) - J = neighbors.reshape(-1) # (S*P,) + I = jnp.repeat(jnp.arange(S), P) # (S*P,) + J = neighbors.reshape(-1) # (S*P,) # 2) Remap “no neighbor” entries to an extra slot S, whose weight=0 OUT = S @@ -92,7 +92,7 @@ def weighted_regularization_matrix_from( # 3) Build an extended weight vector with a zero at index S reg_w_ext = jnp.concatenate([reg_w, jnp.zeros((1,))], axis=0) - w_ij = reg_w_ext[J] # (S*P,) + w_ij = reg_w_ext[J] # (S*P,) # 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely mat = jnp.zeros((S + 1, S + 1), dtype=regularization_weights.dtype) @@ -102,10 +102,9 @@ def weighted_regularization_matrix_from( # - sum_j reg_w[j] into diag[i] # - sum contributions reg_w[j] into diag[j] # (diagonal at OUT=S picks up zeros only) - diag_updates_i = jnp.concatenate([ - jnp.full((S,), 1e-8), - jnp.zeros((1,)) # out‐of‐bounds slot stays zero - ], axis=0) + diag_updates_i = jnp.concatenate( + [jnp.full((S,), 1e-8), jnp.zeros((1,))], axis=0 # out‐of‐bounds slot stays zero + ) mat = mat.at[jnp.diag_indices(S + 1)].add(diag_updates_i) mat = mat.at[I, I].add(w_ij) mat = mat.at[J, J].add(w_ij) @@ -117,6 +116,7 @@ def weighted_regularization_matrix_from( # 7) Drop the extra row/column S and return the S×S result return mat[:S, :S] + class AdaptiveBrightness(AbstractRegularization): def __init__( self, diff --git a/autoarray/inversion/regularization/brightness_zeroth.py b/autoarray/inversion/regularization/brightness_zeroth.py index 713e66cf8..6cd765aec 100644 --- a/autoarray/inversion/regularization/brightness_zeroth.py +++ b/autoarray/inversion/regularization/brightness_zeroth.py @@ -42,6 +42,7 @@ def brightness_zeroth_regularization_weights_from( """ return coefficient * (1.0 - pixel_signals) + def brightness_zeroth_regularization_matrix_from( regularization_weights: jnp.ndarray, ) -> jnp.ndarray: @@ -63,7 +64,6 @@ def brightness_zeroth_regularization_matrix_from( return jnp.diag(regularization_weight_squared) - class BrightnessZeroth(AbstractRegularization): def __init__( self, diff --git a/autoarray/inversion/regularization/constant.py b/autoarray/inversion/regularization/constant.py index 36d23e8dc..d9737d075 100644 --- a/autoarray/inversion/regularization/constant.py +++ b/autoarray/inversion/regularization/constant.py @@ -52,7 +52,9 @@ class in the module `autoarray.inversion.regularization`. # This ensures that JAX can efficiently drop these entries during matrix updates. neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) return ( - jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors] + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[ + I_IDX, neighbors + ] # unique indices should be guranteed by neighbors-spec .add(-regularization_coefficient, mode="drop", unique_indices=True) ) diff --git a/autoarray/inversion/regularization/constant_zeroth.py b/autoarray/inversion/regularization/constant_zeroth.py index 5ed35bd9f..11d7b9808 100644 --- a/autoarray/inversion/regularization/constant_zeroth.py +++ b/autoarray/inversion/regularization/constant_zeroth.py @@ -50,12 +50,14 @@ class in the module ``autoarray.inversion.regularization``. # This ensures that JAX can efficiently drop these entries during matrix updates. neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) const = ( - jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors] + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[ + I_IDX, neighbors + ] # unique indices should be guranteed by neighbors-spec .add(-regularization_coefficient, mode="drop", unique_indices=True) ) - reg_coeff = coefficient_zeroth ** 2.0 + reg_coeff = coefficient_zeroth**2.0 # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T zeroth = jnp.eye(P) * reg_coeff diff --git a/autoarray/inversion/regularization/exponential_kernel.py b/autoarray/inversion/regularization/exponential_kernel.py index 73ead006d..cfb03186b 100644 --- a/autoarray/inversion/regularization/exponential_kernel.py +++ b/autoarray/inversion/regularization/exponential_kernel.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,52 +7,44 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray import numba_util - -@numba_util.jit() def exp_cov_matrix_from( scale: float, - pixel_points: np.ndarray, -) -> np.ndarray: + pixel_points: jnp.ndarray, # shape (N, 2) +) -> jnp.ndarray: # shape (N, N) """ - Consutruct the source brightness covariance matrix, which is used to determined the regularization - pattern (i.e, how the different source pixels are smoothed). + Construct the source brightness covariance matrix using an exponential kernel: + + cov[i,j] = exp(- d_{ij} / scale) - The covariance matrix includes one non-linear parameters, the scale coefficient, which is used to determine - the typical scale of the regularization pattern. + with a tiny jitter 1e-8 added on the diagonal for numerical stability. Parameters ---------- scale - The typical scale of the regularization pattern . + The length‐scale of the exponential kernel. pixel_points - An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane). - Something like [[y1,x1], [y2,x2], ...] + Array of shape (N, 2) giving the (y,x) coordinates of each source‐plane pixel. Returns ------- - np.ndarray - The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels]. + jnp.ndarray, shape (N, N) + The exponential covariance matrix. """ + # pairwise differences: shape (N, N, 2) + diff = pixel_points[:, None, :] - pixel_points[None, :, :] - pixels = len(pixel_points) - covariance_matrix = np.zeros(shape=(pixels, pixels)) + # Euclidean distances: shape (N, N) + d = jnp.linalg.norm(diff, axis=-1) - for i in range(pixels): - covariance_matrix[i, i] += 1e-8 - for j in range(pixels): - xi = pixel_points[i, 1] - yi = pixel_points[i, 0] - xj = pixel_points[j, 1] - yj = pixel_points[j, 0] - d_ij = np.sqrt( - (xi - xj) ** 2 + (yi - yj) ** 2 - ) # distance between the pixel i and j + # exponential kernel + cov = jnp.exp(-d / scale) - covariance_matrix[i, j] += np.exp(-1.0 * d_ij / scale) + # add a small jitter on the diagonal + N = pixel_points.shape[0] + cov = cov + jnp.eye(N) * 1e-8 - return covariance_matrix + return cov class ExponentialKernel(AbstractRegularization): @@ -83,7 +75,7 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0): super().__init__() - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -102,9 +94,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -119,7 +111,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ covariance_matrix = exp_cov_matrix_from( scale=self.scale, - pixel_points=np.array(linear_obj.source_plane_mesh_grid), + pixel_points=linear_obj.source_plane_mesh_grid.array, ) - return self.coefficient * np.linalg.inv(covariance_matrix) + return self.coefficient * jnp.linalg.inv(covariance_matrix) diff --git a/autoarray/inversion/regularization/gaussian_kernel.py b/autoarray/inversion/regularization/gaussian_kernel.py index 27e248dd3..4b600fba5 100644 --- a/autoarray/inversion/regularization/gaussian_kernel.py +++ b/autoarray/inversion/regularization/gaussian_kernel.py @@ -8,6 +8,7 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization + def gauss_cov_matrix_from( scale: float, pixel_points: jnp.ndarray, # shape (N, 2) @@ -34,17 +35,17 @@ def gauss_cov_matrix_from( The Gaussian covariance matrix. """ # Ensure array: - pts = jnp.asarray(pixel_points) # (N, 2) + pts = jnp.asarray(pixel_points) # (N, 2) # Compute squared distances: ||p_i - p_j||^2 - diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2) - d2 = jnp.sum(diffs**2, axis=-1) # (N, N) + diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2) + d2 = jnp.sum(diffs**2, axis=-1) # (N, N) # Gaussian kernel - cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N) + cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N) # Add tiny jitter on the diagonal - N = pts.shape[0] - cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8 + N = pts.shape[0] + cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8 return cov @@ -111,8 +112,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ covariance_matrix = gauss_cov_matrix_from( - scale=self.scale, - pixel_points=linear_obj.source_plane_mesh_grid.array + scale=self.scale, pixel_points=linear_obj.source_plane_mesh_grid.array ) return self.coefficient * jnp.linalg.inv(covariance_matrix) diff --git a/autoarray/inversion/regularization/regularization_util.py b/autoarray/inversion/regularization/regularization_util.py index d340d5180..8cedca034 100644 --- a/autoarray/inversion/regularization/regularization_util.py +++ b/autoarray/inversion/regularization/regularization_util.py @@ -3,10 +3,24 @@ from autoarray import exc -from autoarray.inversion.regularization.adaptive_brightness import adaptive_regularization_weights_from -from autoarray.inversion.regularization.brightness_zeroth import brightness_zeroth_regularization_matrix_from -from autoarray.inversion.regularization.brightness_zeroth import brightness_zeroth_regularization_weights_from -from autoarray.inversion.regularization.constant import constant_regularization_matrix_from +from autoarray.inversion.regularization.adaptive_brightness import ( + adaptive_regularization_weights_from, +) +from autoarray.inversion.regularization.adaptive_brightness import ( + weighted_regularization_matrix_from, +) +from autoarray.inversion.regularization.brightness_zeroth import ( + brightness_zeroth_regularization_matrix_from, +) +from autoarray.inversion.regularization.brightness_zeroth import ( + brightness_zeroth_regularization_weights_from, +) +from autoarray.inversion.regularization.constant import ( + constant_regularization_matrix_from, +) +from autoarray.inversion.regularization.constant_zeroth import ( + constant_zeroth_regularization_matrix_from, +) from autoarray.inversion.regularization.exponential_kernel import exp_cov_matrix_from from autoarray.inversion.regularization.gaussian_kernel import gauss_cov_matrix_from from autoarray.inversion.regularization.matern_kernel import matern_kernel diff --git a/autoarray/inversion/regularization/zeroth.py b/autoarray/inversion/regularization/zeroth.py index 728776bc1..04f61ad0e 100644 --- a/autoarray/inversion/regularization/zeroth.py +++ b/autoarray/inversion/regularization/zeroth.py @@ -30,7 +30,7 @@ class in the module `autoarray.inversion.regularization`. coefficient of every source pixel is the same. """ - reg_coeff = coefficient ** 2.0 + reg_coeff = coefficient**2.0 # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T diff --git a/autoarray/plot/wrap/two_d/grid_scatter.py b/autoarray/plot/wrap/two_d/grid_scatter.py index ad934f4c0..e9b9879d0 100644 --- a/autoarray/plot/wrap/two_d/grid_scatter.py +++ b/autoarray/plot/wrap/two_d/grid_scatter.py @@ -80,9 +80,16 @@ def scatter_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular] try: for grid in grid_list: try: - plt.scatter(y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict) + plt.scatter( + y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict + ) except ValueError: - plt.scatter(y=grid.array[:, 0], x=grid.array[:, 1], c=next(color), **config_dict) + plt.scatter( + y=grid.array[:, 0], + x=grid.array[:, 1], + c=next(color), + **config_dict, + ) except IndexError: return None diff --git a/autoarray/preloads.py b/autoarray/preloads.py index b3f0ca669..340d85bdd 100644 --- a/autoarray/preloads.py +++ b/autoarray/preloads.py @@ -14,7 +14,7 @@ def __init__( self, mapper_indices: np.ndarray = None, source_pixel_zeroed_indices: np.ndarray = None, - linear_light_profile_blurred_mapping_matrix = None, + linear_light_profile_blurred_mapping_matrix=None, ): """ Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance @@ -70,4 +70,4 @@ def __init__( self.linear_light_profile_blurred_mapping_matrix = jnp.array( linear_light_profile_blurred_mapping_matrix - ) \ No newline at end of file + ) diff --git a/autoarray/structures/mesh/rectangular_2d.py b/autoarray/structures/mesh/rectangular_2d.py index 28a9e2372..e8c7a8a82 100644 --- a/autoarray/structures/mesh/rectangular_2d.py +++ b/autoarray/structures/mesh/rectangular_2d.py @@ -136,7 +136,9 @@ def neighbors(self) -> Neighbors: @cached_property def edge_pixel_list(self) -> List: - return mesh_util.rectangular_edge_pixel_list_from(shape_native=self.shape_native) + return mesh_util.rectangular_edge_pixel_list_from( + shape_native=self.shape_native + ) @property def pixels(self) -> int: diff --git a/test_autoarray/inversion/regularizations/test_regularization_util.py b/test_autoarray/inversion/regularizations/test_regularization_util.py index 1b26f7095..05a4bd0d4 100644 --- a/test_autoarray/inversion/regularizations/test_regularization_util.py +++ b/test_autoarray/inversion/regularizations/test_regularization_util.py @@ -405,7 +405,6 @@ def test__weighted_regularization_matrix_from(): [[1, 2, -1, -1], [0, 2, 3, -1], [0, 1, -1, -1], [1, -1, -1, -1]] ) - regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, From f0a4a8d9bda0b656b5b12025a5ba2b4eba1f8e48 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 24 Jul 2025 14:07:20 +0100 Subject: [PATCH 6/6] updates for rectangular grid sorted --- autoarray/structures/mesh/triangulation_2d.py | 1 + pyproject.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/autoarray/structures/mesh/triangulation_2d.py b/autoarray/structures/mesh/triangulation_2d.py index 95ff46633..3ddb95c01 100644 --- a/autoarray/structures/mesh/triangulation_2d.py +++ b/autoarray/structures/mesh/triangulation_2d.py @@ -95,6 +95,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay": to compute the Voronoi mesh are ill posed. These exceptions are caught and combined into a single `MeshException`, which helps exception handling in the `inversion` package. """ + import scipy.spatial try: diff --git a/pyproject.toml b/pyproject.toml index 850141794..039c8d671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "astropy>=5.0,<=6.1.2", "decorator>=4.0.0", "dill>=0.3.1.1", + "jaxnnls==1.0.1", "matplotlib>=3.7.0", "scipy<=1.14.0", "scikit-image<=0.24.0",