Skip to content
Merged
2 changes: 2 additions & 0 deletions autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .inversion.pixelization.mappers.rectangular import MapperRectangular
from .inversion.pixelization.mappers.delaunay import MapperDelaunay
from .inversion.pixelization.mappers.voronoi import MapperVoronoi
from .inversion.pixelization.mappers.rectangular_uniform import MapperRectangularUniform
from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh
from .inversion.pixelization.mesh.abstract import AbstractMesh
from .inversion.inversion.imaging.mapping import InversionImagingMapping
Expand Down Expand Up @@ -75,6 +76,7 @@
from .operators.over_sampling.over_sampler import OverSampler
from .structures.grids.irregular_2d import Grid2DIrregular
from .structures.mesh.rectangular_2d import Mesh2DRectangular
from .structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
from .structures.mesh.voronoi_2d import Mesh2DVoronoi
from .structures.mesh.delaunay_2d import Mesh2DDelaunay
from .structures.arrays.kernel_2d import Kernel2D
Expand Down
3 changes: 3 additions & 0 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def w_tilde(self):
indexes=indexes.astype("int"),
lengths=lengths.astype("int"),
noise_map_value=self.noise_map[0],
noise_map=self.noise_map,
psf=self.psf,
mask=self.mask,
)

@classmethod
Expand Down
60 changes: 60 additions & 0 deletions autoarray/dataset/imaging/w_tilde.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
import logging
import numpy as np

from autoconf import cached_property

from autoarray.dataset.abstract.w_tilde import AbstractWTilde

from autoarray.inversion.inversion.imaging import inversion_imaging_util

logger = logging.getLogger(__name__)


Expand All @@ -13,6 +17,9 @@ def __init__(
curvature_preload: np.ndarray,
indexes: np.ndim,
lengths: np.ndarray,
noise_map: np.ndarray,
psf: np.ndarray,
mask: np.ndarray,
noise_map_value: float,
):
"""
Expand Down Expand Up @@ -44,3 +51,56 @@ def __init__(

self.indexes = indexes
self.lengths = lengths
self.noise_map = noise_map
self.psf = psf
self.mask = mask

@cached_property
def w_matrix(self):
"""
The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF
convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the
curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the
PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging
datasets.

The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's,
making it impossible to store in memory and its use in linear algebra calculations extremely. The method
`w_tilde_curvature_preload_imaging_from` describes a compressed representation that overcomes this hurdles. It is
advised `w_tilde` and this method are only used for testing.

Parameters
----------
noise_map_native
The two dimensional masked noise-map of values which w_tilde is computed from.
kernel_native
The two dimensional PSF kernel that w_tilde encodes the convolution of.
native_index_for_slim_index
An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array.

Returns
-------
ndarray
A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of
the curvature matrix.
"""

return inversion_imaging_util.w_tilde_curvature_imaging_from(
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
kernel_native=np.array(self.psf.native.array).astype("float64"),
native_index_for_slim_index=np.array(
self.mask.derive_indexes.native_for_slim
).astype("int"),
)

@cached_property
def psf_operator_matrix_dense(self):

return inversion_imaging_util.psf_operator_matrix_dense_from(
kernel_native=np.array(self.psf.native.array).astype("float64"),
native_index_for_slim_index=np.array(
self.mask.derive_indexes.native_for_slim
).astype("int"),
native_shape=self.noise_map.shape_native,
correlate=False,
)
2 changes: 1 addition & 1 deletion autoarray/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def make_rectangular_mapper_7x7_3x3():
adapt_data=aa.Array2D.ones(shape_native=(3, 3), pixel_scales=0.1),
)

return aa.MapperRectangular(
return aa.MapperRectangularUniform(
mapper_grids=mapper_grids,
border_relocator=make_border_relocator_2d_7x7(),
regularization=make_regularization_constant(),
Expand Down
104 changes: 83 additions & 21 deletions autoarray/inversion/inversion/imaging/inversion_imaging_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,69 @@
from scipy.signal import convolve2d
import jax.numpy as jnp
import numpy as np
from typing import Tuple

from autoarray import numba_util
from scipy.signal import correlate2d

import numpy as np


def psf_operator_matrix_dense_from(
kernel_native: np.ndarray,
native_index_for_slim_index: np.ndarray, # shape (N_pix, 2), native (y,x) coords of masked pixels
native_shape: tuple[int, int],
correlate: bool = True,
) -> np.ndarray:
"""
Construct a dense PSF operator W (N_pix x N_pix) that maps masked image pixels to masked image pixels.

