diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 789dac386..54dc7a84f 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -43,6 +43,7 @@ from .inversion.pixelization.mappers.rectangular import MapperRectangular from .inversion.pixelization.mappers.delaunay import MapperDelaunay from .inversion.pixelization.mappers.voronoi import MapperVoronoi +from .inversion.pixelization.mappers.rectangular_uniform import MapperRectangularUniform from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh from .inversion.pixelization.mesh.abstract import AbstractMesh from .inversion.inversion.imaging.mapping import InversionImagingMapping @@ -75,6 +76,7 @@ from .operators.over_sampling.over_sampler import OverSampler from .structures.grids.irregular_2d import Grid2DIrregular from .structures.mesh.rectangular_2d import Mesh2DRectangular +from .structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform from .structures.mesh.voronoi_2d import Mesh2DVoronoi from .structures.mesh.delaunay_2d import Mesh2DDelaunay from .structures.arrays.kernel_2d import Kernel2D diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index b851089bb..d5d146a46 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -252,6 +252,9 @@ def w_tilde(self): indexes=indexes.astype("int"), lengths=lengths.astype("int"), noise_map_value=self.noise_map[0], + noise_map=self.noise_map, + psf=self.psf, + mask=self.mask, ) @classmethod diff --git a/autoarray/dataset/imaging/w_tilde.py b/autoarray/dataset/imaging/w_tilde.py index 985caeeb8..2cb5a5778 100644 --- a/autoarray/dataset/imaging/w_tilde.py +++ b/autoarray/dataset/imaging/w_tilde.py @@ -2,8 +2,12 @@ import logging import numpy as np +from autoconf import cached_property + from autoarray.dataset.abstract.w_tilde import AbstractWTilde +from autoarray.inversion.inversion.imaging import inversion_imaging_util + logger = logging.getLogger(__name__) @@ -13,6 +17,9 @@ def __init__( curvature_preload: np.ndarray, indexes: np.ndim, lengths: np.ndarray, + noise_map: np.ndarray, + psf: np.ndarray, + mask: np.ndarray, noise_map_value: float, ): """ @@ -44,3 +51,56 @@ def __init__( self.indexes = indexes self.lengths = lengths + self.noise_map = noise_map + self.psf = psf + self.mask = mask + + @cached_property + def w_matrix(self): + """ + The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF + convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the + curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the + PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging + datasets. + + The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, + making it impossible to store in memory and its use in linear algebra calculations extremely. The method + `w_tilde_curvature_preload_imaging_from` describes a compressed representation that overcomes this hurdles. It is + advised `w_tilde` and this method are only used for testing. + + Parameters + ---------- + noise_map_native + The two dimensional masked noise-map of values which w_tilde is computed from. + kernel_native + The two dimensional PSF kernel that w_tilde encodes the convolution of. + native_index_for_slim_index + An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. + + Returns + ------- + ndarray + A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of + the curvature matrix. + """ + + return inversion_imaging_util.w_tilde_curvature_imaging_from( + noise_map_native=np.array(self.noise_map.native.array).astype("float64"), + kernel_native=np.array(self.psf.native.array).astype("float64"), + native_index_for_slim_index=np.array( + self.mask.derive_indexes.native_for_slim + ).astype("int"), + ) + + @cached_property + def psf_operator_matrix_dense(self): + + return inversion_imaging_util.psf_operator_matrix_dense_from( + kernel_native=np.array(self.psf.native.array).astype("float64"), + native_index_for_slim_index=np.array( + self.mask.derive_indexes.native_for_slim + ).astype("int"), + native_shape=self.noise_map.shape_native, + correlate=False, + ) diff --git a/autoarray/fixtures.py b/autoarray/fixtures.py index 75957966f..6c1f14264 100644 --- a/autoarray/fixtures.py +++ b/autoarray/fixtures.py @@ -421,7 +421,7 @@ def make_rectangular_mapper_7x7_3x3(): adapt_data=aa.Array2D.ones(shape_native=(3, 3), pixel_scales=0.1), ) - return aa.MapperRectangular( + return aa.MapperRectangularUniform( mapper_grids=mapper_grids, border_relocator=make_border_relocator_2d_7x7(), regularization=make_regularization_constant(), diff --git a/autoarray/inversion/inversion/imaging/inversion_imaging_util.py b/autoarray/inversion/inversion/imaging/inversion_imaging_util.py index 699361638..285fb5eeb 100644 --- a/autoarray/inversion/inversion/imaging/inversion_imaging_util.py +++ b/autoarray/inversion/inversion/imaging/inversion_imaging_util.py @@ -1,10 +1,69 @@ +from scipy.signal import convolve2d +import jax.numpy as jnp import numpy as np from typing import Tuple from autoarray import numba_util +from scipy.signal import correlate2d + +import numpy as np + + +def psf_operator_matrix_dense_from( + kernel_native: np.ndarray, + native_index_for_slim_index: np.ndarray, # shape (N_pix, 2), native (y,x) coords of masked pixels + native_shape: tuple[int, int], + correlate: bool = True, +) -> np.ndarray: + """ + Construct a dense PSF operator W (N_pix x N_pix) that maps masked image pixels to masked image pixels. + + Parameters + ---------- + kernel_native : (Ky, Kx) PSF kernel. + native_index_for_slim_index : (N_pix, 2) array of int + Native (y, x) coords for each masked pixel. + native_shape : (Ny, Nx) + Native 2D image shape. + correlate : bool, default True + If True, use correlation convention (no kernel flip). + If False, use convolution convention (flip kernel). + + Returns + ------- + W : ndarray, shape (N_pix, N_pix) + Dense PSF operator. + """ + Ky, Kx = kernel_native.shape + ph, pw = Ky // 2, Kx // 2 + Ny, Nx = native_shape + N_pix = native_index_for_slim_index.shape[0] + + ker = kernel_native if correlate else kernel_native[::-1, ::-1] + + # Padded index grid: -1 everywhere, slim index where masked + index_padded = -np.ones((Ny + 2 * ph, Nx + 2 * pw), dtype=np.int64) + for p, (y, x) in enumerate(native_index_for_slim_index): + index_padded[y + ph, x + pw] = p + + # Neighborhood offsets + dy = np.arange(Ky) - ph + dx = np.arange(Kx) - pw + + W = np.zeros((N_pix, N_pix), dtype=float) + + for i, (y, x) in enumerate(native_index_for_slim_index): + yp = y + ph + xp = x + pw + for j, dy_ in enumerate(dy): + for k, dx_ in enumerate(dx): + neigh = index_padded[yp + dy_, xp + dx_] + if neigh >= 0: + W[i, neigh] += ker[j, k] + + return W -@numba_util.jit() def w_tilde_data_imaging_from( image_native: np.ndarray, noise_map_native: np.ndarray, @@ -44,32 +103,35 @@ def w_tilde_data_imaging_from( efficient calculation of the data vector. """ - kernel_shift_y = -(kernel_native.shape[1] // 2) - kernel_shift_x = -(kernel_native.shape[0] // 2) - - image_pixels = len(native_index_for_slim_index) - - w_tilde_data = np.zeros((image_pixels,)) + # 1) weight map = image / noise^2 (safe where noise==0) + weight_map = jnp.where( + noise_map_native > 0.0, image_native / (noise_map_native**2), 0.0 + ) - weight_map_native = image_native / noise_map_native**2.0 + Ky, Kx = kernel_native.shape + ph, pw = Ky // 2, Kx // 2 - for ip0 in range(image_pixels): - ip0_y, ip0_x = native_index_for_slim_index[ip0] - - value = 0.0 + # 2) pad so neighbourhood gathers never go OOB + padded = jnp.pad( + weight_map, ((ph, ph), (pw, pw)), mode="constant", constant_values=0.0 + ) - for k0_y in range(kernel_native.shape[0]): - for k0_x in range(kernel_native.shape[1]): - weight_value = weight_map_native[ - ip0_y + k0_y + kernel_shift_y, ip0_x + k0_x + kernel_shift_x - ] + # 3) build broadcasted neighbourhood indices for all requested pixels + # shift pixel coords into the padded frame + ys = native_index_for_slim_index[:, 0] + ph # (N,) + xs = native_index_for_slim_index[:, 1] + pw # (N,) - if not np.isnan(weight_value): - value += kernel_native[k0_y, k0_x] * weight_value + # kernel-relative offsets + dy = jnp.arange(Ky) - ph # (Ky,) + dx = jnp.arange(Kx) - pw # (Kx,) - w_tilde_data[ip0] = value + # broadcast to (N, Ky, Kx) + Y = ys[:, None, None] + dy[None, :, None] + X = xs[:, None, None] + dx[None, None, :] - return w_tilde_data + # 4) gather patches and correlate (no kernel flip) + patches = padded[Y, X] # (N, Ky, Kx) + return jnp.sum(patches * kernel_native[None, :, :], axis=(1, 2)) # (N,) @numba_util.jit() diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index b725adb6c..1cb7a2f6d 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from typing import Dict, List, Optional, Union @@ -75,12 +76,10 @@ def __init__( @cached_property def w_tilde_data(self): return inversion_imaging_util.w_tilde_data_imaging_from( - image_native=np.array(self.data.native.array).astype("float"), - noise_map_native=np.array(self.noise_map.native.array).astype("float"), - kernel_native=np.array(self.psf.native.array).astype("float"), - native_index_for_slim_index=np.array( - self.data.mask.derive_indexes.native_for_slim - ).astype("int"), + image_native=self.data.native.array, + noise_map_native=self.noise_map.native.array, + kernel_native=self.psf.native.array, + native_index_for_slim_index=self.data.mask.derive_indexes.native_for_slim, ) @property @@ -92,32 +91,33 @@ def _data_vector_mapper(self) -> np.ndarray: This method is used to compute part of the `data_vector` if there are also linear function list objects in the inversion, and is separated into a separate method to enable preloading of the mapper `data_vector`. """ - - if not self.has(cls=AbstractMapper): - return None - - data_vector = np.zeros(self.total_params) - - mapper_list = self.cls_list_from(cls=AbstractMapper) - mapper_param_range = self.param_range_list_from(cls=AbstractMapper) - - for mapper_index, mapper in enumerate(mapper_list): - data_vector_mapper = ( - inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=self.w_tilde_data, - data_to_pix_unique=np.array( - mapper.unique_mappings.data_to_pix_unique - ), - data_weights=np.array(mapper.unique_mappings.data_weights), - pix_lengths=np.array(mapper.unique_mappings.pix_lengths), - pix_pixels=mapper.params, - ) - ) - param_range = mapper_param_range[mapper_index] - - data_vector[param_range[0] : param_range[1],] = data_vector_mapper - - return data_vector + return jnp.dot(self.mapping_matrix.T, self.w_tilde_data) + + # if not self.has(cls=AbstractMapper): + # return None + # + # data_vector = np.zeros(self.total_params) + # + # mapper_list = self.cls_list_from(cls=AbstractMapper) + # mapper_param_range = self.param_range_list_from(cls=AbstractMapper) + # + # for mapper_index, mapper in enumerate(mapper_list): + # data_vector_mapper = ( + # inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( + # w_tilde_data=np.array(self.w_tilde_data), + # data_to_pix_unique=np.array( + # mapper.unique_mappings.data_to_pix_unique + # ), + # data_weights=np.array(mapper.unique_mappings.data_weights), + # pix_lengths=np.array(mapper.unique_mappings.pix_lengths), + # pix_pixels=mapper.params, + # ) + # ) + # param_range = mapper_param_range[mapper_index] + # + # data_vector[param_range[0] : param_range[1],] = data_vector_mapper + # + # return data_vector @cached_property def data_vector(self) -> np.ndarray: @@ -148,16 +148,17 @@ def _data_vector_x1_mapper(self) -> np.ndarray: This method computes the `data_vector` whenthere is a single mapper object in the `Inversion`, which circumvents `np.concatenate` for speed up. """ - - linear_obj = self.linear_obj_list[0] - - return inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=self.w_tilde_data, - data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, - data_weights=linear_obj.unique_mappings.data_weights, - pix_lengths=linear_obj.unique_mappings.pix_lengths, - pix_pixels=linear_obj.params, - ) + return self._data_vector_mapper + + # linear_obj = self.linear_obj_list[0] + # + # return inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( + # w_tilde_data=self.w_tilde_data, + # data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, + # data_weights=linear_obj.unique_mappings.data_weights, + # pix_lengths=linear_obj.unique_mappings.pix_lengths, + # pix_pixels=linear_obj.params, + # ) @property def _data_vector_multi_mapper(self) -> np.ndarray: @@ -172,7 +173,7 @@ def _data_vector_multi_mapper(self) -> np.ndarray: return np.concatenate( [ inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=self.w_tilde_data, + w_tilde_data=np.array(self.w_tilde_data), data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, data_weights=linear_obj.unique_mappings.data_weights, pix_lengths=linear_obj.unique_mappings.pix_lengths, @@ -195,7 +196,7 @@ def _data_vector_func_list_and_mapper(self) -> np.ndarray: separation of functions enables the `data_vector` to be preloaded in certain circumstances. """ - data_vector = self._data_vector_mapper + data_vector = np.array(self._data_vector_mapper) linear_func_param_range = self.param_range_list_from( cls=AbstractLinearObjFuncList @@ -359,7 +360,12 @@ def _curvature_matrix_x1_mapper(self) -> np.ndarray: This method computes the `curvature_matrix` when there is a single mapper object in the `Inversion`, which circumvents `block_diag` for speed up. """ - return self._curvature_matrix_mapper_diag + + return inversion_util.curvature_matrix_via_w_tilde_from( + w_tilde=self.w_tilde.w_matrix, mapping_matrix=self.mapping_matrix + ) + + # return self._curvature_matrix_mapper_diag @property def _curvature_matrix_multi_mapper(self) -> np.ndarray: @@ -511,7 +517,10 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: for linear_obj in self.linear_obj_list: reconstruction = reconstruction_dict[linear_obj] - if isinstance(linear_obj, AbstractMapper): + if isinstance(linear_obj, AbstractMapper) and self.has( + cls=AbstractLinearObjFuncList + ): + mapped_reconstructed_image = inversion_util.mapped_reconstructed_data_via_image_to_pix_unique_from( data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, data_weights=linear_obj.unique_mappings.data_weights, @@ -527,7 +536,24 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: values=mapped_reconstructed_image, mask=self.mask ) + elif isinstance(linear_obj, AbstractMapper) and not self.has( + cls=AbstractLinearObjFuncList + ): + + mapped_reconstructed_image = ( + inversion_util.mapped_reconstructed_data_via_w_tilde_from( + w_tilde=self.w_tilde.psf_operator_matrix_dense, + mapping_matrix=self.mapping_matrix, + reconstruction=reconstruction, + ) + ) + + mapped_reconstructed_image = Array2D( + values=mapped_reconstructed_image, mask=self.mask + ) + else: + operated_mapping_matrix = self.linear_func_operated_mapping_matrix_dict[ linear_obj ] diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index b2eb4f26c..6f9d7d580 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -4,8 +4,6 @@ from typing import List, Optional, Type -from autoconf import conf - from autoarray.inversion.inversion.settings import SettingsInversion from autoarray import numba_util @@ -36,8 +34,7 @@ def curvature_matrix_via_w_tilde_from( ndarray The curvature matrix `F` (see Warren & Dye 2003). """ - - return np.dot(mapping_matrix.T, np.dot(w_tilde, mapping_matrix)) + return jnp.dot(mapping_matrix.T, jnp.dot(w_tilde, mapping_matrix)) def curvature_matrix_with_added_to_diag_from( @@ -190,6 +187,34 @@ def mapped_reconstructed_data_via_mapping_matrix_from( return jnp.dot(mapping_matrix, reconstruction) +def mapped_reconstructed_data_via_w_tilde_from( + w_tilde: np.ndarray, mapping_matrix: np.ndarray, reconstruction: np.ndarray +) -> np.ndarray: + """ + Returns the reconstructed data vector from the unblurred mapping matrix `M`, + the reconstruction vector `s`, and the PSF convolution operator `w_tilde`. + + Equivalent to: + reconstructed = (W @ M) @ s + = W @ (M @ s) + + Parameters + ---------- + w_tilde + Array of shape [image_pixels, image_pixels], the PSF convolution operator. + mapping_matrix + Array of shape [image_pixels, source_pixels], unblurred mapping matrix. + reconstruction + Array of shape [source_pixels], solution vector. + + Returns + ------- + ndarray + The reconstructed data vector of shape [image_pixels]. + """ + return w_tilde @ (mapping_matrix @ reconstruction) + + def reconstruction_positive_negative_from( data_vector: np.ndarray, curvature_reg_matrix: np.ndarray, diff --git a/autoarray/inversion/pixelization/mappers/factory.py b/autoarray/inversion/pixelization/mappers/factory.py index 965213b03..689f35011 100644 --- a/autoarray/inversion/pixelization/mappers/factory.py +++ b/autoarray/inversion/pixelization/mappers/factory.py @@ -4,6 +4,7 @@ from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray.inversion.regularization.abstract import AbstractRegularization from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular +from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay from autoarray.structures.mesh.voronoi_2d import Mesh2DVoronoi @@ -39,10 +40,19 @@ def mapper_from( from autoarray.inversion.pixelization.mappers.rectangular import ( MapperRectangular, ) + from autoarray.inversion.pixelization.mappers.rectangular_uniform import ( + MapperRectangularUniform, + ) from autoarray.inversion.pixelization.mappers.delaunay import MapperDelaunay from autoarray.inversion.pixelization.mappers.voronoi import MapperVoronoi - if isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular): + if isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangularUniform): + return MapperRectangularUniform( + mapper_grids=mapper_grids, + border_relocator=border_relocator, + regularization=regularization, + ) + elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular): return MapperRectangular( mapper_grids=mapper_grids, border_relocator=border_relocator, diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index c3e8d470a..81b562bd2 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -105,6 +105,180 @@ def data_slim_to_pixelization_unique_from( return data_to_pix_unique, data_weights, pix_lengths +import jax +import jax.numpy as jnp + +from functools import partial + + +def forward_interp(xp, yp, x): + return jax.vmap(jnp.interp, in_axes=(1, 1, None, None, None))(x, xp, yp, 0, 1).T + + +def reverse_interp(xp, yp, x): + return jax.vmap(jnp.interp, in_axes=(1, None, 1))(x, xp, yp).T + + +def create_transforms(traced_points): + # make functions that takes a set of traced points + # stored in a (N, 2) array and return functions that + # take in (N, 2) arrays and transform the values into + # the range (0, 1) and the inverse transform + N = traced_points.shape[0] # // 2 + t = jnp.arange(1, N + 1) / (N + 1) + + sort_points = jnp.sort(traced_points, axis=0) # [::2] + + transform = partial(forward_interp, sort_points, t) + inv_transform = partial(reverse_interp, t, sort_points) + return transform, inv_transform + + +def adaptive_rectangular_transformed_grid_from(source_plane_data_grid, grid): + mu = source_plane_data_grid.mean(axis=0) + scale = source_plane_data_grid.std(axis=0).min() + source_grid_scaled = (source_plane_data_grid - mu) / scale + + transform, inv_transform = create_transforms(source_grid_scaled) + + def inv_full(U): + return inv_transform(U) * scale + mu + + return inv_full(grid) + + +def adaptive_rectangular_areas_from(source_grid_size, source_plane_data_grid): + + pixel_edges_1d = jnp.linspace(0, 1, source_grid_size + 1) + + mu = source_plane_data_grid.mean(axis=0) + scale = source_plane_data_grid.std(axis=0).min() + source_grid_scaled = (source_plane_data_grid - mu) / scale + + transform, inv_transform = create_transforms(source_grid_scaled) + + def inv_full(U): + return inv_transform(U) * scale + mu + + pixel_edges = inv_full(jnp.stack([pixel_edges_1d, pixel_edges_1d]).T) + pixel_lengths = jnp.diff(pixel_edges, axis=0).squeeze() # shape (N_source, 2) + + dy = pixel_lengths[:, 0] + dx = pixel_lengths[:, 1] + + return jnp.outer(dy, dx).flatten() + + +def adaptive_rectangular_mappings_weights_via_interpolation_from( + source_grid_size: int, + source_plane_data_grid, + source_plane_data_grid_over_sampled, +): + """ + Compute bilinear interpolation indices and weights for mapping an oversampled + source-plane grid onto a regular rectangular pixelization. + + This function takes a set of irregularly-sampled source-plane coordinates and + builds an adaptive mapping onto a `source_grid_size x source_grid_size` rectangular + pixelization using bilinear interpolation. The interpolation is expressed as: + + f(x, y) ≈ w_bl * f(ix_down, iy_down) + + w_br * f(ix_up, iy_down) + + w_tl * f(ix_down, iy_up) + + w_tr * f(ix_up, iy_up) + + where `(ix_down, ix_up, iy_down, iy_up)` are the integer grid coordinates + surrounding the continuous position `(x, y)`. + + Steps performed: + 1. Normalize the source-plane grid by subtracting its mean and dividing by + the minimum axis standard deviation (to balance scaling). + 2. Construct forward/inverse transforms which map the grid into the unit square [0,1]^2. + 3. Transform the oversampled source-plane grid into [0,1]^2, then scale it + to index space `[0, source_grid_size)`. + 4. Compute floor/ceil along x and y axes to find the enclosing rectangular cell. + 5. Build the four corner indices: bottom-left (bl), bottom-right (br), + top-left (tl), and top-right (tr). + 6. Flatten the 2D indices into 1D indices suitable for scatter operations, + with a flipped row-major convention: row = source_grid_size - i, col = j. + 7. Compute bilinear interpolation weights (`w_bl, w_br, w_tl, w_tr`). + 8. Return arrays of flattened indices and weights of shape `(N, 4)`, where + `N` is the number of oversampled coordinates. + + Parameters + ---------- + source_grid_size : int + The number of pixels along one dimension of the rectangular pixelization. + The grid is square: (source_grid_size x source_grid_size). + source_plane_data_grid : (M, 2) ndarray + The base source-plane coordinates, used to define normalization and transforms. + source_plane_data_grid_over_sampled : (N, 2) ndarray + Oversampled source-plane coordinates to be interpolated onto the rectangular grid. + + Returns + ------- + flat_indices : (N, 4) int ndarray + The flattened indices of the four neighboring pixel corners for each oversampled point. + Order: [bl, br, tl, tr]. + weights : (N, 4) float ndarray + The bilinear interpolation weights for each of the four neighboring pixels. + Order: [w_bl, w_br, w_tl, w_tr]. + """ + + # --- Step 1. Normalize grid --- + mu = source_plane_data_grid.mean(axis=0) + scale = source_plane_data_grid.std(axis=0).min() + source_grid_scaled = (source_plane_data_grid - mu) / scale + + # --- Step 2. Build transforms --- + transform, inv_transform = create_transforms(source_grid_scaled) + + # --- Step 3. Transform oversampled grid into index space --- + grid_over_sampled_scaled = (source_plane_data_grid_over_sampled - mu) / scale + grid_over_sampled_transformed = transform(grid_over_sampled_scaled) + grid_over_index = source_grid_size * grid_over_sampled_transformed + + # --- Step 4. Floor/ceil indices --- + ix_down = jnp.floor(grid_over_index[:, 0]) + ix_up = jnp.ceil(grid_over_index[:, 0]) + iy_down = jnp.floor(grid_over_index[:, 1]) + iy_up = jnp.ceil(grid_over_index[:, 1]) + + # --- Step 5. Four corners --- + idx_tl = jnp.stack([ix_up, iy_down], axis=1) + idx_tr = jnp.stack([ix_up, iy_up], axis=1) + idx_br = jnp.stack([ix_down, iy_up], axis=1) + idx_bl = jnp.stack([ix_down, iy_down], axis=1) + + # --- Step 6. Flatten indices --- + def flatten(idx, n): + row = n - idx[:, 0] + col = idx[:, 1] + return row * n + col + + flat_tl = flatten(idx_tl, source_grid_size) + flat_tr = flatten(idx_tr, source_grid_size) + flat_bl = flatten(idx_bl, source_grid_size) + flat_br = flatten(idx_br, source_grid_size) + + flat_indices = jnp.stack([flat_tl, flat_tr, flat_bl, flat_br], axis=1).astype( + "int64" + ) + + # --- Step 7. Bilinear interpolation weights --- + t_row = (grid_over_index[:, 0] - ix_down) / (ix_up - ix_down + 1e-12) + t_col = (grid_over_index[:, 1] - iy_down) / (iy_up - iy_down + 1e-12) + + # Weights + w_tl = (1 - t_row) * (1 - t_col) + w_tr = (1 - t_row) * t_col + w_bl = t_row * (1 - t_col) + w_br = t_row * t_col + weights = jnp.stack([w_tl, w_tr, w_bl, w_br], axis=1) + + return flat_indices, weights + + def rectangular_mappings_weights_via_interpolation_from( shape_native: Tuple[int, int], source_plane_data_grid: jnp.ndarray, diff --git a/autoarray/inversion/pixelization/mappers/rectangular.py b/autoarray/inversion/pixelization/mappers/rectangular.py index 79d2ad774..f9f398c69 100644 --- a/autoarray/inversion/pixelization/mappers/rectangular.py +++ b/autoarray/inversion/pixelization/mappers/rectangular.py @@ -97,30 +97,22 @@ def pix_sub_weights(self) -> PixSubWeights: dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index` are equal to 1.0. """ - # from autoarray.geometry import geometry_util - # - # mappings = geometry_util.grid_pixel_indexes_2d_slim_from( - # grid_scaled_2d_slim=np.array(self.source_plane_data_grid.over_sampled), - # shape_native=self.source_plane_mesh_grid.shape_native, - # pixel_scales=self.source_plane_mesh_grid.pixel_scales, - # origin=self.source_plane_mesh_grid.origin, - # ).astype("int") - # - # mappings = mappings.reshape((len(mappings), 1)) - # - # return PixSubWeights( - # mappings=mappings, - # sizes=np.ones(len(mappings), dtype="int"), - # weights=np.ones( - # (len(self.source_plane_data_grid.over_sampled), 1), dtype="int" - # ), + + # mappings, weights = ( + # mapper_util.rectangular_mappings_weights_via_interpolation_from( + # shape_native=self.shape_native, + # source_plane_mesh_grid=self.source_plane_mesh_grid.array, + # source_plane_data_grid=self.source_plane_data_grid.over_sampled, + # ) # ) mappings, weights = ( - mapper_util.rectangular_mappings_weights_via_interpolation_from( - shape_native=self.shape_native, - source_plane_mesh_grid=self.source_plane_mesh_grid.array, - source_plane_data_grid=self.source_plane_data_grid.over_sampled, + mapper_util.adaptive_rectangular_mappings_weights_via_interpolation_from( + source_grid_size=self.shape_native[0], + source_plane_data_grid=self.source_plane_data_grid.array, + source_plane_data_grid_over_sampled=jnp.array( + self.source_plane_data_grid.over_sampled + ), ) ) @@ -129,3 +121,36 @@ def pix_sub_weights(self) -> PixSubWeights: sizes=4 * jnp.ones(len(mappings), dtype="int"), weights=weights, ) + + @cached_property + def areas_transformed(self): + """ + A class packing the ndarrays describing the neighbors of every pixel in the rectangular pixelization (see + `Neighbors` for a complete description of the neighboring scheme). + + The neighbors of a rectangular pixelization are computed by exploiting the uniform and symmetric nature of the + rectangular grid, as described in the method `mesh_util.rectangular_neighbors_from`. + """ + return mapper_util.adaptive_rectangular_areas_from( + source_grid_size=self.shape_native[0], + source_plane_data_grid=self.source_plane_data_grid.array, + ) + + @cached_property + def edges_transformed(self): + """ + A class packing the ndarrays describing the neighbors of every pixel in the rectangular pixelization (see + `Neighbors` for a complete description of the neighboring scheme). + + The neighbors of a rectangular pixelization are computed by exploiting the uniform and symmetric nature of the + rectangular grid, as described in the method `mesh_util.rectangular_neighbors_from`. + """ + + # edges defined in 0 -> 1 space, there is one more edge than pixel centers on each side + edges = jnp.linspace(0, 1, self.shape_native[0] + 1) + edges_reshaped = jnp.stack([edges, edges]).T + + return mapper_util.adaptive_rectangular_transformed_grid_from( + source_plane_data_grid=self.source_plane_data_grid.array, + grid=edges_reshaped, + ) diff --git a/autoarray/inversion/pixelization/mappers/rectangular_uniform.py b/autoarray/inversion/pixelization/mappers/rectangular_uniform.py new file mode 100644 index 000000000..3c58813cb --- /dev/null +++ b/autoarray/inversion/pixelization/mappers/rectangular_uniform.py @@ -0,0 +1,106 @@ +import jax.numpy as jnp + +from autoconf import cached_property + +from autoarray.inversion.pixelization.mappers.rectangular import MapperRectangular +from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights + +from autoarray.inversion.pixelization.mappers import mapper_util + + +class MapperRectangularUniform(MapperRectangular): + """ + To understand a `Mapper` one must be familiar `Mesh` objects and the `mesh` and `pixelization` packages, where + the four grids grouped in a `MapperGrids` object are explained (`image_plane_data_grid`, `source_plane_data_grid`, + `image_plane_mesh_grid`,`source_plane_mesh_grid`) + + If you are unfamliar withe above objects, read through the docstrings of the `pixelization`, `mesh` and + `mapper_grids` packages. + + A `Mapper` determines the mappings between the masked data grid's pixels (`image_plane_data_grid` and + `source_plane_data_grid`) and the mesh's pixels (`image_plane_mesh_grid` and `source_plane_mesh_grid`). + + The 1D Indexing of each grid is identical in the `data` and `source` frames (e.g. the transformation does not + change the indexing, such that `source_plane_data_grid[0]` corresponds to the transformed value + of `image_plane_data_grid[0]` and so on). + + A mapper therefore only needs to determine the index mappings between the `grid_slim` and `mesh_grid`, + noting that associations are made by pairing `source_plane_mesh_grid` with `source_plane_data_grid`. + + Mappings are represented in the 2D ndarray `pix_indexes_for_sub_slim_index`, whereby the index of + a pixel on the `mesh_grid` maps to the index of a pixel on the `grid_slim` as follows: + + - pix_indexes_for_sub_slim_index[0, 0] = 0: the data's 1st sub-pixel maps to the mesh's 1st pixel. + - pix_indexes_for_sub_slim_index[1, 0] = 3: the data's 2nd sub-pixel maps to the mesh's 4th pixel. + - pix_indexes_for_sub_slim_index[2, 0] = 1: the data's 3rd sub-pixel maps to the mesh's 2nd pixel. + + The second dimension of this array (where all three examples above are 0) is used for cases where a + single pixel on the `grid_slim` maps to multiple pixels on the `mesh_grid`. For example, a + `Delaunay` triangulation, where every `grid_slim` pixel maps to three Delaunay pixels (the corners of the + triangles) with varying interpolation weights . + + For a `Rectangular` mesh every pixel in the masked data maps to only one pixel, thus the second + dimension of `pix_indexes_for_sub_slim_index` is always of size 1. + + The mapper allows us to create a mapping matrix, which is a matrix representing the mapping between every + unmasked data pixel annd the pixels of a mesh. This matrix is the basis of performing an `Inversion`, + which reconstructs the data using the `source_plane_mesh_grid`. + + Parameters + ---------- + mapper_grids + An object containing the data grid and mesh grid in both the data-frame and source-frame used by the + mapper to map data-points to linear object parameters. + regularization + The regularization scheme which may be applied to this linear object in order to smooth its solution, + which for a mapper smooths neighboring pixels on the mesh. + """ + + @cached_property + def pix_sub_weights(self) -> PixSubWeights: + """ + Computes the following three quantities describing the mappings between of every sub-pixel in the masked data + and pixel in the `Rectangular` mesh. + + - `pix_indexes_for_sub_slim_index`: the mapping of every data pixel (given its `sub_slim_index`) + to mesh pixels (given their `pix_indexes`). + + - `pix_sizes_for_sub_slim_index`: the number of mappings of every data pixel to mesh pixels. + + - `pix_weights_for_sub_slim_index`: the interpolation weights of every data pixel's mesh + pixel mapping + + These are packaged into the class `PixSubWeights` with attributes `mappings`, `sizes` and `weights`. + + The `sub_slim_index` refers to the masked data sub-pixels and `pix_indexes` the mesh pixel indexes, + for example: + + - `pix_indexes_for_sub_slim_index[0, 0] = 2`: The data's first (index 0) sub-pixel maps to the Rectangular + mesh's third (index 2) pixel. + + - `pix_indexes_for_sub_slim_index[2, 0] = 4`: The data's third (index 2) sub-pixel maps to the Rectangular + mesh's fifth (index 4) pixel. + + The second dimension of the array `pix_indexes_for_sub_slim_index`, which is 0 in both examples above, is used + for cases where a data pixel maps to more than one mesh pixel (for example a `Delaunay` triangulation + where each data pixel maps to 3 Delaunay triangles with interpolation weights). The weights of multiple mappings + are stored in the array `pix_weights_for_sub_slim_index`. + + For a Rectangular pixelization each data sub-pixel maps to a single mesh pixel, thus the second + dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index` + are equal to 1.0. + """ + + mappings, weights = ( + mapper_util.rectangular_mappings_weights_via_interpolation_from( + shape_native=self.shape_native, + source_plane_mesh_grid=self.source_plane_mesh_grid.array, + source_plane_data_grid=self.source_plane_data_grid.over_sampled, + ) + ) + + return PixSubWeights( + mappings=mappings, + sizes=4 * jnp.ones(len(mappings), dtype="int"), + weights=weights, + ) diff --git a/autoarray/inversion/pixelization/mesh/__init__.py b/autoarray/inversion/pixelization/mesh/__init__.py index a14f53f69..28f35f116 100644 --- a/autoarray/inversion/pixelization/mesh/__init__.py +++ b/autoarray/inversion/pixelization/mesh/__init__.py @@ -1,4 +1,5 @@ from .abstract import AbstractMesh as Mesh from .rectangular import Rectangular +from .rectangular_uniform import RectangularUniform from .voronoi import Voronoi from .delaunay import Delaunay diff --git a/autoarray/inversion/pixelization/mesh/mesh_util.py b/autoarray/inversion/pixelization/mesh/mesh_util.py index 420d9c2ab..514fb5e1d 100644 --- a/autoarray/inversion/pixelization/mesh/mesh_util.py +++ b/autoarray/inversion/pixelization/mesh/mesh_util.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from typing import List, Tuple, Union @@ -301,6 +302,58 @@ def rectangular_central_neighbors( return neighbors, neighbors_sizes +def rectangular_edges_from(shape_native, pixel_scales): + """ + Returns all pixel edges for a rectangular grid as a JAX array of shape (N, 4, 2, 2), + where N = Ny * Nx. Edge order per pixel matches the user's convention: + + 0: (x1, y0) -> (x1, y1) + 1: (x1, y1) -> (x0, y1) + 2: (x0, y1) -> (x0, y0) + 3: (x0, y0) -> (x1, y0) + + Notes + ----- + - x is flipped so that the leftmost column has the largest +x (e.g. centres start at x=+1.0). + - y increases upward (top row has the most negative y when dy>0). + """ + Ny, Nx = shape_native + dy, dx = pixel_scales + + # Grid edge coordinates. Flip x so leftmost column has largest +x, matching your convention. + x_edges = ((jnp.arange(Nx + 1) - Nx / 2) * dx)[::-1] + y_edges = (jnp.arange(Ny + 1) - Ny / 2) * dy + + edges_list = [] + + # Pixel order: row-major (y outer, x inner). If you want column-major, swap the loop nesting. + for j in range(Ny): + for i in range(Nx): + y0, y1 = y_edges[i], y_edges[i + 1] + xa, xb = ( + x_edges[j], + x_edges[j + 1], + ) # xa is the "right" boundary in your convention + + # Edge order to match your pytest: [(xa,y0)->(xa,y1), (xa,y1)->(xb,y1), (xb,y1)->(xb,y0), (xb,y0)->(xa,y0)] + e0 = jnp.array( + [[xa, y0], [xa, y1]] + ) # "top" in your test (vertical at x=xa) + e1 = jnp.array( + [[xa, y1], [xb, y1]] + ) # "right" in your test (horizontal at y=y1) + e2 = jnp.array( + [[xb, y1], [xb, y0]] + ) # "bottom" in your test (vertical at x=xb) + e3 = jnp.array( + [[xb, y0], [xa, y0]] + ) # "left" in your test (horizontal at y=y0) + + edges_list.append(jnp.stack([e0, e1, e2, e3], axis=0)) + + return jnp.stack(edges_list, axis=0) + + def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]: """ Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization, diff --git a/autoarray/inversion/pixelization/mesh/rectangular.py b/autoarray/inversion/pixelization/mesh/rectangular.py index e95a05ec6..318ef0436 100644 --- a/autoarray/inversion/pixelization/mesh/rectangular.py +++ b/autoarray/inversion/pixelization/mesh/rectangular.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple from autoarray.structures.grids.uniform_2d import Grid2D diff --git a/autoarray/inversion/pixelization/mesh/rectangular_uniform.py b/autoarray/inversion/pixelization/mesh/rectangular_uniform.py new file mode 100644 index 000000000..acefb8288 --- /dev/null +++ b/autoarray/inversion/pixelization/mesh/rectangular_uniform.py @@ -0,0 +1,32 @@ +from autoarray.inversion.pixelization.mesh.rectangular import Rectangular + +from typing import Optional + + +from autoarray.structures.grids.uniform_2d import Grid2D +from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform + + +class RectangularUniform(Rectangular): + + def mesh_grid_from( + self, + source_plane_data_grid: Optional[Grid2D] = None, + source_plane_mesh_grid: Optional[Grid2D] = None, + ) -> Mesh2DRectangularUniform: + """ + Return the rectangular `source_plane_mesh_grid` as a `Mesh2DRectangular` object, which provides additional + functionality for perform operatons that exploit the geometry of a rectangular pixelization. + + Parameters + ---------- + source_plane_data_grid + The (y,x) grid of coordinates over which the rectangular pixelization is overlaid, where this grid may have + had exterior pixels relocated to its edge via the border. + source_plane_mesh_grid + Not used for a rectangular pixelization, because the pixelization grid in the `source` frame is computed + by overlaying the `source_plane_data_grid` with the rectangular pixelization. + """ + return Mesh2DRectangularUniform.overlay_grid( + shape_native=self.shape, grid=source_plane_data_grid.over_sampled + ) diff --git a/autoarray/plot/mat_plot/two_d.py b/autoarray/plot/mat_plot/two_d.py index 5e38bd896..340366151 100644 --- a/autoarray/plot/mat_plot/two_d.py +++ b/autoarray/plot/mat_plot/two_d.py @@ -564,18 +564,79 @@ def _plot_rectangular_mapper( else: ax = self.setup_subplot(aspect=aspect_inv) + shape_native = mapper.source_plane_mesh_grid.shape_native + if pixel_values is not None: - self.plot_array( - array=pixel_values, - visuals_2d=visuals_2d, - auto_labels=auto_labels, - bypass=True, + + from autoarray.inversion.pixelization.mappers.rectangular_uniform import ( + MapperRectangularUniform, + ) + from autoarray.inversion.pixelization.mappers.rectangular import ( + MapperRectangular, ) - self.axis.set(extent=extent, grid=mapper.source_plane_mesh_grid) + if isinstance(mapper, MapperRectangularUniform): - self.yticks.set(min_value=extent[2], max_value=extent[3], units=self.units) - self.xticks.set(min_value=extent[0], max_value=extent[1], units=self.units) + self.plot_array( + array=pixel_values, + visuals_2d=visuals_2d, + auto_labels=auto_labels, + bypass=True, + ) + + else: + + norm = self.cmap.norm_from( + array=pixel_values.array, use_log10=self.use_log10 + ) + + edges_transformed = mapper.edges_transformed + + edges_transformed_dense = np.moveaxis( + np.stack(np.meshgrid(*edges_transformed.T)), 0, 2 + ) + + plt.pcolormesh( + edges_transformed_dense[..., 0], + edges_transformed_dense[..., 1], + pixel_values.array.reshape(shape_native), + shading="flat", + norm=norm, + cmap=self.cmap.cmap, + ) + + if self.colorbar is not False: + + cb = self.colorbar.set( + units=self.units, + ax=ax, + norm=norm, + cb_unit=auto_labels.cb_unit, + use_log10=self.use_log10, + ) + self.colorbar_tickparams.set(cb=cb) + + extent_axis = self.axis.config_dict.get("extent") + + if extent_axis is None: + extent_axis = extent + + self.axis.set(extent=extent_axis) + + self.tickparams.set() + self.yticks.set( + min_value=extent_axis[2], + max_value=extent_axis[3], + units=self.units, + pixels=shape_native[0], + ) + + self.xticks.set( + min_value=extent_axis[0], + max_value=extent_axis[1], + units=self.units, + pixels=shape_native[1], + ) if not isinstance(self.text, list): self.text.set() @@ -587,13 +648,12 @@ def _plot_rectangular_mapper( else: [annotate.set() for annotate in self.annotate] - self.grid_plot.plot_rectangular_grid_lines( - extent=mapper.source_plane_mesh_grid.geometry.extent, - shape_native=mapper.shape_native, - ) + # self.grid_plot.plot_rectangular_grid_lines( + # extent=mapper.source_plane_mesh_grid.geometry.extent, + # shape_native=mapper.shape_native, + # ) self.title.set(auto_title=auto_labels.title) - self.tickparams.set() self.ylabel.set() self.xlabel.set() diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 3ae0bb1cc..b2992278a 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -1,5 +1,6 @@ import jax import jax.numpy as jnp +from jax import lax import numpy as np from pathlib import Path from typing import List, Tuple, Union diff --git a/autoarray/structures/mesh/abstract_2d.py b/autoarray/structures/mesh/abstract_2d.py index 55fb7b9b8..cf630443e 100644 --- a/autoarray/structures/mesh/abstract_2d.py +++ b/autoarray/structures/mesh/abstract_2d.py @@ -5,6 +5,15 @@ class Abstract2DMesh(Structure): + + @property + def slim(self) -> "Structure": + raise NotImplementedError() + + @property + def native(self) -> Structure: + raise NotImplementedError() + @property def parameters(self) -> int: return self.pixels diff --git a/autoarray/structures/mesh/rectangular_2d.py b/autoarray/structures/mesh/rectangular_2d.py index e8c7a8a82..64f51379b 100644 --- a/autoarray/structures/mesh/rectangular_2d.py +++ b/autoarray/structures/mesh/rectangular_2d.py @@ -3,25 +3,22 @@ from typing import List, Optional, Tuple +from autoconf import cached_property + from autoarray import type as ty from autoarray.inversion.linear_obj.neighbors import Neighbors -from autoarray.inversion.pixelization.mesh import mesh_util from autoarray.mask.mask_2d import Mask2D from autoarray.structures.abstract_structure import Structure from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.structures.grids import grid_2d_util + from autoarray.structures.mesh.abstract_2d import Abstract2DMesh -from autoconf import cached_property +from autoarray.inversion.pixelization.mappers import mapper_util +from autoarray.inversion.pixelization.mesh import mesh_util +from autoarray.structures.grids import grid_2d_util -class Mesh2DRectangular(Abstract2DMesh): - @property - def slim(self) -> "Structure": - raise NotImplementedError() - @property - def native(self) -> Structure: - raise NotImplementedError() +class Mesh2DRectangular(Abstract2DMesh): def __init__( self, @@ -112,7 +109,7 @@ def overlay_grid( origin=origin, ) - return Mesh2DRectangular( + return cls( values=grid_slim, shape_native=shape_native, pixel_scales=pixel_scales, diff --git a/autoarray/structures/mesh/rectangular_2d_uniform.py b/autoarray/structures/mesh/rectangular_2d_uniform.py new file mode 100644 index 000000000..688b7c53b --- /dev/null +++ b/autoarray/structures/mesh/rectangular_2d_uniform.py @@ -0,0 +1,6 @@ +from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular + + +class Mesh2DRectangularUniform(Mesh2DRectangular): + + pass diff --git a/autoarray/structures/mesh/triangulation_2d.py b/autoarray/structures/mesh/triangulation_2d.py index 3ddb95c01..3eeaa249d 100644 --- a/autoarray/structures/mesh/triangulation_2d.py +++ b/autoarray/structures/mesh/triangulation_2d.py @@ -14,14 +14,6 @@ class Abstract2DMeshTriangulation(Abstract2DMesh): - @property - def slim(self) -> "Structure": - raise NotImplementedError() - - @property - def native(self) -> Structure: - raise NotImplementedError() - def __init__( self, values: Union[np.ndarray, List], diff --git a/test_autoarray/inversion/inversion/imaging/test_imaging.py b/test_autoarray/inversion/inversion/imaging/test_imaging.py index bd54c35f1..46bfe55ac 100644 --- a/test_autoarray/inversion/inversion/imaging/test_imaging.py +++ b/test_autoarray/inversion/inversion/imaging/test_imaging.py @@ -144,7 +144,13 @@ def test__w_tilde_checks_noise_map_and_raises_exception_if_preloads_dont_match_n grid = aa.Grid2D.from_mask(mask=mask) w_tilde = WTildeImaging( - curvature_preload=None, indexes=None, lengths=None, noise_map_value=2.0 + curvature_preload=None, + indexes=None, + lengths=None, + noise_map_value=2.0, + noise_map=None, + psf=None, + mask=mask, ) with pytest.raises(exc.InversionException): diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index f96731aa9..120a1da9c 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -183,7 +183,7 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): psf = kernel - pixelization = aa.mesh.Rectangular(shape=(20, 20)) + pixelization = aa.mesh.RectangularUniform(shape=(20, 20)) # TODO : Use pytest.parameterize @@ -243,7 +243,7 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): data_vector_via_w_tilde = ( aa.util.inversion_imaging.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=w_tilde_data, + w_tilde_data=np.array(w_tilde_data), data_to_pix_unique=data_to_pix_unique.astype("int"), data_weights=data_weights, pix_lengths=pix_lengths.astype("int"), @@ -266,7 +266,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): psf = kernel - pixelization = aa.mesh.Rectangular(shape=(20, 20)) + pixelization = aa.mesh.RectangularUniform(shape=(20, 20)) mapper_grids = pixelization.mapper_grids_from( mask=mask, @@ -313,7 +313,7 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): psf = kernel - pixelization = aa.mesh.Rectangular(shape=(20, 20)) + pixelization = aa.mesh.RectangularUniform(shape=(20, 20)) for sub_size in range(1, 2, 3): grid = aa.Grid2D.from_mask(mask=mask, over_sample_size=sub_size) diff --git a/test_autoarray/inversion/inversion/test_abstract.py b/test_autoarray/inversion/inversion/test_abstract.py index 8880b544c..5f1a918cf 100644 --- a/test_autoarray/inversion/inversion/test_abstract.py +++ b/test_autoarray/inversion/inversion/test_abstract.py @@ -115,8 +115,8 @@ def test__curvature_matrix__via_w_tilde__identical_to_mapping(): grid = aa.Grid2D.from_mask(mask=mask, over_sample_size=1) - mesh_0 = aa.mesh.Rectangular(shape=(3, 3)) - mesh_1 = aa.mesh.Rectangular(shape=(4, 4)) + mesh_0 = aa.mesh.RectangularUniform(shape=(3, 3)) + mesh_1 = aa.mesh.RectangularUniform(shape=(4, 4)) mapper_grids_0 = mesh_0.mapper_grids_from( mask=mask, @@ -426,17 +426,6 @@ def test__data_subtracted_dict(): assert (inversion.data_subtracted_dict[linear_obj_1] == 2.0 * np.ones(3)).all() -def test__reconstruction_raises_exception_for_linalg_error(): - # noinspection PyTypeChecker - inversion = aa.m.MockInversion( - data_vector=np.ones(3), curvature_reg_matrix=np.ones((3, 3)) - ) - - with pytest.raises(exc.InversionException): - # noinspection PyStatementEffect - inversion.reconstruction - - def test__regularization_term(): reconstruction = np.array([1.0, 1.0, 1.0]) diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index ed3e6fa53..b8cdaf31f 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -69,11 +69,14 @@ def test__inversion_imaging__via_mapper( settings=aa.SettingsInversion(use_w_tilde=False), ) - assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionImagingMapping) assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( - 7.257175708246, 1.0e-4 + 7.2571757082, 1.0e-4 ) + # assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( + # 4.609440907938719, 1.0e-4 + # ) assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) inversion = aa.Inversion( @@ -82,7 +85,7 @@ def test__inversion_imaging__via_mapper( settings=aa.SettingsInversion(use_w_tilde=True), ) - assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionImagingWTilde) assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( 7.257175708246, 1.0e-4 @@ -233,7 +236,7 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper( ) assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObj) - assert isinstance(inversion.linear_obj_list[1], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[1], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionImagingMapping) assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( 7.2571757082469945, 1.0e-4 @@ -254,7 +257,7 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper( ) assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObj) - assert isinstance(inversion.linear_obj_list[1], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[1], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionImagingWTilde) assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( 7.2571757082469945, 1.0e-4 @@ -492,7 +495,7 @@ def test__inversion_interferometer__via_mapper( settings=aa.SettingsInversion(use_w_tilde=False), ) - assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionInterferometerMapping) assert inversion.mapped_reconstructed_data == pytest.approx( 1.0 + 0.0j * np.ones(shape=(7,)), 1.0e-4 diff --git a/test_autoarray/inversion/inversion/test_inversion_util.py b/test_autoarray/inversion/inversion/test_inversion_util.py index 86b722812..cd4132f1d 100644 --- a/test_autoarray/inversion/inversion/test_inversion_util.py +++ b/test_autoarray/inversion/inversion/test_inversion_util.py @@ -87,22 +87,6 @@ def test__reconstruction_positive_negative_from(): assert reconstruction == pytest.approx(np.array([1.0, -1.0, 3.0]), 1.0e-4) -def test__reconstruction_positive_negative_from__check_solution_raises_error_cause_all_values_identical(): - data_vector = np.array([1.0, 1.0, 1.0]) - - curvature_reg_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - - # reconstruction = np.array([1.0, 1.0, 1.0]) - - with pytest.raises(aa.exc.InversionException): - aa.util.inversion.reconstruction_positive_negative_from( - data_vector=data_vector, - curvature_reg_matrix=curvature_reg_matrix, - mapper_param_range_list=[[0, 3]], - force_check_reconstruction=True, - ) - - def test__mapped_reconstructed_data_via_mapping_matrix_from(): mapping_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) diff --git a/test_autoarray/inversion/pixelization/mappers/test_factory.py b/test_autoarray/inversion/pixelization/mappers/test_factory.py index 31a24ba2a..d8f68507d 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_factory.py +++ b/test_autoarray/inversion/pixelization/mappers/test_factory.py @@ -24,7 +24,7 @@ def test__rectangular_mapper(): grid.over_sampled[0, 0] = -2.0 grid.over_sampled[0, 1] = 2.0 - mesh = aa.mesh.Rectangular(shape=(3, 3)) + mesh = aa.mesh.RectangularUniform(shape=(3, 3)) mapper_grids = mesh.mapper_grids_from( mask=mask, @@ -35,7 +35,7 @@ def test__rectangular_mapper(): mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) - assert isinstance(mapper, aa.MapperRectangular) + assert isinstance(mapper, aa.MapperRectangularUniform) assert mapper.image_plane_mesh_grid == None assert mapper.source_plane_mesh_grid.geometry.shape_native_scaled == pytest.approx( diff --git a/test_autoarray/inversion/pixelization/mappers/test_mapper_util.py b/test_autoarray/inversion/pixelization/mappers/test_mapper_util.py index a5b41a15f..fdbaabe88 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_mapper_util.py +++ b/test_autoarray/inversion/pixelization/mappers/test_mapper_util.py @@ -13,62 +13,6 @@ def make_five_pixels(): return np.array([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]) -def _test__sub_slim_indexes_for_pix_index(): - pix_indexes_for_sub_slim_index = np.array( - [[0, 4], [1, 4], [2, 4], [0, 4], [1, 4], [3, 4], [0, 4], [3, 4]] - ).astype("int") - pix_pixels = 5 - pix_weights_for_sub_slim_index = np.array( - [ - [0.1, 0.9], - [0.2, 0.8], - [0.3, 0.7], - [0.4, 0.6], - [0.5, 0.5], - [0.6, 0.4], - [0.7, 0.3], - [0.8, 0.2], - ] - ) - - ( - sub_slim_indexes_for_pix_index, - sub_slim_sizes_for_pix_index, - sub_slim_weights_for_pix_index, - ) = aa.util.mapper.sub_slim_indexes_for_pix_index( - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - pix_pixels=pix_pixels, - ) - - assert ( - sub_slim_indexes_for_pix_index - == np.array( - [ - [0, 3, 6, -1, -1, -1, -1, -1], - [1, 4, -1, -1, -1, -1, -1, -1], - [2, -1, -1, -1, -1, -1, -1, -1], - [5, 7, -1, -1, -1, -1, -1, -1], - [0, 1, 2, 3, 4, 5, 6, 7], - ] - ) - ).all() - assert (sub_slim_sizes_for_pix_index == np.array([3, 2, 1, 2, 8])).all() - - assert ( - sub_slim_weights_for_pix_index - == np.array( - [ - [0.1, 0.4, 0.7, -1, -1, -1, -1, -1], - [0.2, 0.5, -1, -1, -1, -1, -1, -1], - [0.3, -1, -1, -1, -1, -1, -1, -1], - [0.6, 0.8, -1, -1, -1, -1, -1, -1], - [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2], - ] - ) - ).all() - - def test__mapping_matrix(three_pixels, five_pixels): pix_indexes_for_sub_slim_index = np.array([[0], [1], [2]]) slim_index_for_sub_slim_index = np.array([0, 1, 2]) diff --git a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py index ef8123b99..f80c67cde 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py +++ b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import autoarray as aa @@ -21,7 +22,7 @@ def test__pix_indexes_for_sub_slim_index__matches_util(): over_sample_size=1, ) - mesh_grid = aa.Mesh2DRectangular.overlay_grid( + mesh_grid = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid.over_sampled ) @@ -46,7 +47,7 @@ def test__pix_indexes_for_sub_slim_index__matches_util(): def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7): - mesh_grid = aa.Mesh2DRectangular.overlay_grid( + mesh_grid = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid_2d_sub_1_7x7.over_sampled ) @@ -72,3 +73,73 @@ def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7): ) assert (pixel_signals == pixel_signals_util).all() + + +def test__areas_transformed(mask_2d_7x7): + + grid = aa.Grid2DIrregular( + [ + [-1.5, -1.5], + [-1.5, 0.0], + [-1.5, 1.5], + [0.0, -1.5], + [0.0, 0.0], + [0.0, 1.5], + [1.5, -1.5], + [1.5, 0.0], + [1.5, 1.5], + ], + ) + + mesh = aa.Mesh2DRectangularUniform.overlay_grid( + shape_native=(3, 3), grid=grid, buffer=1e-8 + ) + + mapper_grids = aa.MapperGrids( + mask=mask_2d_7x7, + source_plane_data_grid=grid, + source_plane_mesh_grid=mesh, + ) + + mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) + + assert mapper.areas_transformed[4] == pytest.approx( + 4.0, + abs=1e-8, + ) + + +def test__edges_transformed(mask_2d_7x7): + + grid = aa.Grid2DIrregular( + [ + [-1.5, -1.5], + [-1.5, 0.0], + [-1.5, 1.5], + [0.0, -1.5], + [0.0, 0.0], + [0.0, 1.5], + [1.5, -1.5], + [1.5, 0.0], + [1.5, 1.5], + ], + ) + + mesh = aa.Mesh2DRectangularUniform.overlay_grid( + shape_native=(3, 3), grid=grid, buffer=1e-8 + ) + + mapper_grids = aa.MapperGrids( + mask=mask_2d_7x7, + source_plane_data_grid=grid, + source_plane_mesh_grid=mesh, + ) + + mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) + + assert mapper.edges_transformed[4] == pytest.approx( + np.array( + [1.5, 1.5], # left + ), + abs=1e-8, + ) diff --git a/test_autoarray/structures/mesh/test_rectangular.py b/test_autoarray/structures/mesh/test_rectangular.py index a63489733..f4838ae3d 100644 --- a/test_autoarray/structures/mesh/test_rectangular.py +++ b/test_autoarray/structures/mesh/test_rectangular.py @@ -12,7 +12,7 @@ def test__neighbors__compare_to_mesh_util(): # I8 I 9I10I11I # I12I13I14I15I - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(7, 5), grid=np.zeros((2, 2)), buffer=1e-8 ) @@ -39,7 +39,7 @@ def test__edge_pixel_list(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid, buffer=1e-8 ) @@ -61,7 +61,7 @@ def test__shape_native_and_pixel_scales(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid, buffer=1e-8 ) @@ -82,7 +82,7 @@ def test__shape_native_and_pixel_scales(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(5, 4), grid=grid, buffer=1e-8 ) @@ -91,7 +91,7 @@ def test__shape_native_and_pixel_scales(): grid = np.array([[2.0, 1.0], [4.0, 3.0], [6.0, 5.0], [8.0, 7.0]]) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid, buffer=1e-8 ) @@ -114,7 +114,7 @@ def test__pixel_centres__3x3_grid__pixel_centres(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid, buffer=1e-8 ) @@ -148,7 +148,7 @@ def test__pixel_centres__3x3_grid__pixel_centres(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(4, 3), grid=grid, buffer=1e-8 ) @@ -179,7 +179,7 @@ def test__interpolated_array_from(): pixel_scales=1.0, ) - grid_rectangular = aa.Mesh2DRectangular( + grid_rectangular = aa.Mesh2DRectangularUniform( values=grid, shape_native=grid.shape_native, pixel_scales=grid.pixel_scales )