From fd11b178e6a64855650ac39fcb48d5697a5ab66d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 24 Jun 2025 11:35:58 +0100 Subject: [PATCH 1/5] rectangular uses intterpolation with JAX support now --- .../pixelization/mappers/mapper_util.py | 92 +++++++++++++++++++ .../pixelization/mappers/rectangular.py | 24 +++-- autoarray/operators/contour.py | 4 +- 3 files changed, 105 insertions(+), 15 deletions(-) diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index d9792fe01..9efb049dd 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from scipy.spatial import cKDTree from typing import Tuple @@ -144,6 +145,97 @@ def data_slim_to_pixelization_unique_from( return data_to_pix_unique, data_weights, pix_lengths +def rectangular_mappings_weights_via_interpolation_from( + shape_native : Tuple[int, int], + source_plane_data_grid: jnp.ndarray, + source_plane_mesh_grid: jnp.ndarray +): + """ + Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid. + + Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function + determines for each irregular point: + - the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and + - the bilinear interpolation weights with respect to those pixels. + + The function supports JAX and is compatible with JIT compilation. + + Parameters + ---------- + shape_native + The shape (Ny, Nx) of the original rectangular mesh grid before flattening. + source_plane_data_grid + The irregular grid of (y, x) points to interpolate. + source_plane_mesh_grid + The flattened regular rectangular mesh grid of (y, x) coordinates. + + Returns + ------- + mappings : jnp.ndarray of shape (N, 4) + Indices of the four nearest rectangular mesh pixels in the flattened mesh grid. + Order is: top-left, top-right, bottom-left, bottom-right. + weights : jnp.ndarray of shape (N, 4) + Bilinear interpolation weights corresponding to the four nearest mesh pixels. + + Notes + ----- + - Assumes the mesh grid is uniformly spaced. + - The weights sum to 1 for each irregular point. + - Uses bilinear interpolation in the (y, x) coordinate system. + """ + source_plane_mesh_grid = source_plane_mesh_grid.reshape(*shape_native, 2) + + # Assume mesh is shaped (Ny, Nx, 2) + Ny, Nx = source_plane_mesh_grid.shape[:2] + + # Get mesh spacings and lower corner + y_coords = source_plane_mesh_grid[:, 0, 0] # shape (Ny,) + x_coords = source_plane_mesh_grid[0, :, 1] # shape (Nx,) + + dy = y_coords[1] - y_coords[0] + dx = x_coords[1] - x_coords[0] + + y_min = y_coords[0] + x_min = x_coords[0] + + # shape (N_irregular, 2) + irregular = source_plane_data_grid + + # Compute normalized mesh coordinates (floating indices) + fy = (irregular[:, 0] - y_min) / dy + fx = (irregular[:, 1] - x_min) / dx + + # Integer indices of top-left corners + ix = jnp.floor(fx).astype(jnp.int32) + iy = jnp.floor(fy).astype(jnp.int32) + + # Clip to stay within bounds + ix = jnp.clip(ix, 0, Nx - 2) + iy = jnp.clip(iy, 0, Ny - 2) + + # Local coordinates inside the cell (0 <= tx, ty <= 1) + tx = fx - ix + ty = fy - iy + + # Bilinear weights + w00 = (1 - tx) * (1 - ty) + w10 = tx * (1 - ty) + w01 = (1 - tx) * ty + w11 = tx * ty + + weights = jnp.stack([w00, w10, w01, w11], axis=1) # shape (N_irregular, 4) + + # Compute indices of 4 surrounding pixels in the flattened mesh + i00 = iy * Nx + ix + i10 = iy * Nx + (ix + 1) + i01 = (iy + 1) * Nx + ix + i11 = (iy + 1) * Nx + (ix + 1) + + mappings = jnp.stack([i00, i10, i01, i11], axis=1) # shape (N_irregular, 4) + + return mappings, weights + + @numba_util.jit() def pix_indexes_for_sub_slim_index_delaunay_from( source_plane_data_grid, diff --git a/autoarray/inversion/pixelization/mappers/rectangular.py b/autoarray/inversion/pixelization/mappers/rectangular.py index 357ee9956..3a6629244 100644 --- a/autoarray/inversion/pixelization/mappers/rectangular.py +++ b/autoarray/inversion/pixelization/mappers/rectangular.py @@ -1,12 +1,14 @@ +import jax.numpy as jnp import numpy as np from typing import Tuple from autoconf import cached_property +from autoarray.structures.grids.irregular_2d import Grid2DIrregular from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights -from autoarray.geometry import geometry_util +from autoarray.inversion.pixelization.mappers import mapper_util class MapperRectangular(AbstractMapper): @@ -95,19 +97,15 @@ def pix_sub_weights(self) -> PixSubWeights: dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index` are equal to 1.0. """ - mappings = geometry_util.grid_pixel_indexes_2d_slim_from( - grid_scaled_2d_slim=np.array(self.source_plane_data_grid.over_sampled), - shape_native=self.source_plane_mesh_grid.shape_native, - pixel_scales=self.source_plane_mesh_grid.pixel_scales, - origin=self.source_plane_mesh_grid.origin, - ).astype("int") - mappings = mappings.reshape((len(mappings), 1)) + mappings, weights = mapper_util.rectangular_mappings_weights_via_interpolation_from( + shape_native=self.shape_native, + source_plane_mesh_grid=self.source_plane_mesh_grid.array, + source_plane_data_grid=Grid2DIrregular(self.source_plane_data_grid.over_sampled).array, + ) return PixSubWeights( - mappings=mappings, - sizes=np.ones(len(mappings), dtype="int"), - weights=np.ones( - (len(self.source_plane_data_grid.over_sampled), 1), dtype="int" - ), + mappings=np.array(mappings), + sizes=4*np.ones(len(mappings), dtype="int"), + weights=np.array(weights), ) diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index 6519537fd..ae1e62f1b 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -92,8 +92,8 @@ def hull( # cast JAX arrays to base numpy arrays grid_convex = np.zeros((len(self.grid), 2)) - grid_convex[:, 0] = np.array(self.grid[:, 1]) - grid_convex[:, 1] = np.array(self.grid[:, 0]) + grid_convex[:, 0] = np.array(self.grid.array[:, 1]) + grid_convex[:, 1] = np.array(self.grid.array[:, 0]) try: hull = ConvexHull(grid_convex) From 8a69509173b6d3b88da06171f7ea6ccb080e607e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 24 Jun 2025 11:37:15 +0100 Subject: [PATCH 2/5] fix visualization unit tests --- autoarray/operators/contour.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index ae1e62f1b..52a510dfd 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -92,8 +92,12 @@ def hull( # cast JAX arrays to base numpy arrays grid_convex = np.zeros((len(self.grid), 2)) - grid_convex[:, 0] = np.array(self.grid.array[:, 1]) - grid_convex[:, 1] = np.array(self.grid.array[:, 0]) + try: + grid_convex[:, 0] = np.array(self.grid.array[:, 1]) + grid_convex[:, 1] = np.array(self.grid.array[:, 0]) + except AttributeError: + grid_convex[:, 0] = np.array(self.grid[:, 1]) + grid_convex[:, 1] = np.array(self.grid[:, 0]) try: hull = ConvexHull(grid_convex) From 0e1f23b67e9d0382334186d9f688fc6eee3b51b5 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 24 Jun 2025 11:54:48 +0100 Subject: [PATCH 3/5] test_autoarray/inversion/pixelization/mappers/test_rectangular.py --- .../pixelization/mappers/test_rectangular.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py index 026e230df..4d708723e 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py +++ b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py @@ -31,19 +31,17 @@ def test__pix_indexes_for_sub_slim_index__matches_util(): mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) - pix_indexes_for_sub_slim_index_util = np.array( - [ - aa.util.geometry.grid_pixel_indexes_2d_slim_from( - grid_scaled_2d_slim=np.array(grid.over_sampled), - shape_native=mesh_grid.shape_native, - pixel_scales=mesh_grid.pixel_scales, - origin=mesh_grid.origin, - ).astype("int") - ] - ).T + mappings, weights = aa.util.mapper.rectangular_mappings_weights_via_interpolation_from( + shape_native=(3,3), + source_plane_mesh_grid=mesh_grid.array, + source_plane_data_grid=aa.Grid2DIrregular(mapper_grids.source_plane_data_grid.over_sampled).array, + ) assert ( - mapper.pix_indexes_for_sub_slim_index == pix_indexes_for_sub_slim_index_util + mapper.pix_sub_weights.mappings == mappings + ).all() + assert ( + mapper.pix_sub_weights.weights == weights ).all() From a95464fa1458804d98b84bc8a2fffffadba4080c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 24 Jun 2025 11:58:40 +0100 Subject: [PATCH 4/5] test_autoarray/inversion/pixelization/mappers/test_factory.py -> rectanguilar mapping matrix now has interpoltion --- .../pixelization/mappers/rectangular.py | 14 ++++++++----- .../pixelization/mappers/test_factory.py | 21 ++++++++++--------- .../pixelization/mappers/test_rectangular.py | 20 +++++++++--------- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/autoarray/inversion/pixelization/mappers/rectangular.py b/autoarray/inversion/pixelization/mappers/rectangular.py index 3a6629244..878ab8233 100644 --- a/autoarray/inversion/pixelization/mappers/rectangular.py +++ b/autoarray/inversion/pixelization/mappers/rectangular.py @@ -98,14 +98,18 @@ def pix_sub_weights(self) -> PixSubWeights: are equal to 1.0. """ - mappings, weights = mapper_util.rectangular_mappings_weights_via_interpolation_from( - shape_native=self.shape_native, - source_plane_mesh_grid=self.source_plane_mesh_grid.array, - source_plane_data_grid=Grid2DIrregular(self.source_plane_data_grid.over_sampled).array, + mappings, weights = ( + mapper_util.rectangular_mappings_weights_via_interpolation_from( + shape_native=self.shape_native, + source_plane_mesh_grid=self.source_plane_mesh_grid.array, + source_plane_data_grid=Grid2DIrregular( + self.source_plane_data_grid.over_sampled + ).array, + ) ) return PixSubWeights( mappings=np.array(mappings), - sizes=4*np.ones(len(mappings), dtype="int"), + sizes=4 * np.ones(len(mappings), dtype="int"), weights=np.array(weights), ) diff --git a/test_autoarray/inversion/pixelization/mappers/test_factory.py b/test_autoarray/inversion/pixelization/mappers/test_factory.py index c08bca937..727e2ab97 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_factory.py +++ b/test_autoarray/inversion/pixelization/mappers/test_factory.py @@ -42,18 +42,19 @@ def test__rectangular_mapper(): (5.0, 5.0), 1.0e-4 ) assert mapper.source_plane_mesh_grid.origin == pytest.approx((0.5, 0.5), 1.0e-4) - assert ( - mapper.mapping_matrix - == np.array( + print(mapper.mapping_matrix) + assert mapper.mapping_matrix == pytest.approx( + np.array( [ - [0.0, 0.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0675, 0.5775, 0.18, 0.0075, -0.065, -0.1425, 0.0, 0.0375, 0.3375], + [0.18, -0.03, 0.0, 0.84, -0.14, 0.0, 0.18, -0.03, 0.0], + [0.0225, 0.105, 0.0225, 0.105, 0.49, 0.105, 0.0225, 0.105, 0.0225], + [0.0, -0.03, 0.18, 0.0, -0.14, 0.84, 0.0, -0.03, 0.18], + [0.0, 0.0, 0.0, -0.03, -0.14, -0.03, 0.18, 0.84, 0.18], ] - ) - ).all() + ), + 1.0e-4, + ) assert mapper.shape_native == (3, 3) diff --git a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py index 4d708723e..edfa53722 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py +++ b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py @@ -31,18 +31,18 @@ def test__pix_indexes_for_sub_slim_index__matches_util(): mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) - mappings, weights = aa.util.mapper.rectangular_mappings_weights_via_interpolation_from( - shape_native=(3,3), - source_plane_mesh_grid=mesh_grid.array, - source_plane_data_grid=aa.Grid2DIrregular(mapper_grids.source_plane_data_grid.over_sampled).array, + mappings, weights = ( + aa.util.mapper.rectangular_mappings_weights_via_interpolation_from( + shape_native=(3, 3), + source_plane_mesh_grid=mesh_grid.array, + source_plane_data_grid=aa.Grid2DIrregular( + mapper_grids.source_plane_data_grid.over_sampled + ).array, + ) ) - assert ( - mapper.pix_sub_weights.mappings == mappings - ).all() - assert ( - mapper.pix_sub_weights.weights == weights - ).all() + assert (mapper.pix_sub_weights.mappings == mappings).all() + assert (mapper.pix_sub_weights.weights == weights).all() def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7): From ecfe98d1bf7b4f3ccaacf90c09434cae59a908ad Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 24 Jun 2025 13:46:21 +0100 Subject: [PATCH 5/5] interpolate works but now need to remove convolver --- .../inversion/pixelization/mappers/mapper_util.py | 6 +++--- test_autoarray/inversion/inversion/test_factory.py | 10 ++-------- .../inversion/pixelization/mappers/test_factory.py | 1 - 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index 9efb049dd..6f2d2ded3 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -146,9 +146,9 @@ def data_slim_to_pixelization_unique_from( def rectangular_mappings_weights_via_interpolation_from( - shape_native : Tuple[int, int], - source_plane_data_grid: jnp.ndarray, - source_plane_mesh_grid: jnp.ndarray + shape_native: Tuple[int, int], + source_plane_data_grid: jnp.ndarray, + source_plane_mesh_grid: jnp.ndarray, ): """ Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid. diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index 984aab946..1b46169df 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -487,14 +487,8 @@ def test__inversion_matrices__x2_mappers( settings=aa.SettingsInversion(use_positive_only_solver=True), ) - assert ( - inversion.operated_mapping_matrix[0:9, 0:9] - == rectangular_mapper_7x7_3x3.mapping_matrix - ).all() - assert ( - inversion.operated_mapping_matrix[0:9, 9:18] - == delaunay_mapper_9_3x3.mapping_matrix - ).all() + assert inversion.operated_mapping_matrix[0:9, 0:9] == pytest.approx(rectangular_mapper_7x7_3x3.mapping_matrix, abs=1.0e-4) + assert inversion.operated_mapping_matrix[0:9, 9:18] == pytest.approx(delaunay_mapper_9_3x3.mapping_matrix, abs=1.0e-4) operated_mapping_matrix = np.hstack( [ diff --git a/test_autoarray/inversion/pixelization/mappers/test_factory.py b/test_autoarray/inversion/pixelization/mappers/test_factory.py index 727e2ab97..31a24ba2a 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_factory.py +++ b/test_autoarray/inversion/pixelization/mappers/test_factory.py @@ -42,7 +42,6 @@ def test__rectangular_mapper(): (5.0, 5.0), 1.0e-4 ) assert mapper.source_plane_mesh_grid.origin == pytest.approx((0.5, 0.5), 1.0e-4) - print(mapper.mapping_matrix) assert mapper.mapping_matrix == pytest.approx( np.array( [