diff --git a/autoarray/inversion/regularization/__init__.py b/autoarray/inversion/regularization/__init__.py index e34d07b6b..c5d696062 100644 --- a/autoarray/inversion/regularization/__init__.py +++ b/autoarray/inversion/regularization/__init__.py @@ -10,3 +10,4 @@ from .gaussian_kernel import GaussianKernel from .exponential_kernel import ExponentialKernel from .matern_kernel import MaternKernel +from .matern_adaptive_brightness_kernel import MaternAdaptiveBrightnessKernel diff --git a/autoarray/inversion/regularization/matern_adaptive_brightness_kernel.py b/autoarray/inversion/regularization/matern_adaptive_brightness_kernel.py new file mode 100644 index 000000000..6442d0343 --- /dev/null +++ b/autoarray/inversion/regularization/matern_adaptive_brightness_kernel.py @@ -0,0 +1,159 @@ +from __future__ import annotations +import numpy as np +from typing import TYPE_CHECKING + +from autoarray.inversion.regularization.matern_kernel import MaternKernel + +if TYPE_CHECKING: + from autoarray.inversion.linear_obj.linear_obj import LinearObj + +from autoarray.inversion.regularization.matern_kernel import matern_kernel + + +def matern_cov_matrix_from( + scale: float, + nu: float, + pixel_points, + weights=None, + xp=np, +): + """ + Construct the regularization covariance matrix (N x N) using a Matérn kernel, + optionally modulated by per-pixel weights. + + If `weights` is provided (shape [N]), the covariance is: + C_ij = K(d_ij; scale, nu) * w_i * w_j + with a small diagonal jitter added for numerical stability. + + Parameters + ---------- + scale + Typical correlation length of the Matérn kernel. + nu + Smoothness parameter of the Matérn kernel. + pixel_points + Array-like of shape [N, 2] with (y, x) coordinates (or any 2D coords; only distances matter). + weights + Optional array-like of shape [N]. If None, treated as all ones. + xp + Backend (numpy or jax.numpy). + + Returns + ------- + covariance_matrix + Array of shape [N, N]. + """ + + # -------------------------------- + # Pairwise distances (broadcasted) + # -------------------------------- + diff = pixel_points[:, None, :] - pixel_points[None, :, :] # (N, N, 2) + d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N) + + # -------------------------------- + # Base Matérn covariance + # -------------------------------- + covariance_matrix = matern_kernel(d_ij, l=scale, v=nu, xp=xp) # (N, N) + + # -------------------------------- + # Apply weights: C_ij *= w_i * w_j + # (broadcasted outer product, JAX-safe) + # -------------------------------- + if weights is not None: + w = xp.asarray(weights) + # Ensure shape (N,) -> outer product (N,1)*(1,N) -> (N,N) + covariance_matrix = covariance_matrix * (w[:, None] * w[None, :]) + + # -------------------------------- + # Add diagonal jitter (JAX-safe) + # -------------------------------- + pixels = pixel_points.shape[0] + covariance_matrix = covariance_matrix + 1e-8 * xp.eye(pixels) + + return covariance_matrix + + +class MaternAdaptiveBrightnessKernel(MaternKernel): + def __init__( + self, + coefficient: float = 1.0, + scale: float = 1.0, + nu: float = 0.5, + rho: float = 1.0, + ): + """ + Regularization which uses a Matern smoothing kernel to regularize the solution with regularization weights + that adapt to the brightness of the source being reconstructed. + + For this regularization scheme, every pixel is regularized with every other pixel. This contrasts many other + schemes, where regularization is based on neighboring (e.g. do the pixels share a Delaunay edge?) or computing + derivatives around the center of the pixel (where nearby pixels are regularization locally in similar ways). + + This makes the regularization matrix fully dense and therefore may change the run times of the solution. + It also leads to more overall smoothing which can lead to more stable linear inversions. + + For the weighted regularization scheme, each pixel is given an 'effective regularization weight', which is + applied when each set of pixel neighbors are regularized with one another. The motivation of this is that + different regions of a pixelization's mesh require different levels of regularization (e.g., high smoothing where the + no signal is present and less smoothing where it is, see (Nightingale, Dye and Massey 2018)). + + This scheme is not used by Vernardos et al. (2022): https://arxiv.org/abs/2202.09378, but it follows + a similar approach. + + A full description of regularization and this matrix can be found in the parent `AbstractRegularization` class. + + Parameters + ---------- + coefficient + The regularization coefficient which controls the degree of smooth of the inversion reconstruction. + scale + The typical scale (correlation length) of the Matérn regularization kernel. + nu + Controls the smoothness (differentiability) of the Matérn kernel; ``nu=0.5`` corresponds to an + exponential (Ornstein–Uhlenbeck) kernel, while a Gaussian covariance is obtained in the limit + as ``nu`` approaches infinity. + rho + Controls how strongly the kernel weights adapt to pixel brightness. Larger values make bright pixels + receive significantly higher weights (and faint pixels lower weights), while smaller values produce a + more uniform weighting. Typical values are of order unity (e.g. 0.5–2.0). + """ + super().__init__(coefficient=coefficient, scale=scale, nu=nu) + self.rho = rho + + def covariance_kernel_weights_from( + self, linear_obj: LinearObj, xp=np + ) -> np.ndarray: + """ + Returns per-pixel kernel weights that adapt to the reconstructed pixel brightness. + """ + # Assumes linear_obj.pixel_signals_from is xp-aware elsewhere in the codebase. + pixel_signals = linear_obj.pixel_signals_from(signal_scale=1.0, xp=xp) + + max_signal = xp.max(pixel_signals) + max_signal = xp.maximum(max_signal, 1e-8) # avoid divide-by-zero (JAX-safe) + + return xp.exp(-self.rho * (1.0 - pixel_signals / max_signal)) + + def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray: + kernel_weights = self.covariance_kernel_weights_from( + linear_obj=linear_obj, xp=xp + ) + + # Follow the xp pattern used in the Matérn kernel module (often `.array` for grids). + pixel_points = linear_obj.source_plane_mesh_grid.array + + covariance_matrix = matern_cov_matrix_from( + scale=self.scale, + pixel_points=pixel_points, + nu=self.nu, + weights=kernel_weights, + xp=xp, + ) + + return self.coefficient * xp.linalg.inv(covariance_matrix) + + def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray: + """ + Returns the regularization weights of this regularization scheme. + """ + return 1.0 / self.covariance_kernel_weights_from(linear_obj=linear_obj, xp=xp) diff --git a/autoarray/inversion/regularization/matern_kernel.py b/autoarray/inversion/regularization/matern_kernel.py index c80c5d233..3e38757ec 100644 --- a/autoarray/inversion/regularization/matern_kernel.py +++ b/autoarray/inversion/regularization/matern_kernel.py @@ -44,6 +44,20 @@ def kv_xp(v, z, xp=np): ) +def gamma_xp(x, xp=np): + """ + XP-compatible Gamma(x). + """ + if xp is np: + import scipy.special as sc + + return sc.gamma(x) + else: + import jax.scipy.special as jsp + + return jsp.gamma(x) + + def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np): """ XP-compatible Matérn kernel. @@ -55,7 +69,7 @@ def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np): z = xp.sqrt(2.0 * v) * r / l - part1 = 2.0 ** (1.0 - v) / math.gamma(v) # scalar constant + part1 = 2.0 ** (1.0 - v) / gamma_xp(v, xp) # scalar constant part2 = z**v part3 = kv_xp(v, z, xp) @@ -141,8 +155,8 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0, nu: float = 0.5 """ self.coefficient = coefficient - self.scale = float(scale) - self.nu = float(nu) + self.scale = scale + self.nu = nu super().__init__() def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray: diff --git a/autoarray/preloads.py b/autoarray/preloads.py index ba9e4e4dd..1b4e00d28 100644 --- a/autoarray/preloads.py +++ b/autoarray/preloads.py @@ -26,6 +26,7 @@ def __init__( linear_light_profile_blurred_mapping_matrix=None, use_voronoi_areas: bool = True, areas_factor: float = 0.5, + skip_areas: bool = False, ): """ Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance @@ -81,6 +82,16 @@ def __init__( inversion, with the other component being the pixelization's pixels. These are fixed when the lens light is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but the intensity values will still be solved for during the inversion. + use_voronoi_areas + Whether to use Voronoi areas during Delaunay triangulation. When True, computes areas for each Voronoi + region which can be used in certain regularization schemes. Default is True. + areas_factor + Factor used to scale the Voronoi areas during split point computation. Default is 0.5. + skip_areas + Whether to skip Voronoi area calculations and split point computations during Delaunay triangulation. + When True, the Delaunay interface returns only the minimal set of outputs (points, simplices, mappings) + without computing split_points or splitted_mappings. This optimization is useful for regularization + schemes like Matérn kernels that don't require area-based calculations. Default is False. """ self.mapper_indices = None self.source_pixel_zeroed_indices = None @@ -123,3 +134,4 @@ def __init__( self.use_voronoi_areas = use_voronoi_areas self.areas_factor = areas_factor + self.skip_areas = skip_areas diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 51c37975e..11aeaf927 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -612,7 +612,7 @@ def convolved_image_from( image, blurring_image, jax_method="direct", - use_mixed_precision : bool = False, + use_mixed_precision: bool = False, xp=np, ): """ diff --git a/autoarray/structures/mesh/delaunay_2d.py b/autoarray/structures/mesh/delaunay_2d.py index 74fb18112..961944b75 100644 --- a/autoarray/structures/mesh/delaunay_2d.py +++ b/autoarray/structures/mesh/delaunay_2d.py @@ -339,6 +339,67 @@ def pix_indexes_for_sub_slim_index_delaunay_from( return out +def scipy_delaunay_matern(points_np, query_points_np): + """ + Minimal SciPy Delaunay callback for Matérn regularization. + + Returns only what’s needed for mapping: + - points (tri.points) + - simplices_padded + - mappings: integer array of pixel indices for each query point, + typically of shape (Q, 3), where each row gives the indices of the + Delaunay mesh vertices ("pixels") associated with that query point. + """ + + max_simplices = 2 * points_np.shape[0] + + # --- Delaunay mesh --- + tri = Delaunay(points_np) + + points = tri.points.astype(points_np.dtype) + simplices = tri.simplices.astype(np.int32) + + # --- Pad simplices to fixed shape for JAX --- + simplices_padded = -np.ones((max_simplices, 3), dtype=np.int32) + simplices_padded[: simplices.shape[0]] = simplices + + # --- find_simplex for query points --- + simplex_idx = tri.find_simplex(query_points_np).astype(np.int32) # (Q,) + + mappings = pix_indexes_for_sub_slim_index_delaunay_from( + source_plane_data_grid=query_points_np, + simplex_index_for_sub_slim_index=simplex_idx, + pix_indexes_for_simplex_index=simplices, + delaunay_points=points_np, + ) + + return points, simplices_padded, mappings + + +def jax_delaunay_matern(points, query_points): + """ + JAX wrapper using pure_callback to run SciPy Delaunay on CPU, + returning only the minimal outputs needed for Matérn usage. + """ + import jax + import jax.numpy as jnp + + N = points.shape[0] + Q = query_points.shape[0] + max_simplices = 2 * N + + points_shape = jax.ShapeDtypeStruct((N, 2), points.dtype) + simplices_padded_shape = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32) + mappings_shape = jax.ShapeDtypeStruct((Q, 3), jnp.int32) + + return jax.pure_callback( + lambda pts, qpts: scipy_delaunay_matern(np.asarray(pts), np.asarray(qpts)), + (points_shape, simplices_padded_shape, mappings_shape), + points, + query_points, + ) + + class DelaunayInterface: def __init__( @@ -466,33 +527,60 @@ def delaunay(self) -> "scipy.spatial.Delaunay": use_voronoi_areas = self.preloads.use_voronoi_areas areas_factor = self.preloads.areas_factor + skip_areas = self.preloads.skip_areas else: use_voronoi_areas = True areas_factor = 0.5 + skip_areas = False - if self._xp.__name__.startswith("jax"): + if not skip_areas: - import jax.numpy as jnp + if self._xp.__name__.startswith("jax"): - points, simplices, mappings, split_points, splitted_mappings = jax_delaunay( - points=self.mesh_grid_xy, - query_points=self._source_plane_data_grid_over_sampled, - use_voronoi_areas=use_voronoi_areas, - areas_factor=areas_factor, - ) + import jax.numpy as jnp + + points, simplices, mappings, split_points, splitted_mappings = ( + jax_delaunay( + points=self.mesh_grid_xy, + query_points=self._source_plane_data_grid_over_sampled, + use_voronoi_areas=use_voronoi_areas, + areas_factor=areas_factor, + ) + ) + + else: + + points, simplices, mappings, split_points, splitted_mappings = ( + scipy_delaunay( + points_np=self.mesh_grid_xy, + query_points_np=self._source_plane_data_grid_over_sampled, + use_voronoi_areas=use_voronoi_areas, + areas_factor=areas_factor, + ) + ) else: - points, simplices, mappings, split_points, splitted_mappings = ( - scipy_delaunay( + if self._xp.__name__.startswith("jax"): + + import jax.numpy as jnp + + points, simplices, mappings = jax_delaunay_matern( + points=self.mesh_grid_xy, + query_points=self._source_plane_data_grid_over_sampled, + ) + + else: + + points, simplices, mappings = scipy_delaunay_matern( points_np=self.mesh_grid_xy, query_points_np=self._source_plane_data_grid_over_sampled, - use_voronoi_areas=use_voronoi_areas, - areas_factor=areas_factor, ) - ) + + split_points = None + splitted_mappings = None return DelaunayInterface( points=points, diff --git a/test_autoarray/inversion/regularizations/test_matern_adaptive_brightness_kernel.py b/test_autoarray/inversion/regularizations/test_matern_adaptive_brightness_kernel.py new file mode 100644 index 000000000..de514975f --- /dev/null +++ b/test_autoarray/inversion/regularizations/test_matern_adaptive_brightness_kernel.py @@ -0,0 +1,45 @@ +import pytest + +import autoarray as aa +import numpy as np + +np.set_printoptions(threshold=np.inf) + + +def test__regularization_matrix(): + + reg = aa.reg.MaternAdaptiveBrightnessKernel( + coefficient=1.0, scale=2.0, nu=2.0, rho=1.0 + ) + + pixel_signals = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + source_plane_mesh_grid = aa.Grid2D.no_mask( + values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]], + shape_native=(3, 2), + pixel_scales=1.0, + ) + + mapper = aa.m.MockMapper( + source_plane_mesh_grid=source_plane_mesh_grid, pixel_signals=pixel_signals + ) + + regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper) + + assert regularization_matrix[0, 0] == pytest.approx(18.7439565009, 1.0e-4) + assert regularization_matrix[0, 1] == pytest.approx(-8.786547368, 1.0e-4) + + reg = aa.reg.MaternAdaptiveBrightnessKernel( + coefficient=1.5, scale=2.5, nu=2.5, rho=1.5 + ) + + pixel_signals = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + mapper = aa.m.MockMapper( + source_plane_mesh_grid=source_plane_mesh_grid, pixel_signals=pixel_signals + ) + + regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper) + + assert regularization_matrix[0, 0] == pytest.approx(121.0190770, 1.0e-4) + assert regularization_matrix[0, 1] == pytest.approx(-66.9580331, 1.0e-4)