From 6f3378249f43c88769ba3b9a569d25c7843ffd28 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 14 Dec 2025 18:30:37 +0000 Subject: [PATCH] Matern kernel now owrks in JAX via tensorflow prob --- .../inversion/regularization/matern_kernel.py | 131 ++++++++++-------- pyproject.toml | 3 +- .../regularizations/test_matern_kernel.py | 29 ++-- 3 files changed, 88 insertions(+), 75 deletions(-) diff --git a/autoarray/inversion/regularization/matern_kernel.py b/autoarray/inversion/regularization/matern_kernel.py index 5684a8024..a0934c3a2 100644 --- a/autoarray/inversion/regularization/matern_kernel.py +++ b/autoarray/inversion/regularization/matern_kernel.py @@ -1,8 +1,8 @@ from __future__ import annotations +import jax.numpy as jnp +import jax.scipy.special as jsp import numpy as np - import math -import scipy.special as sc from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -12,35 +12,68 @@ from autoarray import numba_util +import jax.numpy as jnp + + +def kv_xp(v, z, xp=np): + """ + XP-compatible modified Bessel K_v(v, z). + + NumPy backend: + -> scipy.special.kv -@numba_util.jit(cache=False) -def matern_kernel(r: float, l: float = 1.0, v: float = 0.5): + JAX backend: + -> jax.scipy.special.kv if available + -> else tfp.substrates.jax.math.bessel_kve * exp(-|z|) """ - need to `pip install numba-scipy ` - see https://gaussianprocess.org/gpml/chapters/RW4.pdf for more info - the distance r need to be scalar - l is the scale - v is the order, better < 30, otherwise may have numerical NaN issue. + # ------------------------- + # NumPy backend + # ------------------------- + if xp is np: + import scipy.special as sc + + return sc.kv(v, z) - v control the smoothness level. the larger the v, the stronger smoothing condition (i.e., the solution is - v-th differentiable) imposed by the kernel. + # ------------------------- + # JAX backend + # ------------------------- + else: + try: + import tensorflow_probability.substrates.jax as tfp + + return tfp.math.bessel_kve(v, z) * xp.exp(-xp.abs(z)) + except ImportError: + raise ImportError( + "To use the JAX backend with the Matérn kernel, " + "please install tensorflow-probability via `pip install tensorflow-probability==0.25.0`." + ) + + +def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np): """ - r = abs(r) - if r == 0: - r = 0.00000001 - part1 = 2 ** (1 - v) / math.gamma(v) - part2 = (math.sqrt(2 * v) * r / l) ** v - part3 = sc.kv(v, math.sqrt(2 * v) * r / l) + XP-compatible Matérn kernel. + Works with NumPy or JAX. + """ + + # Avoid r = 0 singularity (JAX-safe) + r = xp.maximum(xp.abs(r), 1e-8) + + z = xp.sqrt(2.0 * v) * r / l + + part1 = 2.0 ** (1.0 - v) / math.gamma(v) # scalar constant + part2 = z**v + part3 = kv_xp(v, z, xp) + return part1 * part2 * part3 -@numba_util.jit(cache=False) def matern_cov_matrix_from( scale: float, nu: float, - pixel_points: np.ndarray, -) -> np.ndarray: + pixel_points, + xp=np, +): """ Consutruct the regularization covariance matrix, which is used to determined the regularization pattern (i.e, how the different pixels are correlated). @@ -63,45 +96,27 @@ def matern_cov_matrix_from( The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels]. """ - pixels = len(pixel_points) - covariance_matrix = np.zeros(shape=(pixels, pixels)) - - for i in range(pixels): - covariance_matrix[i, i] += 1e-8 - for j in range(pixels): - xi = pixel_points[i, 1] - yi = pixel_points[i, 0] - xj = pixel_points[j, 1] - yj = pixel_points[j, 0] - d_ij = np.sqrt( - (xi - xj) ** 2 + (yi - yj) ** 2 - ) # distance between the pixel i and j - - covariance_matrix[i, j] += matern_kernel(d_ij, l=scale, v=nu) - - return covariance_matrix + # -------------------------------- + # Pairwise distances (broadcasted) + # -------------------------------- + # pixel_points[:, None, :] -> (N, 1, 2) + # pixel_points[None, :, :] -> (1, N, 2) + diff = pixel_points[:, None, :] - pixel_points[None, :, :] # (N, N, 2) + d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N) -class NumbaScipyPlaceholder: - pass + # -------------------------------- + # Apply Matérn kernel elementwise + # -------------------------------- + covariance_matrix = matern_kernel(d_ij, l=scale, v=nu, xp=xp) + # -------------------------------- + # Add diagonal jitter (JAX-safe) + # -------------------------------- + pixels = pixel_points.shape[0] + covariance_matrix = covariance_matrix + 1e-8 * xp.eye(pixels) -try: - import numba_scipy - - numba_scipy = object -except ModuleNotFoundError: - numba_scipy = NumbaScipyPlaceholder() - - -def numba_scipy_exception(): - raise ModuleNotFoundError( - "\n--------------------\n" - "You are attempting to use the MaternKernel for Regularization.\n\n" - "However, the optional library numba_scipy (https://pypi.org/project/numba-scipy/) is not installed.\n\n" - "Install it via the command `pip install numba-scipy==0.3.1`.\n\n" - "----------------------" - ) + return covariance_matrix class MaternKernel(AbstractRegularization): @@ -131,9 +146,6 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0, nu: float = 0.5 Controls the derivative of the regularization pattern (`nu=0.5` is a Gaussian). """ - if isinstance(numba_scipy, NumbaScipyPlaceholder): - numba_scipy_exception() - self.coefficient = coefficient self.scale = float(scale) self.nu = float(nu) @@ -175,8 +187,9 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray """ covariance_matrix = matern_cov_matrix_from( scale=self.scale, - pixel_points=xp.array(linear_obj.source_plane_mesh_grid), + pixel_points=linear_obj.source_plane_mesh_grid.array, nu=self.nu, + xp=xp, ) return self.coefficient * xp.linalg.inv(covariance_matrix) diff --git a/pyproject.toml b/pyproject.toml index 5dedf5643..d88cf0cba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,8 @@ local_scheme = "no-local-version" [project.optional-dependencies] optional=[ "numba", - "pynufft" + "pynufft", + "ensorflow-probability==0.25.0" ] test = ["pytest"] dev = ["pytest", "black"] diff --git a/test_autoarray/inversion/regularizations/test_matern_kernel.py b/test_autoarray/inversion/regularizations/test_matern_kernel.py index ef25a9b83..2adc9482c 100644 --- a/test_autoarray/inversion/regularizations/test_matern_kernel.py +++ b/test_autoarray/inversion/regularizations/test_matern_kernel.py @@ -7,18 +7,17 @@ def test__regularization_matrix(): - pass - - # reg = aa.reg.MaternKernel(coefficient=1.0, scale=2.0, nu=2.0) - # - # source_plane_mesh_grid = aa.Grid2D.no_mask( - # values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]], - # shape_native=(3, 2), - # pixel_scales=1.0, - # ) - # - # mapper = aa.m.MockMapper(source_plane_mesh_grid=source_plane_mesh_grid) - - # regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper) - # - # assert regularization_matrix[0, 0] == pytest.approx(3.540276762, 1.0e-4) + + reg = aa.reg.MaternKernel(coefficient=1.0, scale=2.0, nu=2.0) + + source_plane_mesh_grid = aa.Grid2D.no_mask( + values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]], + shape_native=(3, 2), + pixel_scales=1.0, + ) + + mapper = aa.m.MockMapper(source_plane_mesh_grid=source_plane_mesh_grid) + + regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper) + + assert regularization_matrix[0, 0] == pytest.approx(3.540276762, 1.0e-4)