Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 72 additions & 59 deletions autoarray/inversion/regularization/matern_kernel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations
import jax.numpy as jnp
import jax.scipy.special as jsp
import numpy as np

import math
import scipy.special as sc
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -12,35 +12,68 @@

from autoarray import numba_util

import jax.numpy as jnp


def kv_xp(v, z, xp=np):
"""
XP-compatible modified Bessel K_v(v, z).

NumPy backend:
-> scipy.special.kv

@numba_util.jit(cache=False)
def matern_kernel(r: float, l: float = 1.0, v: float = 0.5):
JAX backend:
-> jax.scipy.special.kv if available
-> else tfp.substrates.jax.math.bessel_kve * exp(-|z|)
"""
need to `pip install numba-scipy `
see https://gaussianprocess.org/gpml/chapters/RW4.pdf for more info

the distance r need to be scalar
l is the scale
v is the order, better < 30, otherwise may have numerical NaN issue.
# -------------------------
# NumPy backend
# -------------------------
if xp is np:
import scipy.special as sc

return sc.kv(v, z)

v control the smoothness level. the larger the v, the stronger smoothing condition (i.e., the solution is
v-th differentiable) imposed by the kernel.
# -------------------------
# JAX backend
# -------------------------
else:
try:
import tensorflow_probability.substrates.jax as tfp

return tfp.math.bessel_kve(v, z) * xp.exp(-xp.abs(z))
except ImportError:
raise ImportError(
"To use the JAX backend with the Matérn kernel, "
"please install tensorflow-probability via `pip install tensorflow-probability==0.25.0`."
)


def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np):
"""
r = abs(r)
if r == 0:
r = 0.00000001
part1 = 2 ** (1 - v) / math.gamma(v)
part2 = (math.sqrt(2 * v) * r / l) ** v
part3 = sc.kv(v, math.sqrt(2 * v) * r / l)
XP-compatible Matérn kernel.
Works with NumPy or JAX.
"""

# Avoid r = 0 singularity (JAX-safe)
r = xp.maximum(xp.abs(r), 1e-8)

z = xp.sqrt(2.0 * v) * r / l

part1 = 2.0 ** (1.0 - v) / math.gamma(v) # scalar constant
part2 = z**v
part3 = kv_xp(v, z, xp)

return part1 * part2 * part3


@numba_util.jit(cache=False)
def matern_cov_matrix_from(
scale: float,
nu: float,
pixel_points: np.ndarray,
) -> np.ndarray:
pixel_points,
xp=np,
):
"""
Consutruct the regularization covariance matrix, which is used to determined the regularization pattern (i.e,
how the different pixels are correlated).
Expand All @@ -63,45 +96,27 @@ def matern_cov_matrix_from(
The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels].
"""

pixels = len(pixel_points)
covariance_matrix = np.zeros(shape=(pixels, pixels))

for i in range(pixels):
covariance_matrix[i, i] += 1e-8
for j in range(pixels):
xi = pixel_points[i, 1]
yi = pixel_points[i, 0]
xj = pixel_points[j, 1]
yj = pixel_points[j, 0]
d_ij = np.sqrt(
(xi - xj) ** 2 + (yi - yj) ** 2
) # distance between the pixel i and j

covariance_matrix[i, j] += matern_kernel(d_ij, l=scale, v=nu)

return covariance_matrix
# --------------------------------
# Pairwise distances (broadcasted)
# --------------------------------
# pixel_points[:, None, :] -> (N, 1, 2)
# pixel_points[None, :, :] -> (1, N, 2)
diff = pixel_points[:, None, :] - pixel_points[None, :, :] # (N, N, 2)

d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N)

class NumbaScipyPlaceholder:
pass
# --------------------------------
# Apply Matérn kernel elementwise
# --------------------------------
covariance_matrix = matern_kernel(d_ij, l=scale, v=nu, xp=xp)

# --------------------------------
# Add diagonal jitter (JAX-safe)
# --------------------------------
pixels = pixel_points.shape[0]
covariance_matrix = covariance_matrix + 1e-8 * xp.eye(pixels)

try:
import numba_scipy

numba_scipy = object
except ModuleNotFoundError:
numba_scipy = NumbaScipyPlaceholder()


def numba_scipy_exception():
raise ModuleNotFoundError(
"\n--------------------\n"
"You are attempting to use the MaternKernel for Regularization.\n\n"
"However, the optional library numba_scipy (https://pypi.org/project/numba-scipy/) is not installed.\n\n"
"Install it via the command `pip install numba-scipy==0.3.1`.\n\n"
"----------------------"
)
return covariance_matrix


class MaternKernel(AbstractRegularization):
Expand Down Expand Up @@ -131,9 +146,6 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0, nu: float = 0.5
Controls the derivative of the regularization pattern (`nu=0.5` is a Gaussian).
"""

if isinstance(numba_scipy, NumbaScipyPlaceholder):
numba_scipy_exception()

self.coefficient = coefficient
self.scale = float(scale)
self.nu = float(nu)
Expand Down Expand Up @@ -175,8 +187,9 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray
"""
covariance_matrix = matern_cov_matrix_from(
scale=self.scale,
pixel_points=xp.array(linear_obj.source_plane_mesh_grid),
pixel_points=linear_obj.source_plane_mesh_grid.array,
nu=self.nu,
xp=xp,
)

return self.coefficient * xp.linalg.inv(covariance_matrix)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ local_scheme = "no-local-version"
[project.optional-dependencies]
optional=[
"numba",
"pynufft"
"pynufft",
"ensorflow-probability==0.25.0"
]
test = ["pytest"]
dev = ["pytest", "black"]
Expand Down
29 changes: 14 additions & 15 deletions test_autoarray/inversion/regularizations/test_matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@


def test__regularization_matrix():
pass

# reg = aa.reg.MaternKernel(coefficient=1.0, scale=2.0, nu=2.0)
#
# source_plane_mesh_grid = aa.Grid2D.no_mask(
# values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]],
# shape_native=(3, 2),
# pixel_scales=1.0,
# )
#
# mapper = aa.m.MockMapper(source_plane_mesh_grid=source_plane_mesh_grid)

# regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper)
#
# assert regularization_matrix[0, 0] == pytest.approx(3.540276762, 1.0e-4)

reg = aa.reg.MaternKernel(coefficient=1.0, scale=2.0, nu=2.0)

source_plane_mesh_grid = aa.Grid2D.no_mask(
values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]],
shape_native=(3, 2),
pixel_scales=1.0,
)

mapper = aa.m.MockMapper(source_plane_mesh_grid=source_plane_mesh_grid)

regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper)

assert regularization_matrix[0, 0] == pytest.approx(3.540276762, 1.0e-4)
Loading