diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index d9792fe01..6f2d2ded3 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from scipy.spatial import cKDTree from typing import Tuple @@ -144,6 +145,97 @@ def data_slim_to_pixelization_unique_from( return data_to_pix_unique, data_weights, pix_lengths +def rectangular_mappings_weights_via_interpolation_from( + shape_native: Tuple[int, int], + source_plane_data_grid: jnp.ndarray, + source_plane_mesh_grid: jnp.ndarray, +): + """ + Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid. + + Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function + determines for each irregular point: + - the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and + - the bilinear interpolation weights with respect to those pixels. + + The function supports JAX and is compatible with JIT compilation. + + Parameters + ---------- + shape_native + The shape (Ny, Nx) of the original rectangular mesh grid before flattening. + source_plane_data_grid + The irregular grid of (y, x) points to interpolate. + source_plane_mesh_grid + The flattened regular rectangular mesh grid of (y, x) coordinates. + + Returns + ------- + mappings : jnp.ndarray of shape (N, 4) + Indices of the four nearest rectangular mesh pixels in the flattened mesh grid. + Order is: top-left, top-right, bottom-left, bottom-right. + weights : jnp.ndarray of shape (N, 4) + Bilinear interpolation weights corresponding to the four nearest mesh pixels. + + Notes + ----- + - Assumes the mesh grid is uniformly spaced. + - The weights sum to 1 for each irregular point. + - Uses bilinear interpolation in the (y, x) coordinate system. + """ + source_plane_mesh_grid = source_plane_mesh_grid.reshape(*shape_native, 2) + + # Assume mesh is shaped (Ny, Nx, 2) + Ny, Nx = source_plane_mesh_grid.shape[:2] + + # Get mesh spacings and lower corner + y_coords = source_plane_mesh_grid[:, 0, 0] # shape (Ny,) + x_coords = source_plane_mesh_grid[0, :, 1] # shape (Nx,) + + dy = y_coords[1] - y_coords[0] + dx = x_coords[1] - x_coords[0] + + y_min = y_coords[0] + x_min = x_coords[0] + + # shape (N_irregular, 2) + irregular = source_plane_data_grid + + # Compute normalized mesh coordinates (floating indices) + fy = (irregular[:, 0] - y_min) / dy + fx = (irregular[:, 1] - x_min) / dx + + # Integer indices of top-left corners + ix = jnp.floor(fx).astype(jnp.int32) + iy = jnp.floor(fy).astype(jnp.int32) + + # Clip to stay within bounds + ix = jnp.clip(ix, 0, Nx - 2) + iy = jnp.clip(iy, 0, Ny - 2) + + # Local coordinates inside the cell (0 <= tx, ty <= 1) + tx = fx - ix + ty = fy - iy + + # Bilinear weights + w00 = (1 - tx) * (1 - ty) + w10 = tx * (1 - ty) + w01 = (1 - tx) * ty + w11 = tx * ty + + weights = jnp.stack([w00, w10, w01, w11], axis=1) # shape (N_irregular, 4) + + # Compute indices of 4 surrounding pixels in the flattened mesh + i00 = iy * Nx + ix + i10 = iy * Nx + (ix + 1) + i01 = (iy + 1) * Nx + ix + i11 = (iy + 1) * Nx + (ix + 1) + + mappings = jnp.stack([i00, i10, i01, i11], axis=1) # shape (N_irregular, 4) + + return mappings, weights + + @numba_util.jit() def pix_indexes_for_sub_slim_index_delaunay_from( source_plane_data_grid, diff --git a/autoarray/inversion/pixelization/mappers/rectangular.py b/autoarray/inversion/pixelization/mappers/rectangular.py index 357ee9956..878ab8233 100644 --- a/autoarray/inversion/pixelization/mappers/rectangular.py +++ b/autoarray/inversion/pixelization/mappers/rectangular.py @@ -1,12 +1,14 @@ +import jax.numpy as jnp import numpy as np from typing import Tuple from autoconf import cached_property +from autoarray.structures.grids.irregular_2d import Grid2DIrregular from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights -from autoarray.geometry import geometry_util +from autoarray.inversion.pixelization.mappers import mapper_util class MapperRectangular(AbstractMapper): @@ -95,19 +97,19 @@ def pix_sub_weights(self) -> PixSubWeights: dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index` are equal to 1.0. """ - mappings = geometry_util.grid_pixel_indexes_2d_slim_from( - grid_scaled_2d_slim=np.array(self.source_plane_data_grid.over_sampled), - shape_native=self.source_plane_mesh_grid.shape_native, - pixel_scales=self.source_plane_mesh_grid.pixel_scales, - origin=self.source_plane_mesh_grid.origin, - ).astype("int") - mappings = mappings.reshape((len(mappings), 1)) + mappings, weights = ( + mapper_util.rectangular_mappings_weights_via_interpolation_from( + shape_native=self.shape_native, + source_plane_mesh_grid=self.source_plane_mesh_grid.array, + source_plane_data_grid=Grid2DIrregular( + self.source_plane_data_grid.over_sampled + ).array, + ) + ) return PixSubWeights( - mappings=mappings, - sizes=np.ones(len(mappings), dtype="int"), - weights=np.ones( - (len(self.source_plane_data_grid.over_sampled), 1), dtype="int" - ), + mappings=np.array(mappings), + sizes=4 * np.ones(len(mappings), dtype="int"), + weights=np.array(weights), ) diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index 6519537fd..52a510dfd 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -92,8 +92,12 @@ def hull( # cast JAX arrays to base numpy arrays grid_convex = np.zeros((len(self.grid), 2)) - grid_convex[:, 0] = np.array(self.grid[:, 1]) - grid_convex[:, 1] = np.array(self.grid[:, 0]) + try: + grid_convex[:, 0] = np.array(self.grid.array[:, 1]) + grid_convex[:, 1] = np.array(self.grid.array[:, 0]) + except AttributeError: + grid_convex[:, 0] = np.array(self.grid[:, 1]) + grid_convex[:, 1] = np.array(self.grid[:, 0]) try: hull = ConvexHull(grid_convex) diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index 984aab946..1b46169df 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -487,14 +487,8 @@ def test__inversion_matrices__x2_mappers( settings=aa.SettingsInversion(use_positive_only_solver=True), ) - assert ( - inversion.operated_mapping_matrix[0:9, 0:9] - == rectangular_mapper_7x7_3x3.mapping_matrix - ).all() - assert ( - inversion.operated_mapping_matrix[0:9, 9:18] - == delaunay_mapper_9_3x3.mapping_matrix - ).all() + assert inversion.operated_mapping_matrix[0:9, 0:9] == pytest.approx(rectangular_mapper_7x7_3x3.mapping_matrix, abs=1.0e-4) + assert inversion.operated_mapping_matrix[0:9, 9:18] == pytest.approx(delaunay_mapper_9_3x3.mapping_matrix, abs=1.0e-4) operated_mapping_matrix = np.hstack( [ diff --git a/test_autoarray/inversion/pixelization/mappers/test_factory.py b/test_autoarray/inversion/pixelization/mappers/test_factory.py index c08bca937..31a24ba2a 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_factory.py +++ b/test_autoarray/inversion/pixelization/mappers/test_factory.py @@ -42,18 +42,18 @@ def test__rectangular_mapper(): (5.0, 5.0), 1.0e-4 ) assert mapper.source_plane_mesh_grid.origin == pytest.approx((0.5, 0.5), 1.0e-4) - assert ( - mapper.mapping_matrix - == np.array( + assert mapper.mapping_matrix == pytest.approx( + np.array( [ - [0.0, 0.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0675, 0.5775, 0.18, 0.0075, -0.065, -0.1425, 0.0, 0.0375, 0.3375], + [0.18, -0.03, 0.0, 0.84, -0.14, 0.0, 0.18, -0.03, 0.0], + [0.0225, 0.105, 0.0225, 0.105, 0.49, 0.105, 0.0225, 0.105, 0.0225], + [0.0, -0.03, 0.18, 0.0, -0.14, 0.84, 0.0, -0.03, 0.18], + [0.0, 0.0, 0.0, -0.03, -0.14, -0.03, 0.18, 0.84, 0.18], ] - ) - ).all() + ), + 1.0e-4, + ) assert mapper.shape_native == (3, 3) diff --git a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py index 026e230df..edfa53722 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py +++ b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py @@ -31,20 +31,18 @@ def test__pix_indexes_for_sub_slim_index__matches_util(): mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) - pix_indexes_for_sub_slim_index_util = np.array( - [ - aa.util.geometry.grid_pixel_indexes_2d_slim_from( - grid_scaled_2d_slim=np.array(grid.over_sampled), - shape_native=mesh_grid.shape_native, - pixel_scales=mesh_grid.pixel_scales, - origin=mesh_grid.origin, - ).astype("int") - ] - ).T + mappings, weights = ( + aa.util.mapper.rectangular_mappings_weights_via_interpolation_from( + shape_native=(3, 3), + source_plane_mesh_grid=mesh_grid.array, + source_plane_data_grid=aa.Grid2DIrregular( + mapper_grids.source_plane_data_grid.over_sampled + ).array, + ) + ) - assert ( - mapper.pix_indexes_for_sub_slim_index == pix_indexes_for_sub_slim_index_util - ).all() + assert (mapper.pix_sub_weights.mappings == mappings).all() + assert (mapper.pix_sub_weights.weights == weights).all() def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7):