Skip to content
14 changes: 8 additions & 6 deletions autoarray/inversion/inversion/imaging/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:

return [
(
self.convolver.convolve_mapping_matrix(
mapping_matrix=linear_obj.mapping_matrix
self.psf.convolve_mapping_matrix(
mapping_matrix=linear_obj.mapping_matrix, mask=self.mask
)
if linear_obj.operated_mapping_matrix_override is None
else self.linear_func_operated_mapping_matrix_dict[linear_obj]
Expand Down Expand Up @@ -131,8 +131,9 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict:
if linear_func.operated_mapping_matrix_override is not None:
operated_mapping_matrix = linear_func.operated_mapping_matrix_override
else:
operated_mapping_matrix = self.convolver.convolve_mapping_matrix(
mapping_matrix=linear_func.mapping_matrix
operated_mapping_matrix = self.psf.convolve_mapping_matrix(
mapping_matrix=linear_func.mapping_matrix,
mask=self.mask,
)

linear_func_operated_mapping_matrix_dict[linear_func] = (
Expand Down Expand Up @@ -212,8 +213,9 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict:
mapper_operated_mapping_matrix_dict = {}

for mapper in self.cls_list_from(cls=AbstractMapper):
operated_mapping_matrix = self.convolver.convolve_mapping_matrix(
mapping_matrix=mapper.mapping_matrix
operated_mapping_matrix = self.psf.convolve_mapping_matrix(
mapping_matrix=mapper.mapping_matrix,
mask=self.mask,
)

mapper_operated_mapping_matrix_dict[mapper] = operated_mapping_matrix
Expand Down
8 changes: 4 additions & 4 deletions autoarray/inversion/inversion/imaging/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def _data_vector_mapper(self) -> np.ndarray:
mapper = mapper_list[i]
param_range = mapper_param_range_list[i]

operated_mapping_matrix = self.convolver.convolve_mapping_matrix(
mapping_matrix=mapper.mapping_matrix
operated_mapping_matrix = self.psf.convolve_mapping_matrix(
mapping_matrix=mapper.mapping_matrix, mask=self.mask
)

data_vector_mapper = (
Expand Down Expand Up @@ -129,8 +129,8 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]:
mapper_i = mapper_list[i]
mapper_param_range_i = mapper_param_range_list[i]

operated_mapping_matrix = self.convolver.convolve_mapping_matrix(
mapping_matrix=mapper_i.mapping_matrix
operated_mapping_matrix = self.psf.convolve_mapping_matrix(
mapping_matrix=mapper_i.mapping_matrix, mask=self.mask
)

diag = inversion_util.curvature_matrix_via_mapping_matrix_from(
Expand Down
14 changes: 3 additions & 11 deletions autoarray/inversion/inversion/imaging/w_tilde.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,22 +504,14 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]:
reconstruction=reconstruction,
)

# mapped_reconstructed_image = self.psf.convolve_image_no_blurring(
# image=mapped_reconstructed_image, mask=self.mask
# ).array
mapped_reconstructed_image = self.psf.convolve_image_no_blurring(
image=mapped_reconstructed_image, mask=self.mask
).array

mapped_reconstructed_image = Array2D(
values=mapped_reconstructed_image, mask=self.mask
)

mapped_reconstructed_image = self.convolver.convolve_image_no_blurring(
image=mapped_reconstructed_image
)

mapped_reconstructed_image = Array2D(
values=np.array(mapped_reconstructed_image), mask=self.mask
)

else:
operated_mapping_matrix = self.linear_func_operated_mapping_matrix_dict[
linear_obj
Expand Down
9 changes: 9 additions & 0 deletions autoarray/inversion/mock/mock_inversion_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class MockInversionImaging(InversionImagingMapping):
def __init__(
self,
mask=None,
data=None,
noise_map=None,
psf=None,
Expand All @@ -33,13 +34,21 @@ def __init__(
settings=settings,
)

self._mask = mask
self._operated_mapping_matrix = operated_mapping_matrix

self._linear_func_operated_mapping_matrix_dict = (
linear_func_operated_mapping_matrix_dict
)
self._data_linear_func_matrix_dict = data_linear_func_matrix_dict

@property
def mask(self) -> np.ndarray:
if self._mask is None:
return super().mask

return self._mask

@property
def operated_mapping_matrix(self) -> np.ndarray:
if self._operated_mapping_matrix is None:
Expand Down
92 changes: 92 additions & 0 deletions autoarray/inversion/pixelization/mappers/mapper_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
import numpy as np
from scipy.spatial import cKDTree
from typing import Tuple
Expand Down Expand Up @@ -144,6 +145,97 @@ def data_slim_to_pixelization_unique_from(
return data_to_pix_unique, data_weights, pix_lengths


def rectangular_mappings_weights_via_interpolation_from(
shape_native: Tuple[int, int],
source_plane_data_grid: jnp.ndarray,
source_plane_mesh_grid: jnp.ndarray,
):
"""
Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid.

Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function
determines for each irregular point:
- the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and
- the bilinear interpolation weights with respect to those pixels.

The function supports JAX and is compatible with JIT compilation.

Parameters
----------
shape_native
The shape (Ny, Nx) of the original rectangular mesh grid before flattening.
source_plane_data_grid
The irregular grid of (y, x) points to interpolate.
source_plane_mesh_grid
The flattened regular rectangular mesh grid of (y, x) coordinates.

Returns
-------
mappings : jnp.ndarray of shape (N, 4)
Indices of the four nearest rectangular mesh pixels in the flattened mesh grid.
Order is: top-left, top-right, bottom-left, bottom-right.
weights : jnp.ndarray of shape (N, 4)
Bilinear interpolation weights corresponding to the four nearest mesh pixels.

Notes
-----
- Assumes the mesh grid is uniformly spaced.
- The weights sum to 1 for each irregular point.
- Uses bilinear interpolation in the (y, x) coordinate system.
"""
source_plane_mesh_grid = source_plane_mesh_grid.reshape(*shape_native, 2)

# Assume mesh is shaped (Ny, Nx, 2)
Ny, Nx = source_plane_mesh_grid.shape[:2]

# Get mesh spacings and lower corner
y_coords = source_plane_mesh_grid[:, 0, 0] # shape (Ny,)
x_coords = source_plane_mesh_grid[0, :, 1] # shape (Nx,)

dy = y_coords[1] - y_coords[0]
dx = x_coords[1] - x_coords[0]

y_min = y_coords[0]
x_min = x_coords[0]

# shape (N_irregular, 2)
irregular = source_plane_data_grid

# Compute normalized mesh coordinates (floating indices)
fy = (irregular[:, 0] - y_min) / dy
fx = (irregular[:, 1] - x_min) / dx

# Integer indices of top-left corners
ix = jnp.floor(fx).astype(jnp.int32)
iy = jnp.floor(fy).astype(jnp.int32)

# Clip to stay within bounds
ix = jnp.clip(ix, 0, Nx - 2)
iy = jnp.clip(iy, 0, Ny - 2)

# Local coordinates inside the cell (0 <= tx, ty <= 1)
tx = fx - ix
ty = fy - iy

# Bilinear weights
w00 = (1 - tx) * (1 - ty)
w10 = tx * (1 - ty)
w01 = (1 - tx) * ty
w11 = tx * ty

weights = jnp.stack([w00, w10, w01, w11], axis=1) # shape (N_irregular, 4)

# Compute indices of 4 surrounding pixels in the flattened mesh
i00 = iy * Nx + ix
i10 = iy * Nx + (ix + 1)
i01 = (iy + 1) * Nx + ix
i11 = (iy + 1) * Nx + (ix + 1)

mappings = jnp.stack([i00, i10, i01, i11], axis=1) # shape (N_irregular, 4)

return mappings, weights


@numba_util.jit()
def pix_indexes_for_sub_slim_index_delaunay_from(
source_plane_data_grid,
Expand Down
28 changes: 15 additions & 13 deletions autoarray/inversion/pixelization/mappers/rectangular.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import jax.numpy as jnp
import numpy as np
from typing import Tuple

from autoconf import cached_property

from autoarray.structures.grids.irregular_2d import Grid2DIrregular
from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper
from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights

from autoarray.geometry import geometry_util
from autoarray.inversion.pixelization.mappers import mapper_util


class MapperRectangular(AbstractMapper):
Expand Down Expand Up @@ -95,19 +97,19 @@ def pix_sub_weights(self) -> PixSubWeights:
dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index`
are equal to 1.0.
"""
mappings = geometry_util.grid_pixel_indexes_2d_slim_from(
grid_scaled_2d_slim=np.array(self.source_plane_data_grid.over_sampled),
shape_native=self.source_plane_mesh_grid.shape_native,
pixel_scales=self.source_plane_mesh_grid.pixel_scales,
origin=self.source_plane_mesh_grid.origin,
).astype("int")

mappings = mappings.reshape((len(mappings), 1))
mappings, weights = (
mapper_util.rectangular_mappings_weights_via_interpolation_from(
shape_native=self.shape_native,
source_plane_mesh_grid=self.source_plane_mesh_grid.array,
source_plane_data_grid=Grid2DIrregular(
self.source_plane_data_grid.over_sampled
).array,
)
)

return PixSubWeights(
mappings=mappings,
sizes=np.ones(len(mappings), dtype="int"),
weights=np.ones(
(len(self.source_plane_data_grid.over_sampled), 1), dtype="int"
),
mappings=np.array(mappings),
sizes=4 * np.ones(len(mappings), dtype="int"),
weights=np.array(weights),
)
8 changes: 6 additions & 2 deletions autoarray/operators/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def hull(
# cast JAX arrays to base numpy arrays
grid_convex = np.zeros((len(self.grid), 2))

grid_convex[:, 0] = np.array(self.grid[:, 1])
grid_convex[:, 1] = np.array(self.grid[:, 0])
try:
grid_convex[:, 0] = np.array(self.grid.array[:, 1])
grid_convex[:, 1] = np.array(self.grid.array[:, 0])
except AttributeError:
grid_convex[:, 0] = np.array(self.grid[:, 1])
grid_convex[:, 1] = np.array(self.grid[:, 0])

try:
hull = ConvexHull(grid_convex)
Expand Down
2 changes: 1 addition & 1 deletion autoarray/operators/mock/mock_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ class MockPSF:
def __init__(self, operated_mapping_matrix=None):
self.operated_mapping_matrix = operated_mapping_matrix

def convolve_mapping_matrix(self, mapping_matrix):
def convolve_mapping_matrix(self, mapping_matrix, mask):
return self.operated_mapping_matrix


Expand Down
26 changes: 17 additions & 9 deletions test_autoarray/inversion/inversion/imaging/test_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,32 @@


def test__operated_mapping_matrix_property(psf_3x3, rectangular_mapper_7x7_3x3):

inversion = aa.m.MockInversionImaging(
mask=rectangular_mapper_7x7_3x3.mapper_grids.mask,
psf=psf_3x3,
linear_obj_list=[rectangular_mapper_7x7_3x3],
convolver=aa.Convolver(
kernel=psf_3x3, mask=rectangular_mapper_7x7_3x3.mapper_grids.mask
),
)

assert inversion.operated_mapping_matrix_list[0][0, 0] == pytest.approx(1.0, 1e-4)
assert inversion.operated_mapping_matrix[0, 0] == pytest.approx(1.0, 1e-4)
assert inversion.operated_mapping_matrix_list[0][0, 0] == pytest.approx(
1.61999997, 1e-4
)
assert inversion.operated_mapping_matrix[0, 0] == pytest.approx(1.61999997408, 1e-4)

mask = aa.Mask2D(
[
[True, True, True, True],
[True, False, False, True],
[True, True, True, True],
],
pixel_scales=1.0,
)
psf = aa.m.MockPSF(operated_mapping_matrix=np.ones((2, 2)))

inversion = aa.m.MockInversionImaging(
mask=mask,
psf=psf,
linear_obj_list=[rectangular_mapper_7x7_3x3, rectangular_mapper_7x7_3x3],
convolver=aa.m.MockConvolver(operated_mapping_matrix=np.ones((2, 2))),
)

operated_mapping_matrix_0 = np.array([[1.0, 1.0], [1.0, 1.0]])
Expand Down Expand Up @@ -59,9 +68,9 @@ def test__operated_mapping_matrix_property__with_operated_mapping_matrix_overrid
)

inversion = aa.m.MockInversionImaging(
mask=rectangular_mapper_7x7_3x3.mapper_grids.mask,
psf=psf,
linear_obj_list=[rectangular_mapper_7x7_3x3, linear_obj],
convolver=aa.m.MockConvolver(operated_mapping_matrix=np.ones((2, 2))),
)

operated_mapping_matrix_0 = np.array([[1.0, 1.0], [1.0, 1.0]])
Expand Down Expand Up @@ -92,10 +101,9 @@ def test__curvature_matrix(rectangular_mapper_7x7_3x3):
)

dataset = aa.DatasetInterface(
data=np.ones(2),
data=aa.Array2D.ones(shape_native=(2, 10), pixel_scales=1.0),
noise_map=noise_map,
psf=psf,
convolver=aa.m.MockConvolver(operated_mapping_matrix=np.ones((2, 10))),
)

inversion = aa.InversionImagingMapping(
Expand Down
Loading
Loading