Parameters
----------
kernel_native : (Ky, Kx) PSF kernel.
native_index_for_slim_index : (N_pix, 2) array of int
Native (y, x) coords for each masked pixel.
native_shape : (Ny, Nx)
Native 2D image shape.
correlate : bool, default True
If True, use correlation convention (no kernel flip).
If False, use convolution convention (flip kernel).

Returns
-------
W : ndarray, shape (N_pix, N_pix)
Dense PSF operator.
"""
Ky, Kx = kernel_native.shape
ph, pw = Ky // 2, Kx // 2
Ny, Nx = native_shape
N_pix = native_index_for_slim_index.shape[0]

ker = kernel_native if correlate else kernel_native[::-1, ::-1]

# Padded index grid: -1 everywhere, slim index where masked
index_padded = -np.ones((Ny + 2 * ph, Nx + 2 * pw), dtype=np.int64)
for p, (y, x) in enumerate(native_index_for_slim_index):
index_padded[y + ph, x + pw] = p

# Neighborhood offsets
dy = np.arange(Ky) - ph
dx = np.arange(Kx) - pw

W = np.zeros((N_pix, N_pix), dtype=float)

for i, (y, x) in enumerate(native_index_for_slim_index):
yp = y + ph
xp = x + pw
for j, dy_ in enumerate(dy):
for k, dx_ in enumerate(dx):
neigh = index_padded[yp + dy_, xp + dx_]
if neigh >= 0:
W[i, neigh] += ker[j, k]

return W


@numba_util.jit()
def w_tilde_data_imaging_from(
image_native: np.ndarray,
noise_map_native: np.ndarray,
Expand Down Expand Up @@ -44,32 +103,35 @@ def w_tilde_data_imaging_from(
efficient calculation of the data vector.
"""

kernel_shift_y = -(kernel_native.shape[1] // 2)
kernel_shift_x = -(kernel_native.shape[0] // 2)

image_pixels = len(native_index_for_slim_index)

w_tilde_data = np.zeros((image_pixels,))
# 1) weight map = image / noise^2 (safe where noise==0)
weight_map = jnp.where(
noise_map_native > 0.0, image_native / (noise_map_native**2), 0.0
)

weight_map_native = image_native / noise_map_native**2.0
Ky, Kx = kernel_native.shape
ph, pw = Ky // 2, Kx // 2

for ip0 in range(image_pixels):
ip0_y, ip0_x = native_index_for_slim_index[ip0]

value = 0.0
# 2) pad so neighbourhood gathers never go OOB
padded = jnp.pad(
weight_map, ((ph, ph), (pw, pw)), mode="constant", constant_values=0.0
)

for k0_y in range(kernel_native.shape[0]):
for k0_x in range(kernel_native.shape[1]):
weight_value = weight_map_native[
ip0_y + k0_y + kernel_shift_y, ip0_x + k0_x + kernel_shift_x
]
# 3) build broadcasted neighbourhood indices for all requested pixels
# shift pixel coords into the padded frame
ys = native_index_for_slim_index[:, 0] + ph # (N,)
xs = native_index_for_slim_index[:, 1] + pw # (N,)

if not np.isnan(weight_value):
value += kernel_native[k0_y, k0_x] * weight_value
# kernel-relative offsets
dy = jnp.arange(Ky) - ph # (Ky,)
dx = jnp.arange(Kx) - pw # (Kx,)

w_tilde_data[ip0] = value
# broadcast to (N, Ky, Kx)
Y = ys[:, None, None] + dy[None, :, None]
X = xs[:, None, None] + dx[None, None, :]

return w_tilde_data
# 4) gather patches and correlate (no kernel flip)
patches = padded[Y, X] # (N, Ky, Kx)
return jnp.sum(patches * kernel_native[None, :, :], axis=(1, 2)) # (N,)


@numba_util.jit()
Expand Down
Loading
Loading