diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 99d021359..16d8e7952 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -22,7 +22,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - uses: actions/cache@v2 + - uses: actions/cache@v3 id: cache-pip with: path: ~/.cache/pip @@ -36,9 +36,9 @@ jobs: pip3 install setuptools pip3 install wheel pip3 install pytest coverage pytest-cov - pip3 install -r PyAutoConf/requirements.txt - pip3 install -r PyAutoArray/requirements.txt - pip3 install -r PyAutoArray/optional_requirements.txt + pip install ./PyAutoConf + pip install ./PyAutoArray + pip install ./PyAutoArray[optional] cd PyAutoArray/autoarray/util/nn/src/nn export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/runner/work/PyAutoArray/PyAutoArray/PyAutoArray/autoarray/util/nn/src/nn diff --git a/autoarray/__init__.py b/autoarray/__init__.py index 1a6d79c53..3aa146ef7 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -1,9 +1,13 @@ +from autoconf.dictable import register_parser +from autoconf import conf + +conf.instance.register(__file__) + from . import exc from . import type from . import util from . import fixtures from . import mock as m -from .numba_util import profile_func from .dataset import preprocess from .dataset.abstract.dataset import AbstractDataset from .dataset.abstract.w_tilde import AbstractWTilde @@ -38,13 +42,13 @@ from .inversion.pixelization.mappers.rectangular import MapperRectangular from .inversion.pixelization.mappers.delaunay import MapperDelaunay from .inversion.pixelization.mappers.voronoi import MapperVoronoi +from .inversion.pixelization.mappers.rectangular_uniform import MapperRectangularUniform from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh from .inversion.pixelization.mesh.abstract import AbstractMesh from .inversion.inversion.imaging.mapping import InversionImagingMapping from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping -from .inversion.inversion.interferometer.lop import InversionInterferometerMappingPyLops from .inversion.linear_obj.linear_obj import LinearObj from .inversion.linear_obj.func_list import AbstractLinearObjFuncList from .mask.derive.indexes_2d import DeriveIndexes2D @@ -52,6 +56,7 @@ from .mask.derive.mask_2d import DeriveMask2D from .mask.derive.grid_1d import DeriveGrid1D from .mask.derive.grid_2d import DeriveGrid2D +from .mask.derive.zoom_2d import Zoom2D from .mask.mask_1d import Mask1D from .mask.mask_2d import Mask2D from .operators.transformer import TransformerDFT @@ -60,14 +65,17 @@ from .operators.contour import Grid2DContour from .layout.layout import Layout1D from .layout.layout import Layout2D +from .preloads import Preloads from .structures.arrays.uniform_1d import Array1D from .structures.arrays.uniform_2d import Array2D +from .structures.arrays.rgb import Array2DRGB from .structures.arrays.irregular import ArrayIrregular from .structures.grids.uniform_1d import Grid1D from .structures.grids.uniform_2d import Grid2D from .operators.over_sampling.over_sampler import OverSampler from .structures.grids.irregular_2d import Grid2DIrregular from .structures.mesh.rectangular_2d import Mesh2DRectangular +from .structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform from .structures.mesh.voronoi_2d import Mesh2DVoronoi from .structures.mesh.delaunay_2d import Mesh2DDelaunay from .structures.arrays.kernel_2d import Kernel2D diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index ded8c5452..44febb46b 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -282,7 +282,7 @@ def output_to_fits(self, file_path: str, overwrite: bool = False): If a file already exists at the path, if overwrite=True it is overwritten else an error is raised. """ output_to_fits( - values=self.native.array, + values=self.native.array.astype("float"), file_path=file_path, overwrite=overwrite, header_dict=self.mask.header_dict, diff --git a/autoarray/config/general.yaml b/autoarray/config/general.yaml index a6cb8d5dc..7b9112e81 100644 --- a/autoarray/config/general.yaml +++ b/autoarray/config/general.yaml @@ -1,3 +1,5 @@ +jax: + use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy. fits: flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9. inversion: diff --git a/autoarray/config/visualize/general.yaml b/autoarray/config/visualize/general.yaml index 8bbf29e06..b6cecf50f 100644 --- a/autoarray/config/visualize/general.yaml +++ b/autoarray/config/visualize/general.yaml @@ -4,7 +4,6 @@ general: log10_min_value: 1.0e-4 # If negative values are being plotted on a log10 scale, values below this value are rounded up to it (e.g. to remove negative values). log10_max_value: 1.0e99 # If positive values are being plotted on a log10 scale, values above this value are rounded down to it (e.g. to prevent white blobs). zoom_around_mask: true # If True, plots of data structures with a mask automatically zoom in the masked region. - disable_zoom_for_fits: true # If True, the zoom-in around the masked region is disabled when outputting .fits files, which is useful to retain the same dimensions as the input data. inversion: reconstruction_vmax_factor: 0.5 total_mappings_pixels : 8 # The number of source pixels used when plotting the subplot_mappings of a pixelization. diff --git a/autoarray/config/visualize/include.yaml b/autoarray/config/visualize/include.yaml deleted file mode 100644 index f010d9f8b..000000000 --- a/autoarray/config/visualize/include.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# The `include` settings customize every feature that appears on plotted images by default (e.g. a mask, the -# coordinate system's origin, etc.). - -# For example, if `include_2d -> mask:true`, the mask will not be plotted on any applicable figure by default. - -include_1d: - mask: false # Include a Mask ? - origin: false # Include the (x,) origin of the data's coordinate system ? -include_2d: - border: false # Include the border of the mask (all pixels on the outside of the mask) ? - grid: false # Include the data's 2D grid of (y,x) coordinates ? - mapper_image_plane_mesh_grid: false # For an Inversion, include the pixel centres computed in the image-plane / data frame? - mapper_source_plane_data_grid: false # For an Inversion, include the centres of the image-plane grid mapped to the source-plane / frame in source-plane figures? - mapper_source_plane_mesh_grid: false # For an Inversion, include the centres of the mesh pixels in the source-plane / source-plane? - mask: true # Include a mask ? - origin: false # Include the (y,x) origin of the data's coordinate system ? - positions: true # Include (y,x) coordinates specified via `Visuals2d.positions` ? - parallel_overscan: true - serial_overscan: true - serial_prescan: true \ No newline at end of file diff --git a/autoarray/dataset/grids.py b/autoarray/dataset/grids.py index c460bf820..d97fd3f4d 100644 --- a/autoarray/dataset/grids.py +++ b/autoarray/dataset/grids.py @@ -3,11 +3,11 @@ from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.arrays.kernel_2d import Kernel2D -from autoarray.structures.grids.uniform_1d import Grid1D from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.inversion.pixelization.border_relocator import BorderRelocator -from autoconf import cached_property + +from autoarray import exc class GridsDataset: @@ -24,7 +24,7 @@ def __init__( The following grids are contained: - - `uniform`: A grids of (y,x) coordinates which aligns with the centre of every image pixel of the image data, + - `lp`: A grids of (y,x) coordinates which aligns with the centre of every image pixel of the image data, which is used for most normal calculations (e.g. evaluating the amount of light that falls in an pixel from a light profile). @@ -60,72 +60,30 @@ def __init__( self.over_sample_size_pixelization = over_sample_size_pixelization self.psf = psf - @cached_property - def lp(self) -> Union[Grid1D, Grid2D]: - """ - Returns the grid of (y,x) Cartesian coordinates at the centre of every pixel in the masked data, which is used - to perform most normal calculations (e.g. evaluating the amount of light that falls in an pixel from a light - profile). - - This grid is computed based on the mask, in particular its pixel-scale and sub-grid size. - - Returns - ------- - The (y,x) coordinates of every pixel in the data. - """ - return Grid2D.from_mask( + self.lp = Grid2D.from_mask( mask=self.mask, over_sample_size=self.over_sample_size_lp, ) + self.lp.over_sampled - @cached_property - def pixelization(self) -> Grid2D: - """ - Returns the grid of (y,x) Cartesian coordinates of every pixel in the masked data which is used - specifically for calculations associated with a pixelization. - - The `pixelization` grid is identical to the `uniform` grid but often uses a different over sampling scheme - when performing calculations. For example, the pixelization may benefit from using a a higher `sub_size` than - the `uniform` grid, in order to better prevent aliasing effects. - - This grid is computed based on the mask, in particular its pixel-scale and sub-grid size. - - Returns - ------- - The (y,x) coordinates of every pixel in the data, used for pixelization / inversion calculations. - """ - return Grid2D.from_mask( + self.pixelization = Grid2D.from_mask( mask=self.mask, over_sample_size=self.over_sample_size_pixelization, ) - - @cached_property - def blurring(self) -> Optional[Grid2D]: - """ - Returns a blurring-grid from a mask and the 2D shape of the PSF kernel. - - A blurring grid consists of all pixels that are masked (and therefore have their values set to (0.0, 0.0)), - but are close enough to the unmasked pixels that their values will be convolved into the unmasked those pixels. - This when computing images from light profile objects. - - This uses lazy allocation such that the calculation is only performed when the blurring grid is used, ensuring - efficient set up of the `Imaging` class. - - Returns - ------- - The blurring grid given the mask of the imaging data. - """ + self.pixelization.over_sampled if self.psf is None: - return None - - return self.lp.blurring_grid_via_kernel_shape_from( - kernel_shape_native=self.psf.shape_native, - ) - - @cached_property - def border_relocator(self) -> BorderRelocator: - return BorderRelocator( + self.blurring = None + else: + try: + self.blurring = self.lp.blurring_grid_via_kernel_shape_from( + kernel_shape_native=self.psf.shape_native, + ) + self.blurring.over_sampled + except exc.MaskException: + self.blurring = None + + self.border_relocator = BorderRelocator( mask=self.mask, sub_size=self.over_sample_size_pixelization ) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 8d84ee1b8..b69e7dd6d 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -15,7 +15,7 @@ from autoarray import exc from autoarray.operators.over_sampling import over_sample_util -from autoarray.inversion.inversion.imaging import inversion_imaging_util +from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util logger = logging.getLogger(__name__) @@ -159,9 +159,26 @@ def __init__( """ ) - if psf is not None and use_normalized_psf: + if psf is not None: + + if not data.mask.is_all_false: + + image_mask = data.mask + blurring_mask = data.mask.derive_mask.blurring_from( + kernel_shape_native=psf.shape_native + ) + + else: + + image_mask = None + blurring_mask = None + psf = Kernel2D.no_mask( - values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True + values=psf.native._array, + pixel_scales=psf.pixel_scales, + normalize=use_normalized_psf, + image_mask=image_mask, + blurring_mask=blurring_mask, ) self.psf = psf @@ -170,9 +187,7 @@ def __init__( if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: raise exc.KernelException("Kernel2D Kernel2D must be odd") - @cached_property - def grids(self): - return GridsDataset( + self.grids = GridsDataset( mask=self.data.mask, over_sample_size_lp=self.over_sample_size_lp, over_sample_size_pixelization=self.over_sample_size_pixelization, @@ -203,10 +218,12 @@ def w_tilde(self): curvature_preload, indexes, lengths, - ) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from( + ) = inversion_imaging_numba_util.w_tilde_curvature_preload_imaging_from( noise_map_native=np.array(self.noise_map.native.array).astype("float64"), kernel_native=np.array(self.psf.native.array).astype("float64"), - native_index_for_slim_index=np.array(self.mask.derive_indexes.native_for_slim).astype("int"), + native_index_for_slim_index=np.array( + self.mask.derive_indexes.native_for_slim + ).astype("int"), ) return WTildeImaging( @@ -214,6 +231,9 @@ def w_tilde(self): indexes=indexes.astype("int"), lengths=lengths.astype("int"), noise_map_value=self.noise_map[0], + noise_map=self.noise_map, + psf=self.psf, + mask=self.mask, ) @classmethod @@ -409,12 +429,12 @@ def apply_noise_scaling( """ if signal_to_noise_value is None: - noise_map = np.array(self.noise_map.native.array) + noise_map = self.noise_map.native noise_map[mask.array == False] = noise_value else: noise_map = np.where( mask == False, - np.median(self.data.native.array[mask.derive_mask.edge == False]) + np.median(self.data.native[mask.derive_mask.edge == False]) / signal_to_noise_value, self.noise_map.native.array, ) @@ -488,7 +508,7 @@ def apply_over_sampling( passed into the calculations performed in the `inversion` module. """ - return Imaging( + dataset = Imaging( data=self.data, noise_map=self.noise_map, psf=self.psf, @@ -499,6 +519,8 @@ def apply_over_sampling( check_noise_map=False, ) + return dataset + def output_to_fits( self, data_path: Union[Path, str], diff --git a/autoarray/dataset/imaging/simulator.py b/autoarray/dataset/imaging/simulator.py index 449a3d4b6..576dc6017 100644 --- a/autoarray/dataset/imaging/simulator.py +++ b/autoarray/dataset/imaging/simulator.py @@ -151,7 +151,7 @@ def via_image_from( pixel_scales=image.pixel_scales, ) - if np.isnan(noise_map).any(): + if np.isnan(noise_map.array).any(): raise exc.DatasetException( "The noise-map has NaN values in it. This suggests your exposure time and / or" "background sky levels are too low, creating signal counts at or close to 0.0." @@ -161,7 +161,9 @@ def via_image_from( image = image - background_sky_map mask = Mask2D.all_false( - shape_native=image.shape_native, pixel_scales=image.pixel_scales + shape_native=image.shape_native, + pixel_scales=image.pixel_scales, + origin=image.origin, ) image = Array2D(values=image, mask=mask) diff --git a/autoarray/dataset/imaging/w_tilde.py b/autoarray/dataset/imaging/w_tilde.py index 985caeeb8..95ebe5291 100644 --- a/autoarray/dataset/imaging/w_tilde.py +++ b/autoarray/dataset/imaging/w_tilde.py @@ -1,9 +1,13 @@ -import copy import logging import numpy as np +from autoconf import cached_property + from autoarray.dataset.abstract.w_tilde import AbstractWTilde +from autoarray.inversion.inversion.imaging import inversion_imaging_util +from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util + logger = logging.getLogger(__name__) @@ -13,6 +17,9 @@ def __init__( curvature_preload: np.ndarray, indexes: np.ndim, lengths: np.ndarray, + noise_map: np.ndarray, + psf: np.ndarray, + mask: np.ndarray, noise_map_value: float, ): """ @@ -44,3 +51,56 @@ def __init__( self.indexes = indexes self.lengths = lengths + self.noise_map = noise_map + self.psf = psf + self.mask = mask + + @cached_property + def w_matrix(self): + """ + The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF + convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the + curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the + PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging + datasets. + + The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, + making it impossible to store in memory and its use in linear algebra calculations extremely. The method + `w_tilde_curvature_preload_imaging_from` describes a compressed representation that overcomes this hurdles. It is + advised `w_tilde` and this method are only used for testing. + + Parameters + ---------- + noise_map_native + The two dimensional masked noise-map of values which w_tilde is computed from. + kernel_native + The two dimensional PSF kernel that w_tilde encodes the convolution of. + native_index_for_slim_index + An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. + + Returns + ------- + ndarray + A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of + the curvature matrix. + """ + + return inversion_imaging_numba_util.w_tilde_curvature_imaging_from( + noise_map_native=np.array(self.noise_map.native.array).astype("float64"), + kernel_native=np.array(self.psf.native.array).astype("float64"), + native_index_for_slim_index=np.array( + self.mask.derive_indexes.native_for_slim + ).astype("int"), + ) + + @cached_property + def psf_operator_matrix_dense(self): + + return inversion_imaging_util.psf_operator_matrix_dense_from( + kernel_native=np.array(self.psf.native.array).astype("float64"), + native_index_for_slim_index=np.array( + self.mask.derive_indexes.native_for_slim + ).astype("int"), + native_shape=self.noise_map.shape_native, + correlate=False, + ) diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 06892ea68..7966ef74d 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -1,4 +1,3 @@ -from astropy.io import fits import logging import numpy as np from pathlib import Path @@ -10,7 +9,7 @@ from autoarray.dataset.interferometer.w_tilde import WTildeInterferometer from autoarray.dataset.grids import GridsDataset from autoarray.operators.transformer import TransformerNUFFT - +from autoarray.mask.mask_2d import Mask2D from autoarray.structures.visibilities import Visibilities from autoarray.structures.visibilities import VisibilitiesNoiseMap @@ -25,8 +24,9 @@ def __init__( data: Visibilities, noise_map: VisibilitiesNoiseMap, uv_wavelengths: np.ndarray, - real_space_mask, + real_space_mask: Mask2D, transformer_class=TransformerNUFFT, + dft_preload_transform: bool = True, preprocessing_directory=None, ): """ @@ -73,6 +73,9 @@ def __init__( transformer_class The class of the Fourier Transform which maps images from real space to Fourier space visibilities and the uv-plane. + dft_preload_transform + If True, precomputes and stores the cosine and sine terms for the Fourier transform. + This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets). """ self.real_space_mask = real_space_mask @@ -86,7 +89,9 @@ def __init__( self.uv_wavelengths = uv_wavelengths self.transformer = transformer_class( - uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask + uv_wavelengths=uv_wavelengths, + real_space_mask=real_space_mask, + preload_transform=dft_preload_transform, ) self.preprocessing_directory = ( @@ -95,9 +100,7 @@ def __init__( else None ) - @cached_property - def grids(self): - return GridsDataset( + self.grids = GridsDataset( mask=self.real_space_mask, over_sample_size_lp=self.over_sample_size_lp, over_sample_size_pixelization=self.over_sample_size_pixelization, @@ -114,6 +117,7 @@ def from_fits( noise_map_hdu=0, uv_wavelengths_hdu=0, transformer_class=TransformerNUFFT, + dft_preload_transform: bool = True, ): """ Factory for loading the interferometer data_type from .fits files, as well as computing properties like the @@ -139,9 +143,13 @@ def from_fits( noise_map=noise_map, uv_wavelengths=uv_wavelengths, transformer_class=transformer_class, + dft_preload_transform=dft_preload_transform, ) def w_tilde_preprocessing(self): + + from astropy.io import fits + if self.preprocessing_directory.is_dir(): filename = "{}/curvature_preload.fits".format(self.preprocessing_directory) @@ -193,7 +201,9 @@ def w_tilde(self): w_matrix = inversion_interferometer_util.w_tilde_via_preload_from( w_tilde_preload=curvature_preload, - native_index_for_slim_index=self.real_space_mask.derive_indexes.native_for_slim, + native_index_for_slim_index=np.array( + self.real_space_mask.derive_indexes.native_for_slim + ).astype("int"), ) dirty_image = self.transformer.image_from( @@ -205,7 +215,7 @@ def w_tilde(self): return WTildeInterferometer( w_matrix=w_matrix, curvature_preload=curvature_preload, - dirty_image=dirty_image, + dirty_image=np.array(dirty_image.array), real_space_mask=self.real_space_mask, noise_map_value=self.noise_map[0], ) diff --git a/autoarray/dataset/plot/imaging_plotters.py b/autoarray/dataset/plot/imaging_plotters.py index 8fed6f601..e0c0772e3 100644 --- a/autoarray/dataset/plot/imaging_plotters.py +++ b/autoarray/dataset/plot/imaging_plotters.py @@ -2,21 +2,18 @@ from typing import Callable, Optional from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.dataset.imaging.dataset import Imaging -class ImagingPlotterMeta(Plotter): +class ImagingPlotterMeta(AbstractPlotter): def __init__( self, dataset: Imaging, - get_visuals_2d: Callable, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -27,29 +24,21 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Imaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Imaging` and plotted via the visuals object. Parameters ---------- dataset The imaging dataset the plotter plots. - get_visuals_2d - A function which extracts from the `Imaging` the 2D visuals which are plotted on figures. mat_plot_2d Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Imaging` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.dataset = dataset - self.get_visuals_2d = get_visuals_2d @property def imaging(self): @@ -91,21 +80,21 @@ def figures_2d( if data: self.mat_plot_2d.plot_array( array=self.dataset.data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title=title_str or f" Data", filename="data"), ) if noise_map: self.mat_plot_2d.plot_array( array=self.dataset.noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title_str or f"Noise-Map", filename="noise_map"), ) if psf: self.mat_plot_2d.plot_array( array=self.dataset.psf, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title=title_str or f"Point Spread Function", filename="psf", @@ -116,7 +105,7 @@ def figures_2d( if signal_to_noise_map: self.mat_plot_2d.plot_array( array=self.dataset.signal_to_noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title=title_str or f"Signal-To-Noise Map", filename="signal_to_noise_map", @@ -127,7 +116,7 @@ def figures_2d( if over_sample_size_lp: self.mat_plot_2d.plot_array( array=self.dataset.grids.over_sample_size_lp, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title=title_str or f"Over Sample Size (Light Profiles)", filename="over_sample_size_lp", @@ -138,7 +127,7 @@ def figures_2d( if over_sample_size_pixelization: self.mat_plot_2d.plot_array( array=self.dataset.grids.over_sample_size_pixelization, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title=title_str or f"Over Sample Size (Pixelization)", filename="over_sample_size_pixelization", @@ -227,13 +216,12 @@ def subplot_dataset(self): self.mat_plot_2d.use_log10 = use_log10_original -class ImagingPlotter(Plotter): +class ImagingPlotter(AbstractPlotter): def __init__( self, dataset: Imaging, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -244,8 +232,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Imaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Imaging` and plotted via the visuals object. Parameters ---------- @@ -255,27 +242,18 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Imaging` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.dataset = dataset self._imaging_meta_plotter = ImagingPlotterMeta( dataset=self.dataset, - get_visuals_2d=self.get_visuals_2d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) self.figures_2d = self._imaging_meta_plotter.figures_2d self.subplot = self._imaging_meta_plotter.subplot self.subplot_dataset = self._imaging_meta_plotter.subplot_dataset - - def get_visuals_2d(self): - return self.get_2d.via_mask_from(mask=self.dataset.mask) diff --git a/autoarray/dataset/plot/interferometer_plotters.py b/autoarray/dataset/plot/interferometer_plotters.py index 944ba51bd..e69f53d38 100644 --- a/autoarray/dataset/plot/interferometer_plotters.py +++ b/autoarray/dataset/plot/interferometer_plotters.py @@ -1,298 +1,284 @@ -from autoarray.plot.abstract_plotters import Plotter -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D -from autoarray.plot.mat_plot.one_d import MatPlot1D -from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.auto_labels import AutoLabels -from autoarray.dataset.interferometer.dataset import Interferometer -from autoarray.structures.grids.irregular_2d import Grid2DIrregular - - -class InterferometerPlotter(Plotter): - def __init__( - self, - dataset: Interferometer, - mat_plot_1d: MatPlot1D = MatPlot1D(), - visuals_1d: Visuals1D = Visuals1D(), - include_1d: Include1D = Include1D(), - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), - ): - """ - Plots the attributes of `Interferometer` objects using the matplotlib methods `plot()`, `scatter()` and - `imshow()` and other matplotlib functions which customize the plot's appearance. - - The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, - the settings passed to every matplotlib function called are those specified in - the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2d` to - customize the figure's appearance. - - Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `LightProfile` and plotted via the visuals object, if the corresponding entry is `True` in - the `Include1D` or `Include2D` object or the `config/visualize/include.ini` file. - - Parameters - ---------- - dataset - The interferometer dataset the plotter plots. - mat_plot_1d - Contains objects which wrap the matplotlib function calls that make 1D plots. - visuals_1d - Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `Interferometer` are extracted and plotted as visuals for 1D plots. - mat_plot_2d - Contains objects which wrap the matplotlib function calls that make 2D plots. - visuals_2d - Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Interferometer` are extracted and plotted as visuals for 2D plots. - """ - self.dataset = dataset - - super().__init__( - mat_plot_1d=mat_plot_1d, - include_1d=include_1d, - visuals_1d=visuals_1d, - mat_plot_2d=mat_plot_2d, - include_2d=include_2d, - visuals_2d=visuals_2d, - ) - - @property - def interferometer(self): - return self.dataset - - def get_visuals_2d_real_space(self): - return self.get_2d.via_mask_from(mask=self.dataset.real_space_mask) - - def figures_2d( - self, - data: bool = False, - noise_map: bool = False, - u_wavelengths: bool = False, - v_wavelengths: bool = False, - uv_wavelengths: bool = False, - amplitudes_vs_uv_distances: bool = False, - phases_vs_uv_distances: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - ): - """ - Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is plotted. - - Parameters - ---------- - data - Whether to make a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to make a 2D plot (via `scatter`) of the noise-map. - u_wavelengths - Whether to make a 1D plot (via `plot`) of the u-wavelengths. - v_wavelengths - Whether to make a 1D plot (via `plot`) of the v-wavelengths. - amplitudes_vs_uv_distances - Whether to make a 1D plot (via `plot`) of the amplitudes versis the uv distances. - phases_vs_uv_distances - Whether to make a 1D plot (via `plot`) of the phases versis the uv distances. - dirty_image - Whether to make a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty noise map. - dirty_signal_to_noise_map - Whether to make a 2D plot (via `imshow`) of the dirty signal-to-noise map. - """ - - if data: - self.mat_plot_2d.plot_grid( - grid=self.dataset.data.in_grid, - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels(title="Visibilities", filename="data"), - ) - - if noise_map: - self.mat_plot_2d.plot_grid( - grid=self.dataset.data.in_grid, - visuals_2d=self.visuals_2d, - color_array=self.dataset.noise_map.real, - auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), - ) - - if u_wavelengths: - self.mat_plot_1d.plot_yx( - y=self.dataset.uv_wavelengths[:, 0], - x=None, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="U-Wavelengths", - filename="u_wavelengths", - ylabel="Wavelengths", - ), - plot_axis_type_override="linear", - ) - - if v_wavelengths: - self.mat_plot_1d.plot_yx( - y=self.dataset.uv_wavelengths[:, 1], - x=None, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="V-Wavelengths", - filename="v_wavelengths", - ylabel="Wavelengths", - ), - plot_axis_type_override="linear", - ) - - if uv_wavelengths: - self.mat_plot_2d.plot_grid( - grid=Grid2DIrregular.from_yx_1d( - y=self.dataset.uv_wavelengths[:, 1] / 10**3.0, - x=self.dataset.uv_wavelengths[:, 0] / 10**3.0, - ), - visuals_2d=self.visuals_2d, - auto_labels=AutoLabels( - title="UV-Wavelengths", filename="uv_wavelengths" - ), - ) - - if amplitudes_vs_uv_distances: - self.mat_plot_1d.plot_yx( - y=self.dataset.amplitudes, - x=self.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Amplitudes vs UV-distances", - filename="amplitudes_vs_uv_distances", - yunit="Jy", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if phases_vs_uv_distances: - self.mat_plot_1d.plot_yx( - y=self.dataset.phases, - x=self.dataset.uv_distances / 10**3.0, - visuals_1d=self.visuals_1d, - auto_labels=AutoLabels( - title="Phases vs UV-distances", - filename="phases_vs_uv_distances", - yunit="deg", - xunit="k$\lambda$", - ), - plot_axis_type_override="scatter", - ) - - if dirty_image: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_image, - visuals_2d=self.get_visuals_2d_real_space(), - auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"), - ) - - if dirty_noise_map: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_noise_map, - visuals_2d=self.get_visuals_2d_real_space(), - auto_labels=AutoLabels( - title="Dirty Noise Map", filename="dirty_noise_map" - ), - ) - - if dirty_signal_to_noise_map: - self.mat_plot_2d.plot_array( - array=self.dataset.dirty_signal_to_noise_map, - visuals_2d=self.get_visuals_2d_real_space(), - auto_labels=AutoLabels( - title="Dirty Signal-To-Noise Map", - filename="dirty_signal_to_noise_map", - ), - ) - - def subplot( - self, - data: bool = False, - noise_map: bool = False, - u_wavelengths: bool = False, - v_wavelengths: bool = False, - uv_wavelengths: bool = False, - amplitudes_vs_uv_distances: bool = False, - phases_vs_uv_distances: bool = False, - dirty_image: bool = False, - dirty_noise_map: bool = False, - dirty_signal_to_noise_map: bool = False, - auto_filename: str = "subplot_dataset", - ): - """ - Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D on a subplot. - - The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type - bool of the function, which if switched to `True` means that it is included on the subplot. - - Parameters - ---------- - data - Whether to include a 2D plot (via `scatter`) of the visibility data. - noise_map - Whether to include a 2D plot (via `scatter`) of the noise-map. - u_wavelengths - Whether to include a 1D plot (via `plot`) of the u-wavelengths. - v_wavelengths - Whether to include a 1D plot (via `plot`) of the v-wavelengths. - amplitudes_vs_uv_distances - Whether to include a 1D plot (via `plot`) of the amplitudes versis the uv distances. - phases_vs_uv_distances - Whether to include a 1D plot (via `plot`) of the phases versis the uv distances. - dirty_image - Whether to include a 2D plot (via `imshow`) of the dirty image. - dirty_noise_map - Whether to include a 2D plot (via `imshow`) of the dirty noise map. - dirty_signal_to_noise_map - Whether to include a 2D plot (via `imshow`) of the dirty signal-to-noise map. - """ - self._subplot_custom_plot( - data=data, - noise_map=noise_map, - u_wavelengths=u_wavelengths, - v_wavelengths=v_wavelengths, - uv_wavelengths=uv_wavelengths, - amplitudes_vs_uv_distances=amplitudes_vs_uv_distances, - phases_vs_uv_distances=phases_vs_uv_distances, - dirty_image=dirty_image, - dirty_noise_map=dirty_noise_map, - dirty_signal_to_noise_map=dirty_signal_to_noise_map, - auto_labels=AutoLabels(filename=auto_filename), - ) - - def subplot_dataset(self): - """ - Standard subplot of the attributes of the plotter's `Interferometer` object. - """ - return self.subplot( - data=True, - uv_wavelengths=True, - amplitudes_vs_uv_distances=True, - phases_vs_uv_distances=True, - dirty_image=True, - dirty_signal_to_noise_map=True, - auto_filename="subplot_dataset", - ) - - def subplot_dirty_images(self): - """ - Standard subplot of the dirty attributes of the plotter's `Interferometer` object. - """ - return self.subplot( - dirty_image=True, - dirty_noise_map=True, - dirty_signal_to_noise_map=True, - auto_filename="subplot_dirty_images", - ) +from autoarray.plot.abstract_plotters import AbstractPlotter +from autoarray.plot.visuals.one_d import Visuals1D +from autoarray.plot.visuals.two_d import Visuals2D +from autoarray.plot.mat_plot.one_d import MatPlot1D +from autoarray.plot.mat_plot.two_d import MatPlot2D +from autoarray.plot.auto_labels import AutoLabels +from autoarray.dataset.interferometer.dataset import Interferometer +from autoarray.structures.grids.irregular_2d import Grid2DIrregular + + +class InterferometerPlotter(AbstractPlotter): + def __init__( + self, + dataset: Interferometer, + mat_plot_1d: MatPlot1D = None, + visuals_1d: Visuals1D = None, + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, + ): + """ + Plots the attributes of `Interferometer` objects using the matplotlib methods `plot()`, `scatter()` and + `imshow()` and other matplotlib functions which customize the plot's appearance. + + The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default, + the settings passed to every matplotlib function called are those specified in + the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2d` to + customize the figure's appearance. + + Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be + extracted from the `LightProfile` and plotted via the visuals object. + + Parameters + ---------- + dataset + The interferometer dataset the plotter plots. + mat_plot_1d + Contains objects which wrap the matplotlib function calls that make 1D plots. + visuals_1d + Contains 1D visuals that can be overlaid on 1D plots. + mat_plot_2d + Contains objects which wrap the matplotlib function calls that make 2D plots. + visuals_2d + Contains 2D visuals that can be overlaid on 2D plots. + """ + self.dataset = dataset + + super().__init__( + mat_plot_1d=mat_plot_1d, + visuals_1d=visuals_1d, + mat_plot_2d=mat_plot_2d, + visuals_2d=visuals_2d, + ) + + @property + def interferometer(self): + return self.dataset + + def figures_2d( + self, + data: bool = False, + noise_map: bool = False, + u_wavelengths: bool = False, + v_wavelengths: bool = False, + uv_wavelengths: bool = False, + amplitudes_vs_uv_distances: bool = False, + phases_vs_uv_distances: bool = False, + dirty_image: bool = False, + dirty_noise_map: bool = False, + dirty_signal_to_noise_map: bool = False, + ): + """ + Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D. + + The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type + bool of the function, which if switched to `True` means that it is plotted. + + Parameters + ---------- + data + Whether to make a 2D plot (via `scatter`) of the visibility data. + noise_map + Whether to make a 2D plot (via `scatter`) of the noise-map. + u_wavelengths + Whether to make a 1D plot (via `plot`) of the u-wavelengths. + v_wavelengths + Whether to make a 1D plot (via `plot`) of the v-wavelengths. + amplitudes_vs_uv_distances + Whether to make a 1D plot (via `plot`) of the amplitudes versis the uv distances. + phases_vs_uv_distances + Whether to make a 1D plot (via `plot`) of the phases versis the uv distances. + dirty_image + Whether to make a 2D plot (via `imshow`) of the dirty image. + dirty_noise_map + Whether to make a 2D plot (via `imshow`) of the dirty noise map. + dirty_signal_to_noise_map + Whether to make a 2D plot (via `imshow`) of the dirty signal-to-noise map. + """ + + if data: + self.mat_plot_2d.plot_grid( + grid=self.dataset.data.in_grid, + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels(title="Visibilities", filename="data"), + ) + + if noise_map: + self.mat_plot_2d.plot_grid( + grid=self.dataset.data.in_grid, + visuals_2d=self.visuals_2d, + color_array=self.dataset.noise_map.real, + auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), + ) + + if u_wavelengths: + self.mat_plot_1d.plot_yx( + y=self.dataset.uv_wavelengths[:, 0], + x=None, + visuals_1d=self.visuals_1d, + auto_labels=AutoLabels( + title="U-Wavelengths", + filename="u_wavelengths", + ylabel="Wavelengths", + ), + plot_axis_type_override="linear", + ) + + if v_wavelengths: + self.mat_plot_1d.plot_yx( + y=self.dataset.uv_wavelengths[:, 1], + x=None, + visuals_1d=self.visuals_1d, + auto_labels=AutoLabels( + title="V-Wavelengths", + filename="v_wavelengths", + ylabel="Wavelengths", + ), + plot_axis_type_override="linear", + ) + + if uv_wavelengths: + self.mat_plot_2d.plot_grid( + grid=Grid2DIrregular.from_yx_1d( + y=self.dataset.uv_wavelengths[:, 1] / 10**3.0, + x=self.dataset.uv_wavelengths[:, 0] / 10**3.0, + ), + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="UV-Wavelengths", filename="uv_wavelengths" + ), + ) + + if amplitudes_vs_uv_distances: + self.mat_plot_1d.plot_yx( + y=self.dataset.amplitudes, + x=self.dataset.uv_distances / 10**3.0, + visuals_1d=self.visuals_1d, + auto_labels=AutoLabels( + title="Amplitudes vs UV-distances", + filename="amplitudes_vs_uv_distances", + yunit="Jy", + xunit="k$\lambda$", + ), + plot_axis_type_override="scatter", + ) + + if phases_vs_uv_distances: + self.mat_plot_1d.plot_yx( + y=self.dataset.phases, + x=self.dataset.uv_distances / 10**3.0, + visuals_1d=self.visuals_1d, + auto_labels=AutoLabels( + title="Phases vs UV-distances", + filename="phases_vs_uv_distances", + yunit="deg", + xunit="k$\lambda$", + ), + plot_axis_type_override="scatter", + ) + + if dirty_image: + self.mat_plot_2d.plot_array( + array=self.dataset.dirty_image, + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"), + ) + + if dirty_noise_map: + self.mat_plot_2d.plot_array( + array=self.dataset.dirty_noise_map, + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Dirty Noise Map", filename="dirty_noise_map" + ), + ) + + if dirty_signal_to_noise_map: + self.mat_plot_2d.plot_array( + array=self.dataset.dirty_signal_to_noise_map, + visuals_2d=self.visuals_2d, + auto_labels=AutoLabels( + title="Dirty Signal-To-Noise Map", + filename="dirty_signal_to_noise_map", + ), + ) + + def subplot( + self, + data: bool = False, + noise_map: bool = False, + u_wavelengths: bool = False, + v_wavelengths: bool = False, + uv_wavelengths: bool = False, + amplitudes_vs_uv_distances: bool = False, + phases_vs_uv_distances: bool = False, + dirty_image: bool = False, + dirty_noise_map: bool = False, + dirty_signal_to_noise_map: bool = False, + auto_filename: str = "subplot_dataset", + ): + """ + Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D on a subplot. + + The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type + bool of the function, which if switched to `True` means that it is included on the subplot. + + Parameters + ---------- + data + Whether to include a 2D plot (via `scatter`) of the visibility data. + noise_map + Whether to include a 2D plot (via `scatter`) of the noise-map. + u_wavelengths + Whether to include a 1D plot (via `plot`) of the u-wavelengths. + v_wavelengths + Whether to include a 1D plot (via `plot`) of the v-wavelengths. + amplitudes_vs_uv_distances + Whether to include a 1D plot (via `plot`) of the amplitudes versis the uv distances. + phases_vs_uv_distances + Whether to include a 1D plot (via `plot`) of the phases versis the uv distances. + dirty_image + Whether to include a 2D plot (via `imshow`) of the dirty image. + dirty_noise_map + Whether to include a 2D plot (via `imshow`) of the dirty noise map. + dirty_signal_to_noise_map + Whether to include a 2D plot (via `imshow`) of the dirty signal-to-noise map. + """ + self._subplot_custom_plot( + data=data, + noise_map=noise_map, + u_wavelengths=u_wavelengths, + v_wavelengths=v_wavelengths, + uv_wavelengths=uv_wavelengths, + amplitudes_vs_uv_distances=amplitudes_vs_uv_distances, + phases_vs_uv_distances=phases_vs_uv_distances, + dirty_image=dirty_image, + dirty_noise_map=dirty_noise_map, + dirty_signal_to_noise_map=dirty_signal_to_noise_map, + auto_labels=AutoLabels(filename=auto_filename), + ) + + def subplot_dataset(self): + """ + Standard subplot of the attributes of the plotter's `Interferometer` object. + """ + return self.subplot( + data=True, + uv_wavelengths=True, + amplitudes_vs_uv_distances=True, + phases_vs_uv_distances=True, + dirty_image=True, + dirty_signal_to_noise_map=True, + auto_filename="subplot_dataset", + ) + + def subplot_dirty_images(self): + """ + Standard subplot of the dirty attributes of the plotter's `Interferometer` object. + """ + return self.subplot( + dirty_image=True, + dirty_noise_map=True, + dirty_signal_to_noise_map=True, + auto_filename="subplot_dirty_images", + ) diff --git a/autoarray/dataset/preprocess.py b/autoarray/dataset/preprocess.py index f13af3184..c113c16aa 100644 --- a/autoarray/dataset/preprocess.py +++ b/autoarray/dataset/preprocess.py @@ -1,5 +1,4 @@ import numpy as np -from scipy.stats import norm from autoarray import exc @@ -149,7 +148,8 @@ def noise_map_via_data_eps_and_exposure_time_map_from(data_eps, exposure_time_ma The exposure time at every data-point of the data. """ return data_eps.with_new_array( - np.abs(data_eps * exposure_time_map) ** 0.5 / exposure_time_map + np.abs(data_eps.array * exposure_time_map.array) ** 0.5 + / exposure_time_map.array ) @@ -263,7 +263,9 @@ def edges_from(image, no_edges): edges = [] for edge_no in range(no_edges): - top_edge = image.native.array[edge_no, edge_no : image.shape_native[1] - edge_no] + top_edge = image.native.array[ + edge_no, edge_no : image.shape_native[1] - edge_no + ] bottom_edge = image.native.array[ image.shape_native[0] - 1 - edge_no, edge_no : image.shape_native[1] - edge_no, @@ -313,6 +315,7 @@ def background_noise_map_via_edges_from(image, no_edges): no_edges Number of edges used to estimate the background level. """ + from scipy.stats import norm from autoarray.structures.arrays.uniform_2d import Array2D @@ -406,9 +409,10 @@ def poisson_noise_via_data_eps_from(data_eps, exposure_time_map, seed=-1): An array describing simulated poisson noise_maps """ setup_random_seed(seed) - image_counts = np.multiply(data_eps, exposure_time_map) + + image_counts = np.multiply(data_eps.array, exposure_time_map.array) return data_eps - np.divide( - np.random.poisson(image_counts, data_eps.shape), exposure_time_map + np.random.poisson(image_counts, data_eps.shape), exposure_time_map.array ) @@ -506,8 +510,6 @@ def noise_map_with_signal_to_noise_limit_from( from autoarray.structures.arrays.uniform_1d import Array1D from autoarray.structures.arrays.uniform_2d import Array2D - # TODO : Refacotr into a util - signal_to_noise_map = data / noise_map signal_to_noise_map[signal_to_noise_map < 0] = 0 @@ -522,7 +524,9 @@ def noise_map_with_signal_to_noise_limit_from( ) mask = Mask2D.all_false( - shape_native=data.shape_native, pixel_scales=data.pixel_scales + shape_native=data.shape_native, + pixel_scales=data.pixel_scales, + origin=data.origin, ) if len(noise_map.native) == 1: diff --git a/autoarray/exc.py b/autoarray/exc.py index eed76b04e..3929820eb 100644 --- a/autoarray/exc.py +++ b/autoarray/exc.py @@ -104,11 +104,3 @@ class PlottingException(Exception): """ pass - - -class ProfilingException(Exception): - """ - Raises exceptions associated with in-built profiling tools (e.g. the `profile_func` decorator). - """ - - pass diff --git a/autoarray/fit/fit_dataset.py b/autoarray/fit/fit_dataset.py index c2f498949..7c6234b75 100644 --- a/autoarray/fit/fit_dataset.py +++ b/autoarray/fit/fit_dataset.py @@ -13,7 +13,6 @@ from autoarray.inversion.inversion.abstract import AbstractInversion from autoarray.mask.mask_2d import Mask2D -from autoarray.numba_util import profile_func from autoarray import type as ty @@ -86,7 +85,7 @@ def chi_squared(self) -> float: """ Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map. """ - return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map) + return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array) @property def noise_normalization(self) -> float: @@ -95,7 +94,7 @@ def noise_normalization(self) -> float: [Noise_Term] = sum(log(2*pi*[Noise]**2.0)) """ - return fit_util.noise_normalization_from(noise_map=self.noise_map) + return fit_util.noise_normalization_from(noise_map=self.noise_map.array) @property def log_likelihood(self) -> float: @@ -116,7 +115,6 @@ def __init__( dataset, use_mask_in_fit: bool = False, dataset_model: DatasetModel = None, - run_time_dict: Optional[Dict] = None, ): """Class to fit a masked dataset where the dataset's data structures are any dimension. @@ -149,7 +147,6 @@ def __init__( self.dataset = dataset self.use_mask_in_fit = use_mask_in_fit self.dataset_model = dataset_model or DatasetModel() - self.run_time_dict = run_time_dict @property def mask(self) -> Mask2D: @@ -157,15 +154,6 @@ def mask(self) -> Mask2D: @cached_property def grids(self) -> GridsInterface: - offset = self.dataset_model.grid_offset - - if offset[0] == 0.0 and offset[1] == 0.0: - return GridsInterface( - lp=self.dataset.grids.lp, - pixelization=self.dataset.grids.pixelization, - blurring=self.dataset.grids.blurring, - border_relocator=self.dataset.grids.border_relocator, - ) def subtracted_from(grid, offset): if grid is None: @@ -320,7 +308,6 @@ def log_evidence(self) -> float: ) @property - @profile_func def figure_of_merit(self) -> float: if self.inversion is not None: return self.log_evidence diff --git a/autoarray/fit/fit_imaging.py b/autoarray/fit/fit_imaging.py index aa55b4e35..a8b1f2297 100644 --- a/autoarray/fit/fit_imaging.py +++ b/autoarray/fit/fit_imaging.py @@ -14,7 +14,6 @@ def __init__( dataset: Imaging, use_mask_in_fit: bool = False, dataset_model: DatasetModel = None, - run_time_dict: Optional[Dict] = None, ): """ Class to fit a masked imaging dataset. @@ -50,7 +49,6 @@ def __init__( dataset=dataset, use_mask_in_fit=use_mask_in_fit, dataset_model=dataset_model, - run_time_dict=run_time_dict, ) @property diff --git a/autoarray/fit/fit_interferometer.py b/autoarray/fit/fit_interferometer.py index ec4c1d99d..8ff3d3684 100644 --- a/autoarray/fit/fit_interferometer.py +++ b/autoarray/fit/fit_interferometer.py @@ -18,7 +18,6 @@ def __init__( dataset: Interferometer, dataset_model: DatasetModel = None, use_mask_in_fit: bool = False, - run_time_dict: Optional[Dict] = None, ): """ Class to fit a masked interferometer dataset. @@ -59,7 +58,6 @@ def __init__( dataset=dataset, dataset_model=dataset_model, use_mask_in_fit=use_mask_in_fit, - run_time_dict=run_time_dict, ) @property @@ -113,7 +111,7 @@ def chi_squared(self) -> float: Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map. """ return fit_util.chi_squared_complex_from( - chi_squared_map=self.chi_squared_map, + chi_squared_map=self.chi_squared_map.array, ) @property @@ -124,7 +122,7 @@ def noise_normalization(self) -> float: [Noise_Term] = sum(log(2*pi*[Noise]**2.0)) """ return fit_util.noise_normalization_complex_from( - noise_map=self.noise_map, + noise_map=self.noise_map.array, ) @property diff --git a/autoarray/fit/fit_util.py b/autoarray/fit/fit_util.py index 10f24f9a7..190788efc 100644 --- a/autoarray/fit/fit_util.py +++ b/autoarray/fit/fit_util.py @@ -84,7 +84,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float: chi_squared_map The chi-squared-map of values of the model-data fit to the dataset. """ - return jnp.sum(np.array(chi_squared_map)) + return jnp.sum(chi_squared_map) def noise_normalization_from(*, noise_map: ty.DataLike) -> float: @@ -98,7 +98,7 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float: noise_map The masked noise-map of the dataset. """ - return jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map)**2.0)) + return jnp.sum(jnp.log(2 * jnp.pi * noise_map**2.0)) def normalized_residual_map_complex_from( @@ -224,6 +224,7 @@ def normalized_residual_map_with_mask_from( """ return jnp.where(jnp.asarray(mask) == 0, jnp.divide(residual_map, noise_map), 0) + @to_new_array def chi_squared_map_with_mask_from( *, residual_map: ty.DataLike, noise_map: ty.DataLike, mask: Mask @@ -244,11 +245,7 @@ def chi_squared_map_with_mask_from( mask The mask applied to the residual-map, where `False` entries are included in the calculation. """ - return jnp.where( - jnp.asarray(mask) == 0, - jnp.square(residual_map / noise_map), - 0 - ) + return jnp.where(jnp.asarray(mask) == 0, jnp.square(residual_map / noise_map), 0) def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask) -> float: @@ -321,7 +318,9 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask) -> mask The mask applied to the noise-map, where `False` entries are included in the calculation. """ - return float(jnp.sum(jnp.log(2 * jnp.pi * noise_map[jnp.asarray(mask) == 0] ** 2.0))) + return float( + jnp.sum(jnp.log(2 * jnp.pi * noise_map[jnp.asarray(mask) == 0] ** 2.0)) + ) def chi_squared_with_noise_covariance_from( @@ -457,4 +456,4 @@ def residual_flux_fraction_map_with_mask_from( mask The mask applied to the residual-map, where `False` entries are included in the calculation. """ - return jnp.where(mask == 0, residual_map / data, 0) \ No newline at end of file + return jnp.where(mask == 0, residual_map / data, 0) diff --git a/autoarray/fit/mock/mock_fit_imaging.py b/autoarray/fit/mock/mock_fit_imaging.py index 181d3d56e..4fa8253ff 100644 --- a/autoarray/fit/mock/mock_fit_imaging.py +++ b/autoarray/fit/mock/mock_fit_imaging.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from autoarray.dataset.mock.mock_dataset import MockDataset from autoarray.dataset.dataset_model import DatasetModel @@ -15,13 +15,11 @@ def __init__( model_data=None, inversion=None, blurred_image=None, - run_time_dict: Optional[Dict] = None, ): super().__init__( dataset=dataset or MockDataset(), dataset_model=dataset_model, use_mask_in_fit=use_mask_in_fit, - run_time_dict=run_time_dict, ) self._noise_map = noise_map diff --git a/autoarray/fit/plot/fit_imaging_plotters.py b/autoarray/fit/plot/fit_imaging_plotters.py index 1c0eb9e67..86aa0d34d 100644 --- a/autoarray/fit/plot/fit_imaging_plotters.py +++ b/autoarray/fit/plot/fit_imaging_plotters.py @@ -1,21 +1,18 @@ from typing import Callable -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.fit.fit_imaging import FitImaging -class FitImagingPlotterMeta(Plotter): +class FitImagingPlotterMeta(AbstractPlotter): def __init__( self, fit, - get_visuals_2d: Callable, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, residuals_symmetric_cmap: bool = True, ): """ @@ -27,31 +24,23 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- fit The fit to an imaging dataset the plotter plots. - get_visuals_2d - A function which extracts from the `FitImaging` the 2D visuals which are plotted on figures. mat_plot_2d Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. residuals_symmetric_cmap If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such that `abs(vmin) = abs(vmax)`. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit - self.get_visuals_2d = get_visuals_2d self.residuals_symmetric_cmap = residuals_symmetric_cmap def figures_2d( @@ -95,14 +84,14 @@ def figures_2d( if data: self.mat_plot_2d.plot_array( array=self.fit.data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Data", filename=f"data{suffix}"), ) if noise_map: self.mat_plot_2d.plot_array( array=self.fit.noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Noise-Map", filename=f"noise_map{suffix}" ), @@ -111,7 +100,7 @@ def figures_2d( if signal_to_noise_map: self.mat_plot_2d.plot_array( array=self.fit.signal_to_noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Signal-To-Noise Map", filename=f"signal_to_noise_map{suffix}" ), @@ -120,7 +109,7 @@ def figures_2d( if model_image: self.mat_plot_2d.plot_array( array=self.fit.model_data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Model Image", filename=f"model_image{suffix}" ), @@ -134,7 +123,7 @@ def figures_2d( if residual_map: self.mat_plot_2d.plot_array( array=self.fit.residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Residual Map", filename=f"residual_map{suffix}" ), @@ -143,7 +132,7 @@ def figures_2d( if normalized_residual_map: self.mat_plot_2d.plot_array( array=self.fit.normalized_residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Normalized Residual Map", filename=f"normalized_residual_map{suffix}", @@ -155,7 +144,7 @@ def figures_2d( if chi_squared_map: self.mat_plot_2d.plot_array( array=self.fit.chi_squared_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Chi-Squared Map", filename=f"chi_squared_map{suffix}" ), @@ -164,7 +153,7 @@ def figures_2d( if residual_flux_fraction_map: self.mat_plot_2d.plot_array( array=self.fit.residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Residual Flux Fraction Map", filename=f"residual_flux_fraction_map{suffix}", @@ -238,13 +227,12 @@ def subplot_fit(self): ) -class FitImagingPlotter(Plotter): +class FitImagingPlotter(AbstractPlotter): def __init__( self, fit: FitImaging, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -255,8 +243,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- @@ -266,26 +253,17 @@ def __init__( Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit self._fit_imaging_meta_plotter = FitImagingPlotterMeta( fit=self.fit, - get_visuals_2d=self.get_visuals_2d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) self.figures_2d = self._fit_imaging_meta_plotter.figures_2d self.subplot = self._fit_imaging_meta_plotter.subplot self.subplot_fit = self._fit_imaging_meta_plotter.subplot_fit - - def get_visuals_2d(self) -> Visuals2D: - return self.get_2d.via_fit_imaging_from(fit=self.fit) diff --git a/autoarray/fit/plot/fit_interferometer_plotters.py b/autoarray/fit/plot/fit_interferometer_plotters.py index 93e7b212e..3ab2bd1e6 100644 --- a/autoarray/fit/plot/fit_interferometer_plotters.py +++ b/autoarray/fit/plot/fit_interferometer_plotters.py @@ -1,28 +1,22 @@ import numpy as np -from typing import Callable -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.one_d import Visuals1D from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.one_d import MatPlot1D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.fit.fit_interferometer import FitInterferometer -class FitInterferometerPlotterMeta(Plotter): +class FitInterferometerPlotterMeta(AbstractPlotter): def __init__( self, fit, - get_visuals_2d_real_space: Callable, mat_plot_1d: MatPlot1D, visuals_1d: Visuals1D, - include_1d: Include1D, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, residuals_symmetric_cmap: bool = True, ): """ @@ -35,42 +29,32 @@ def __init__( customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be - extracted from the `FitInterferometer` and plotted via the visuals object, if the corresponding entry is `True` in - the `Include1D` or `Include2D` object or the `config/visualize/include.ini` file. + extracted from the `FitInterferometer` and plotted via the visuals object. Parameters ---------- fit The fit to an interferometer dataset the plotter plots. - get_visuals_2d - A function which extracts from the `FitInterferometer` the 2D visuals which are plotted on figures. mat_plot_1d Contains objects which wrap the matplotlib function calls that make 1D plots. visuals_1d Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `FitInterferometer` are extracted and plotted as visuals for 1D plots. mat_plot_2d Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `FitInterferometer` are extracted and plotted as visuals for 2D plots. residuals_symmetric_cmap If true, the `residual_map` and `normalized_residual_map` are plotted with a symmetric color map such that `abs(vmin) = abs(vmax)`. """ super().__init__( mat_plot_1d=mat_plot_1d, - include_1d=include_1d, visuals_1d=visuals_1d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, visuals_2d=visuals_2d, ) self.fit = fit - self.get_visuals_2d_real_space = get_visuals_2d_real_space self.residuals_symmetric_cmap = residuals_symmetric_cmap def figures_2d( @@ -183,7 +167,7 @@ def figures_2d( auto_labels=AutoLabels( title="Model Visibilities", filename="model_data" ), - color_array=np.real(self.fit.model_data), + color_array=np.real(self.fit.model_data.array), ) if residual_map_real: @@ -268,14 +252,14 @@ def figures_2d( if dirty_image: self.mat_plot_2d.plot_array( array=self.fit.dirty_image, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"), ) if dirty_noise_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_noise_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Noise Map", filename="dirty_noise_map" ), @@ -284,7 +268,7 @@ def figures_2d( if dirty_signal_to_noise_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_signal_to_noise_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Signal-To-Noise Map", filename="dirty_signal_to_noise_map", @@ -294,7 +278,7 @@ def figures_2d( if dirty_model_image: self.mat_plot_2d.plot_array( array=self.fit.dirty_model_image, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Model Image", filename="dirty_model_image_2d" ), @@ -308,7 +292,7 @@ def figures_2d( if dirty_residual_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_residual_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Residual Map", filename="dirty_residual_map_2d" ), @@ -317,7 +301,7 @@ def figures_2d( if dirty_normalized_residual_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_normalized_residual_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Normalized Residual Map", filename="dirty_normalized_residual_map_2d", @@ -330,7 +314,7 @@ def figures_2d( if dirty_chi_squared_map: self.mat_plot_2d.plot_array( array=self.fit.dirty_chi_squared_map, - visuals_2d=self.get_visuals_2d_real_space(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Dirty Chi-Squared Map", filename="dirty_chi_squared_map_2d" ), @@ -451,16 +435,14 @@ def subplot_fit_dirty_images(self): ) -class FitInterferometerPlotter(Plotter): +class FitInterferometerPlotter(AbstractPlotter): def __init__( self, fit: FitInterferometer, - mat_plot_1d: MatPlot1D = MatPlot1D(), - visuals_1d: Visuals1D = Visuals1D(), - include_1d: Include1D = Include1D(), - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_1d: MatPlot1D = None, + visuals_1d: Visuals1D = None, + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots the attributes of `FitInterferometer` objects using the matplotlib method `imshow()` and many other @@ -471,8 +453,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitInterferometer` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitInterferometer` and plotted via the visuals object. Parameters ---------- @@ -482,15 +463,11 @@ def __init__( Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ super().__init__( mat_plot_1d=mat_plot_1d, - include_1d=include_1d, visuals_1d=visuals_1d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, visuals_2d=visuals_2d, ) @@ -498,12 +475,9 @@ def __init__( self._fit_interferometer_meta_plotter = FitInterferometerPlotterMeta( fit=self.fit, - get_visuals_2d_real_space=self.get_visuals_2d_real_space, mat_plot_1d=self.mat_plot_1d, - include_1d=self.include_1d, visuals_1d=self.visuals_1d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) @@ -513,6 +487,3 @@ def __init__( self.subplot_fit_dirty_images = ( self._fit_interferometer_meta_plotter.subplot_fit_dirty_images ) - - def get_visuals_2d_real_space(self) -> Visuals2D: - return self.get_2d.via_mask_from(mask=self.fit.dataset.real_space_mask) diff --git a/autoarray/fit/plot/fit_vector_yx_plotters.py b/autoarray/fit/plot/fit_vector_yx_plotters.py index 3351cbaaf..9691e5680 100644 --- a/autoarray/fit/plot/fit_vector_yx_plotters.py +++ b/autoarray/fit/plot/fit_vector_yx_plotters.py @@ -1,22 +1,19 @@ from typing import Callable -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.fit.fit_imaging import FitImaging from autoarray.fit.plot.fit_imaging_plotters import FitImagingPlotterMeta -class FitVectorYXPlotterMeta(Plotter): +class FitVectorYXPlotterMeta(AbstractPlotter): def __init__( self, fit, - get_visuals_2d: Callable, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -27,28 +24,20 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- fit The fit to an imaging dataset the plotter plots. - get_visuals_2d - A function which extracts from the `FitImaging` the 2D visuals which are plotted on figures. mat_plot_2d Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit - self.get_visuals_2d = get_visuals_2d def figures_2d( self, @@ -84,26 +73,24 @@ def figures_2d( Whether to make a 2D plot (via `imshow`) of the chi-squared map. """ - fit_plotter_y = FitImaging(self.fit.data.y_array) - if image: self.mat_plot_2d.plot_array( array=self.fit.data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Data", filename="image_2d"), ) if noise_map: self.mat_plot_2d.plot_array( array=self.fit.noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"), ) if signal_to_noise_map: self.mat_plot_2d.plot_array( array=self.fit.signal_to_noise_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Signal-To-Noise Map", filename="signal_to_noise_map" ), @@ -112,21 +99,21 @@ def figures_2d( if model_image: self.mat_plot_2d.plot_array( array=self.fit.model_data, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Model Image", filename="model_image"), ) if residual_map: self.mat_plot_2d.plot_array( array=self.fit.residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Residual Map", filename="residual_map"), ) if normalized_residual_map: self.mat_plot_2d.plot_array( array=self.fit.normalized_residual_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Normalized Residual Map", filename="normalized_residual_map" ), @@ -135,7 +122,7 @@ def figures_2d( if chi_squared_map: self.mat_plot_2d.plot_array( array=self.fit.chi_squared_map, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Chi-Squared Map", filename="chi_squared_map" ), @@ -204,13 +191,12 @@ def subplot_fit(self): ) -class FitImagingPlotter(Plotter): +class FitImagingPlotter(AbstractPlotter): def __init__( self, fit: FitImaging, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib @@ -221,8 +207,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `FitImaging` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `FitImaging` and plotted via the visuals object. Parameters ---------- @@ -232,26 +217,17 @@ def __init__( Contains objects which wrap the matplotlib function calls that make the plot. visuals_2d Contains visuals that can be overlaid on the plot. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals. """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.fit = fit self._fit_imaging_meta_plotter = FitImagingPlotterMeta( fit=self.fit, - get_visuals_2d=self.get_visuals_2d, mat_plot_2d=self.mat_plot_2d, - include_2d=self.include_2d, visuals_2d=self.visuals_2d, ) self.figures_2d = self._fit_imaging_meta_plotter.figures_2d self.subplot = self._fit_imaging_meta_plotter.subplot self.subplot_fit = self._fit_imaging_meta_plotter.subplot_fit - - def get_visuals_2d(self) -> Visuals2D: - return self.get_2d.via_fit_imaging_from(fit=self.fit) diff --git a/autoarray/fixtures.py b/autoarray/fixtures.py index e546fb497..6c1f14264 100644 --- a/autoarray/fixtures.py +++ b/autoarray/fixtures.py @@ -68,6 +68,10 @@ def make_array_2d_7x7(): return aa.Array2D.ones(shape_native=(7, 7), pixel_scales=(1.0, 1.0)) +def make_array_2d_rgb_7x7(): + return aa.Array2DRGB(values=np.ones((7, 7, 3)), mask=make_mask_2d_7x7()) + + def make_layout_2d_7x7(): return aa.Layout2D( shape_2d=(7, 7), @@ -417,7 +421,7 @@ def make_rectangular_mapper_7x7_3x3(): adapt_data=aa.Array2D.ones(shape_native=(3, 3), pixel_scales=0.1), ) - return aa.MapperRectangular( + return aa.MapperRectangularUniform( mapper_grids=mapper_grids, border_relocator=make_border_relocator_2d_7x7(), regularization=make_regularization_constant(), diff --git a/autoarray/geometry/geometry_2d.py b/autoarray/geometry/geometry_2d.py index 9eea7e9f2..fa5d4b24a 100644 --- a/autoarray/geometry/geometry_2d.py +++ b/autoarray/geometry/geometry_2d.py @@ -154,7 +154,7 @@ def pixel_coordinates_2d_from( A 2D (y,x) pixel-value coordinate. """ return geometry_util.pixel_coordinates_2d_from( - scaled_coordinates_2d=np.array(scaled_coordinates_2d), + scaled_coordinates_2d=scaled_coordinates_2d, shape_native=self.shape_native, pixel_scales=self.pixel_scales, origins=self.origin, @@ -184,7 +184,7 @@ def scaled_coordinates_2d_from( """ return geometry_util.scaled_coordinates_2d_from( - pixel_coordinates_2d=np.array(pixel_coordinates_2d), + pixel_coordinates_2d=pixel_coordinates_2d, shape_native=self.shape_native, pixel_scales=self.pixel_scales, origins=self.origin, @@ -235,7 +235,7 @@ def grid_pixels_2d_from(self, grid_scaled_2d: Grid2D) -> Grid2D: from autoarray.structures.grids.uniform_2d import Grid2D grid_pixels_2d = geometry_util.grid_pixels_2d_slim_from( - grid_scaled_2d_slim=np.array(grid_scaled_2d.array), + grid_scaled_2d_slim=grid_scaled_2d.array, shape_native=self.shape_native, pixel_scales=self.pixel_scales, origin=self.origin, @@ -261,7 +261,7 @@ def grid_pixel_centres_2d_from(self, grid_scaled_2d: Grid2D) -> Grid2D: from autoarray.structures.grids.uniform_2d import Grid2D grid_pixel_centres_1d = geometry_util.grid_pixel_centres_2d_slim_from( - grid_scaled_2d_slim=np.array(grid_scaled_2d), + grid_scaled_2d_slim=grid_scaled_2d, shape_native=self.shape_native, pixel_scales=self.pixel_scales, origin=self.origin, @@ -294,7 +294,7 @@ def grid_pixel_indexes_2d_from(self, grid_scaled_2d: Grid2D) -> Array2D: from autoarray.structures.arrays.uniform_2d import Array2D grid_pixel_indexes_2d = geometry_util.grid_pixel_indexes_2d_slim_from( - grid_scaled_2d_slim=np.array(grid_scaled_2d), + grid_scaled_2d_slim=grid_scaled_2d, shape_native=self.shape_native, pixel_scales=self.pixel_scales, origin=self.origin, @@ -320,7 +320,7 @@ def grid_scaled_2d_from(self, grid_pixels_2d: Grid2D) -> Grid2D: from autoarray.structures.grids.uniform_2d import Grid2D grid_scaled_1d = geometry_util.grid_scaled_2d_slim_from( - grid_pixels_2d_slim=np.array(grid_pixels_2d), + grid_pixels_2d_slim=grid_pixels_2d, shape_native=self.shape_native, pixel_scales=self.pixel_scales, origin=self.origin, diff --git a/autoarray/geometry/geometry_util.py b/autoarray/geometry/geometry_util.py index e2c6c8898..e0dc8bdcc 100644 --- a/autoarray/geometry/geometry_util.py +++ b/autoarray/geometry/geometry_util.py @@ -2,8 +2,6 @@ import numpy as np from typing import Tuple, Union - -from autoarray import numba_util from autoarray import type as ty @@ -57,7 +55,6 @@ def convert_pixel_scales_1d(pixel_scales: ty.PixelScales) -> Tuple[float]: return pixel_scales -@numba_util.jit() def central_pixel_coordinates_1d_from( shape_slim: Tuple[int], ) -> Union[Tuple[float], Tuple[float]]: @@ -85,7 +82,6 @@ def central_pixel_coordinates_1d_from( return (float(shape_slim[0] - 1) / 2,) -@numba_util.jit() def central_scaled_coordinate_1d_from( shape_slim: Tuple[float], pixel_scales: ty.PixelScales, @@ -121,7 +117,6 @@ def central_scaled_coordinate_1d_from( return (x_pixel,) -@numba_util.jit() def pixel_coordinates_1d_from( scaled_coordinates_1d: Tuple[float], shape_slim: Tuple[int], @@ -139,7 +134,6 @@ def pixel_coordinates_1d_from( return (x_pixel,) -@numba_util.jit() def scaled_coordinates_1d_from( pixel_coordinates_1d: Tuple[float], shape_slim: Tuple[int], @@ -182,70 +176,6 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] return pixel_scales -@numba_util.jit() -def central_pixel_coordinates_2d_numba_from( - shape_native: Tuple[int, int], -) -> Tuple[float, float]: - """ - Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) - from the shape of that data structure. - - Examples of the central pixels are as follows: - - - For a 3x3 image, the central pixel is pixel [1, 1]. - - For a 4x4 image, the central pixel is [1.5, 1.5]. - - Parameters - ---------- - shape_native - The dimensions of the data structure, which can be in 1D, 2D or higher dimensions. - - Returns - ------- - The central pixel coordinates of the data structure. - """ - return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) - - -@numba_util.jit() -def central_scaled_coordinate_2d_numba_from( - shape_native: Tuple[int, int], - pixel_scales: ty.PixelScales, - origin: Tuple[float, float] = (0.0, 0.0), -) -> Tuple[float, float]: - """ - Returns the central scaled coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) - from the shape of that data structure. - - This is computed by using the data structure's shape and converting it to scaled units using an input - pixel-coordinates to scaled-coordinate conversion factor `pixel_scales`. - - The origin of the scaled grid can also be input and moved from (0.0, 0.0). - - Parameters - ---------- - shape_native - The 2D shape of the data structure whose central scaled coordinates are computed. - pixel_scales - The (y,x) scaled units to pixel units conversion factor of the 2D data structure. - origin - The (y,x) scaled units origin of the coordinate system the central scaled coordinate is computed on. - - Returns - ------- - The central coordinates of the 2D data structure in scaled units. - """ - - central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( - shape_native=shape_native - ) - - y_pixel = central_pixel_coordinates[0] + (origin[0] / pixel_scales[0]) - x_pixel = central_pixel_coordinates[1] - (origin[1] / pixel_scales[1]) - - return (y_pixel, x_pixel) - - def central_pixel_coordinates_2d_from( shape_native: Tuple[int, int], ) -> Tuple[float, float]: @@ -298,7 +228,7 @@ def central_scaled_coordinate_2d_from( The central coordinates of the 2D data structure in scaled units. """ - central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( + central_pixel_coordinates = central_pixel_coordinates_2d_from( shape_native=shape_native ) @@ -371,7 +301,6 @@ def pixel_coordinates_2d_from( return (y_pixel, x_pixel) -@numba_util.jit() def scaled_coordinates_2d_from( pixel_coordinates_2d: Tuple[float, float], shape_native: Tuple[int, int], @@ -415,7 +344,7 @@ def scaled_coordinates_2d_from( origin=(0.0, 0.0) ) """ - central_scaled_coordinates = central_scaled_coordinate_2d_numba_from( + central_scaled_coordinates = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origins ) @@ -445,7 +374,12 @@ def transform_grid_2d_to_reference_frame( grid The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - shifted_grid_2d = np.array(grid_2d) - jnp.array(centre) + try: + grid_2d = grid_2d.array + except AttributeError: + pass + + shifted_grid_2d = grid_2d - jnp.array(centre) radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) theta_coordinate_to_profile = jnp.arctan2( @@ -477,23 +411,24 @@ def transform_grid_2d_from_reference_frame( The 2d grid of (y, x) coordinates which are transformed to a new reference frame. """ - cos_angle = np.cos(np.radians(angle)) - sin_angle = np.sin(np.radians(angle)) + cos_angle = jnp.cos(jnp.radians(angle)) + sin_angle = jnp.sin(jnp.radians(angle)) - y = np.add( - np.add( - np.multiply(grid_2d[:, 1], sin_angle), np.multiply(grid_2d[:, 0], cos_angle) + y = jnp.add( + jnp.add( + jnp.multiply(grid_2d[:, 1], sin_angle), + jnp.multiply(grid_2d[:, 0], cos_angle), ), centre[0], ) - x = np.add( - np.add( - np.multiply(grid_2d[:, 1], cos_angle), - -np.multiply(grid_2d[:, 0], sin_angle), + x = jnp.add( + jnp.add( + jnp.multiply(grid_2d[:, 1], cos_angle), + -jnp.multiply(grid_2d[:, 0], sin_angle), ), centre[1], ) - return np.vstack((y, x)).T + return jnp.vstack((y, x)).T def grid_pixels_2d_slim_from( @@ -539,8 +474,8 @@ def grid_pixels_2d_slim_from( centres_scaled = central_scaled_coordinate_2d_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin ) - centres_scaled = np.array(centres_scaled) - pixel_scales = np.array(pixel_scales) + centres_scaled = centres_scaled + pixel_scales = pixel_scales sign = np.array([-1, 1]) return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 diff --git a/autoarray/inversion/inversion/abstract.py b/autoarray/inversion/inversion/abstract.py index 757a24ef0..fe9071277 100644 --- a/autoarray/inversion/inversion/abstract.py +++ b/autoarray/inversion/inversion/abstract.py @@ -1,22 +1,20 @@ import copy - +import jax.numpy as jnp +from jax.scipy.linalg import block_diag import numpy as np -from scipy.linalg import block_diag -from scipy.sparse import csc_matrix -from scipy.sparse.linalg import splu + from typing import Dict, List, Optional, Type, Union from autoconf import cached_property -from autoarray.numba_util import profile_func from autoarray.dataset.imaging.dataset import Imaging from autoarray.dataset.interferometer.dataset import Interferometer -from autoarray.inversion.inversion.mapper_valued import MapperValued from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.inversion.regularization.abstract import AbstractRegularization from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.visibilities import Visibilities @@ -31,7 +29,7 @@ def __init__( dataset: Union[Imaging, Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, + preloads: Preloads = None, ): """ An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions @@ -69,28 +67,15 @@ def __init__( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ - # try: - # import numba - # except ModuleNotFoundError: - # raise exc.InversionException( - # "Inversion functionality (linear light profiles, pixelized reconstructions) is " - # "disabled if numba is not installed.\n\n" - # "This is because the run-times without numba are too slow.\n\n" - # "Please install numba, which is described at the following web page:\n\n" - # "https://pyautolens.readthedocs.io/en/latest/installation/overview.html" - # ) - self.dataset = dataset self.linear_obj_list = linear_obj_list self.settings = settings - self.run_time_dict = run_time_dict + self.preloads = preloads or Preloads() @property def data(self): @@ -161,17 +146,9 @@ def param_range_list_from(self, cls: Type) -> List[List[int]]: ------- A list of the index range of the parameters of each linear object in the inversion of the input cls type. """ - index_list = [] - - pixel_count = 0 - - for linear_obj in self.linear_obj_list: - if isinstance(linear_obj, cls): - index_list.append([pixel_count, pixel_count + linear_obj.params]) - - pixel_count += linear_obj.params - - return index_list + return inversion_util.param_range_list_from( + cls=cls, linear_obj_list=self.linear_obj_list + ) def cls_list_from(self, cls: Type, cls_filtered: Optional[Type] = None) -> List: """ @@ -272,12 +249,27 @@ def no_regularization_index_list(self) -> List[int]: return no_regularization_index_list + @property + def mapper_indices(self) -> np.ndarray: + + if self.preloads.mapper_indices is not None: + return self.preloads.mapper_indices + + mapper_indices = [] + + param_range_list = self.param_range_list_from(cls=AbstractMapper) + + for param_range in param_range_list: + + mapper_indices += range(param_range[0], param_range[1]) + + return np.array(mapper_indices) + @property def mask(self) -> Array2D: return self.data.mask @cached_property - @profile_func def mapping_matrix(self) -> np.ndarray: """ The `mapping_matrix` of a linear object describes the mappings between the observed data's data-points / pixels @@ -293,7 +285,7 @@ def mapping_matrix(self) -> np.ndarray: If there are multiple linear objects, the mapping matrices are stacked such that their simultaneous linear equations are solved simultaneously. This property returns the stacked mapping matrix. """ - return np.hstack( + return jnp.hstack( [linear_obj.mapping_matrix for linear_obj in self.linear_obj_list] ) @@ -302,7 +294,6 @@ def operated_mapping_matrix_list(self) -> np.ndarray: raise NotImplementedError @cached_property - @profile_func def operated_mapping_matrix(self) -> np.ndarray: """ The `operated_mapping_matrix` of a linear object describes the mappings between the observed data's values and @@ -313,20 +304,17 @@ def operated_mapping_matrix(self) -> np.ndarray: If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous linear equations are solved simultaneously. """ - return np.hstack(self.operated_mapping_matrix_list) + return jnp.hstack(self.operated_mapping_matrix_list) @cached_property - @profile_func def data_vector(self) -> np.ndarray: raise NotImplementedError @cached_property - @profile_func def curvature_matrix(self) -> np.ndarray: raise NotImplementedError @cached_property - @profile_func def regularization_matrix(self) -> Optional[np.ndarray]: """ The regularization matrix H is used to impose smoothness on our inversion's reconstruction. This enters the @@ -348,7 +336,6 @@ def regularization_matrix(self) -> Optional[np.ndarray]: ) @cached_property - @profile_func def regularization_matrix_reduced(self) -> Optional[np.ndarray]: """ The regularization matrix H is used to impose smoothness on our inversion's reconstruction. This enters the @@ -363,22 +350,16 @@ def regularization_matrix_reduced(self) -> Optional[np.ndarray]: regularization it is bypassed. """ - regularization_matrix = self.regularization_matrix - if self.all_linear_obj_have_regularization: - return regularization_matrix + return self.regularization_matrix - regularization_matrix = np.delete( - regularization_matrix, self.no_regularization_index_list, 0 - ) - regularization_matrix = np.delete( - regularization_matrix, self.no_regularization_index_list, 1 - ) + # ids of values which are on edge so zero-d and not solved for. + ids_to_keep = self.mapper_indices - return regularization_matrix + # Zero rows and columns in the matrix we want to ignore + return self.regularization_matrix[ids_to_keep][:, ids_to_keep] @cached_property - @profile_func def curvature_reg_matrix(self) -> np.ndarray: """ The linear system of equations solves for F + regularization_coefficient*H, which is computed below. @@ -391,59 +372,33 @@ def curvature_reg_matrix(self) -> np.ndarray: if not self.has(cls=AbstractRegularization): return self.curvature_matrix - if len(self.regularization_list) == 1: - curvature_matrix = self.curvature_matrix - curvature_matrix += self.regularization_matrix - - del self.__dict__["curvature_matrix"] - - return curvature_matrix - - return np.add(self.curvature_matrix, self.regularization_matrix) + return jnp.add(self.curvature_matrix, self.regularization_matrix) @cached_property - @profile_func - def curvature_reg_matrix_reduced(self) -> np.ndarray: + def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]: """ - The linear system of equations solves for F + regularization_coefficient*H, which is computed below. + The regularization matrix H is used to impose smoothness on our inversion's reconstruction. This enters the + linear algebra system we solve for using D and F above and is given by + equation (12) in https://arxiv.org/pdf/astro-ph/0302587.pdf. + + A complete description of regularization is given in the `regularization.py` and `regularization_util.py` + modules. - This is the curvature reg matrix for only the mappers, which is necessary for computing the log det - term without the linear light profiles included. + For multiple mappers, the regularization matrix is computed as the block diagonal of each individual mapper. + The scipy function `block_diag` has an overhead associated with it and if there is only one mapper and + regularization it is bypassed. """ + if self.all_linear_obj_have_regularization: return self.curvature_reg_matrix - curvature_reg_matrix = self.curvature_reg_matrix - - curvature_reg_matrix = np.delete( - curvature_reg_matrix, self.no_regularization_index_list, 0 - ) - curvature_reg_matrix = np.delete( - curvature_reg_matrix, self.no_regularization_index_list, 1 - ) - - return curvature_reg_matrix + # ids of values which are on edge so zero-d and not solved for. + ids_to_keep = self.mapper_indices - @property - def mapper_zero_pixel_list(self) -> np.ndarray: - mapper_zero_pixel_list = [] - param_range_list = self.param_range_list_from(cls=LinearObj) - for param_range, linear_obj in zip(param_range_list, self.linear_obj_list): - if isinstance(linear_obj, AbstractMapper): - mapping_matrix_for_image_pixels_source_zero = linear_obj.mapping_matrix[ - self.settings.image_pixels_source_zero - ] - source_pixels_zero = ( - np.sum(mapping_matrix_for_image_pixels_source_zero != 0, axis=0) - != 0 - ) - mapper_zero_pixel_list.append( - np.where(source_pixels_zero == True)[0] + param_range[0] - ) - return mapper_zero_pixel_list + # Zero rows and columns in the matrix we want to ignore + return self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep] @cached_property - @profile_func def reconstruction(self) -> np.ndarray: """ Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12) @@ -460,66 +415,49 @@ def reconstruction(self) -> np.ndarray: ZTx := np.dot(Z.T, x) """ if self.settings.use_positive_only_solver: - """ - For the new implementation, we now need to take out the cols and rows of - the curvature_reg_matrix that corresponds to the parameters we force to be 0. - Similar for the data vector. - - What we actually doing is that we have set the correspoding cols of the Z to be 0. - As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out. - And the data_vector = ZTx, so the corresponding row is also taken out. - """ - - if self.settings.force_edge_pixels_to_zeros: - if self.settings.force_edge_image_pixels_to_zeros: - ids_zeros = np.unique( - np.append( - self.mapper_edge_pixel_list, self.mapper_zero_pixel_list - ) - ) - else: - ids_zeros = self.mapper_edge_pixel_list - - values_to_solve = np.ones( - np.shape(self.curvature_reg_matrix)[0], dtype=bool - ) - values_to_solve[ids_zeros] = False - data_vector_input = self.data_vector[values_to_solve] + if self.preloads.source_pixel_zeroed_indices is not None: - curvature_reg_matrix_input = self.curvature_reg_matrix[ - values_to_solve, : - ][:, values_to_solve] + # ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads. + ids_to_keep = self.preloads.source_pixel_zeroed_indices_to_keep - solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0]) + # Use advanced indexing to select rows/columns + data_vector = self.data_vector[ids_to_keep] + curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][ + :, ids_to_keep + ] - solutions[values_to_solve] = ( + # Perform reconstruction via fnnls + reconstruction_partial = ( inversion_util.reconstruction_positive_only_from( - data_vector=data_vector_input, - curvature_reg_matrix=curvature_reg_matrix_input, - settings=self.settings, + data_vector=data_vector, + curvature_reg_matrix=curvature_reg_matrix, ) ) - return solutions + + # Allocate full solution array + reconstruction = jnp.zeros(self.data_vector.shape[0]) + + # Scatter the partial solution back to the full shape + reconstruction = reconstruction.at[ids_to_keep].set( + reconstruction_partial + ) + + return reconstruction + else: - solutions = inversion_util.reconstruction_positive_only_from( + + return inversion_util.reconstruction_positive_only_from( data_vector=self.data_vector, curvature_reg_matrix=self.curvature_reg_matrix, - settings=self.settings, ) - return solutions - - mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper) - return inversion_util.reconstruction_positive_negative_from( data_vector=self.data_vector, curvature_reg_matrix=self.curvature_reg_matrix, - mapper_param_range_list=mapper_param_range_list, ) @cached_property - @profile_func def reconstruction_reduced(self) -> np.ndarray: """ Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12) @@ -531,7 +469,11 @@ def reconstruction_reduced(self) -> np.ndarray: if self.all_linear_obj_have_regularization: return self.reconstruction - return np.delete(self.reconstruction, self.no_regularization_index_list, axis=0) + # ids of values which are on edge so zero-d and not solved for. + ids_to_keep = self.mapper_indices + + # Zero rows and columns in the matrix we want to ignore + return self.reconstruction[ids_to_keep] @property def reconstruction_dict(self) -> Dict[LinearObj, np.ndarray]: @@ -576,7 +518,6 @@ def source_quantity_dict_from( return source_quantity_dict @property - @profile_func def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: raise NotImplementedError @@ -597,7 +538,6 @@ def mapped_reconstructed_image_dict(self) -> Dict[LinearObj, Array2D]: return self.mapped_reconstructed_data_dict @cached_property - @profile_func def mapped_reconstructed_data(self) -> Union[Array2D, Visibilities]: """ Using the reconstructed source pixel fluxes we map each source pixel flux back to the image plane and @@ -658,7 +598,6 @@ def data_subtracted_dict(self) -> Dict[LinearObj, Array2D]: return data_subtracted_dict @cached_property - @profile_func def regularization_term(self) -> float: """ Returns the regularization term of an inversion. This term represents the sum of the difference in flux @@ -677,13 +616,12 @@ def regularization_term(self) -> float: if not self.has(cls=AbstractRegularization): return 0.0 - return np.matmul( + return jnp.matmul( self.reconstruction_reduced.T, - np.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced), + jnp.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced), ) @cached_property - @profile_func def log_det_curvature_reg_matrix_term(self) -> float: """ The log determinant of [F + reg_coeff*H] is used to determine the Bayesian evidence of the solution. @@ -695,13 +633,14 @@ def log_det_curvature_reg_matrix_term(self) -> float: try: return 2.0 * np.sum( - np.log(np.diag(np.linalg.cholesky(self.curvature_reg_matrix_reduced))) + jnp.log( + jnp.diag(jnp.linalg.cholesky(self.curvature_reg_matrix_reduced)) + ) ) except np.linalg.LinAlgError as e: raise exc.InversionException() from e @cached_property - @profile_func def log_det_regularization_matrix_term(self) -> float: """ The Bayesian evidence of an inversion which quantifies its overall goodness-of-fit uses the log determinant @@ -715,28 +654,17 @@ def log_det_regularization_matrix_term(self) -> float: float The log determinant of the regularization matrix. """ - if not self.has(cls=AbstractRegularization): return 0.0 try: - lu = splu(csc_matrix(self.regularization_matrix_reduced)) - diagL = lu.L.diagonal() - diagU = lu.U.diagonal() - diagL = diagL.astype(np.complex128) - diagU = diagU.astype(np.complex128) - - return np.real(np.log(diagL).sum() + np.log(diagU).sum()) - - except RuntimeError: - try: - return 2.0 * np.sum( - np.log( - np.diag(np.linalg.cholesky(self.regularization_matrix_reduced)) - ) + return 2.0 * np.sum( + jnp.log( + jnp.diag(jnp.linalg.cholesky(self.regularization_matrix_reduced)) ) - except np.linalg.LinAlgError as e: - raise exc.InversionException() from e + ) + except np.linalg.LinAlgError as e: + raise exc.InversionException() from e @property def reconstruction_noise_map_with_covariance(self) -> np.ndarray: @@ -812,12 +740,10 @@ def regularization_weights_mapper_dict(self) -> Dict[LinearObj, np.ndarray]: return regularization_weights_dict @property - @profile_func def _data_vector_mapper(self) -> np.ndarray: raise NotImplementedError @property - @profile_func def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: raise NotImplementedError diff --git a/autoarray/inversion/inversion/dataset_interface.py b/autoarray/inversion/inversion/dataset_interface.py index 0e417e71a..cf5960bf8 100644 --- a/autoarray/inversion/inversion/dataset_interface.py +++ b/autoarray/inversion/inversion/dataset_interface.py @@ -35,6 +35,9 @@ def __init__( noise_map An array describing the RMS standard deviation error in each pixel used for computing quantities like the chi-squared in a fit (in PyAutoGalaxy and PyAutoLens the recommended units are electrons per second). + grids + The grids of (y,x) Cartesian coordinates that the image data is paired with, which are used for evaluting + light profiles and calculations associated with a pixelization. over_sampler Performs over-sampling whereby the masked image pixels are split into sub-pixels, which are all mapped via the mapper with sub-fractional values of flux. @@ -49,9 +52,6 @@ def __init__( w_tilde The w_tilde matrix used by the w-tilde formalism to construct the data vector and curvature matrix during an inversion efficiently.. - grids - The grids of (y,x) Cartesian coordinates that the image data is paired with, which are used for evaluting - light profiles and calculations associated with a pixelization. noise_covariance_matrix A noise-map covariance matrix representing the covariance between noise in every `data` value, which can be used via a bespoke fit to account for correlated noise in the data. diff --git a/autoarray/inversion/inversion/factory.py b/autoarray/inversion/inversion/factory.py index 327262786..b7c9016b1 100644 --- a/autoarray/inversion/inversion/factory.py +++ b/autoarray/inversion/inversion/factory.py @@ -9,14 +9,12 @@ from autoarray.inversion.inversion.interferometer.w_tilde import ( InversionInterferometerWTilde, ) -from autoarray.inversion.inversion.interferometer.lop import ( - InversionInterferometerMappingPyLops, -) from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D @@ -24,7 +22,7 @@ def inversion_from( dataset: Union[Imaging, Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, + preloads: Preloads = None, ): """ Factory which given an input dataset and list of linear objects, creates an `Inversion`. @@ -49,8 +47,6 @@ def inversion_from( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. Returns ------- @@ -61,14 +57,13 @@ def inversion_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, + preloads=preloads, ) return inversion_interferometer_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, ) @@ -76,7 +71,7 @@ def inversion_imaging_from( dataset, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, + preloads: Preloads = None, ): """ Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`. @@ -105,8 +100,6 @@ def inversion_imaging_from( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. Returns ------- @@ -131,14 +124,13 @@ def inversion_imaging_from( w_tilde=w_tilde, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, ) return InversionImagingMapping( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, + preloads=preloads, ) @@ -146,7 +138,6 @@ def inversion_interferometer_from( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, ): """ Factory which given an input `Interferometer` dataset and list of linear objects, creates @@ -178,8 +169,6 @@ def inversion_interferometer_from( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. Returns ------- @@ -202,7 +191,6 @@ def inversion_interferometer_from( w_tilde=w_tilde, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, ) else: @@ -210,13 +198,4 @@ def inversion_interferometer_from( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, ) - - else: - return InversionInterferometerMappingPyLops( - dataset=dataset, - linear_obj_list=linear_obj_list, - settings=settings, - run_time_dict=run_time_dict, - ) diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 5efc4d0a9..9167af6f9 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -1,10 +1,8 @@ import numpy as np -from typing import Dict, List, Optional, Union, Type +from typing import Dict, List, Union, Type from autoconf import cached_property -from autoarray.numba_util import profile_func - from autoarray.dataset.imaging.dataset import Imaging from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList @@ -12,6 +10,7 @@ from autoarray.inversion.inversion.abstract import AbstractInversion from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.preloads import Preloads from autoarray.inversion.inversion.imaging import inversion_imaging_util @@ -22,7 +21,7 @@ def __init__( dataset: Union[Imaging, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, + preloads: Preloads = None, ): """ An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions @@ -63,15 +62,13 @@ def __init__( input dataset's data and whose values are solved for via the inversion. settings Settings controlling how an inversion is fitted for example which linear algebra formalism is used. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, + preloads=preloads, ) @property @@ -96,7 +93,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: return [ ( self.psf.convolve_mapping_matrix( - mapping_matrix=linear_obj.mapping_matrix + mapping_matrix=linear_obj.mapping_matrix, mask=self.mask ) if linear_obj.operated_mapping_matrix_override is None else self.linear_func_operated_mapping_matrix_dict[linear_obj] @@ -116,7 +113,6 @@ def _updated_cls_key_dict_from(self, cls: Type, preload_dict: Dict) -> Dict: return cls_dict @cached_property - @profile_func def linear_func_operated_mapping_matrix_dict(self) -> Dict: """ The `operated_mapping_matrix` of a linear object describes the mappings between the observed data's values and @@ -139,7 +135,8 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: operated_mapping_matrix = linear_func.operated_mapping_matrix_override else: operated_mapping_matrix = self.psf.convolve_mapping_matrix( - mapping_matrix=linear_func.mapping_matrix + mapping_matrix=linear_func.mapping_matrix, + mask=self.mask, ) linear_func_operated_mapping_matrix_dict[linear_func] = ( @@ -191,9 +188,8 @@ def data_linear_func_matrix_dict(self): data_linear_func_matrix = ( inversion_imaging_util.data_linear_func_matrix_from( curvature_weights_matrix=curvature_weights, - image_frame_1d_lengths=self.convolver.image_frame_1d_lengths, - image_frame_1d_indexes=self.convolver.image_frame_1d_indexes, - image_frame_1d_kernels=self.convolver.image_frame_1d_kernels, + kernel_native=self.psf.native, + mask=self.mask, ) ) @@ -202,7 +198,6 @@ def data_linear_func_matrix_dict(self): return data_linear_func_matrix_dict @cached_property - @profile_func def mapper_operated_mapping_matrix_dict(self) -> Dict: """ The `operated_mapping_matrix` of a `Mapper` object describes the mappings between the observed data's values @@ -221,7 +216,8 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict: for mapper in self.cls_list_from(cls=AbstractMapper): operated_mapping_matrix = self.psf.convolve_mapping_matrix( - mapping_matrix=mapper.mapping_matrix + mapping_matrix=mapper.mapping_matrix, + mask=self.mask, ) mapper_operated_mapping_matrix_dict[mapper] = operated_mapping_matrix diff --git a/autoarray/inversion/inversion/imaging/inversion_imaging_numba_util.py b/autoarray/inversion/inversion/imaging/inversion_imaging_numba_util.py new file mode 100644 index 000000000..0eb95d26d --- /dev/null +++ b/autoarray/inversion/inversion/imaging/inversion_imaging_numba_util.py @@ -0,0 +1,859 @@ +from typing import List, Optional, Tuple + +from autoarray import numba_util + +import numpy as np + + +@numba_util.jit() +def w_tilde_data_imaging_from( + image_native: np.ndarray, + noise_map_native: np.ndarray, + kernel_native: np.ndarray, + native_index_for_slim_index, +) -> np.ndarray: + """ + The matrix w_tilde is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF convolution of + every pair of image pixels given the noise map. This can be used to efficiently compute the curvature matrix via + the mappings between image and source pixels, in a way that omits having to perform the PSF convolution on every + individual source pixel. This provides a significant speed up for inversions of imaging datasets. + + When w_tilde is used to perform an inversion, the mapping matrices are not computed, meaning that they cannot be + used to compute the data vector. This method creates the vector `w_tilde_data` which allows for the data + vector to be computed efficiently without the mapping matrix. + + The matrix w_tilde_data is dimensions [image_pixels] and encodes the PSF convolution with the `weight_map`, + where the weights are the image-pixel values divided by the noise-map values squared: + + weight = image / noise**2.0 + + Parameters + ---------- + image_native + The two dimensional masked image of values which `w_tilde_data` is computed from. + noise_map_native + The two dimensional masked noise-map of values which `w_tilde_data` is computed from. + kernel_native + The two dimensional PSF kernel that `w_tilde_data` encodes the convolution of. + native_index_for_slim_index + An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. + + Returns + ------- + ndarray + A matrix that encodes the PSF convolution values between the imaging divided by the noise map**2 that enables + efficient calculation of the data vector. + """ + + kernel_shift_y = -(kernel_native.shape[1] // 2) + kernel_shift_x = -(kernel_native.shape[0] // 2) + + image_pixels = len(native_index_for_slim_index) + + w_tilde_data = np.zeros((image_pixels,)) + + weight_map_native = image_native / noise_map_native**2.0 + + for ip0 in range(image_pixels): + ip0_y, ip0_x = native_index_for_slim_index[ip0] + + value = 0.0 + + for k0_y in range(kernel_native.shape[0]): + for k0_x in range(kernel_native.shape[1]): + weight_value = weight_map_native[ + ip0_y + k0_y + kernel_shift_y, ip0_x + k0_x + kernel_shift_x + ] + + if not np.isnan(weight_value): + value += kernel_native[k0_y, k0_x] * weight_value + + w_tilde_data[ip0] = value + + return w_tilde_data + + +@numba_util.jit() +def w_tilde_curvature_imaging_from( + noise_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index +) -> np.ndarray: + """ + The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF + convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the + curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the + PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging + datasets. + + The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, + making it impossible to store in memory and its use in linear algebra calculations extremely. The method + `w_tilde_curvature_preload_imaging_from` describes a compressed representation that overcomes this hurdles. It is + advised `w_tilde` and this method are only used for testing. + + Parameters + ---------- + noise_map_native + The two dimensional masked noise-map of values which w_tilde is computed from. + kernel_native + The two dimensional PSF kernel that w_tilde encodes the convolution of. + native_index_for_slim_index + An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. + + Returns + ------- + ndarray + A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of + the curvature matrix. + """ + image_pixels = len(native_index_for_slim_index) + + w_tilde_curvature = np.zeros((image_pixels, image_pixels)) + + for ip0 in range(w_tilde_curvature.shape[0]): + ip0_y, ip0_x = native_index_for_slim_index[ip0] + + for ip1 in range(ip0, w_tilde_curvature.shape[1]): + ip1_y, ip1_x = native_index_for_slim_index[ip1] + + w_tilde_curvature[ip0, ip1] += w_tilde_curvature_value_from( + value_native=noise_map_native, + kernel_native=kernel_native, + ip0_y=ip0_y, + ip0_x=ip0_x, + ip1_y=ip1_y, + ip1_x=ip1_x, + ) + + for ip0 in range(w_tilde_curvature.shape[0]): + for ip1 in range(ip0, w_tilde_curvature.shape[1]): + w_tilde_curvature[ip1, ip0] = w_tilde_curvature[ip0, ip1] + + return w_tilde_curvature + + +@numba_util.jit() +def w_tilde_curvature_preload_imaging_from( + noise_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF + convolution of every pair of image pixels on the noise map. This can be used to efficiently compute the + curvature matrix via the mappings between image and source pixels, in a way that omits having to repeat the PSF + convolution on every individual source pixel. This provides a significant speed up for inversions of imaging + datasets. + + The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, + making it impossible to store in memory and its use in linear algebra calculations slow. This methods creates + a sparse matrix that can compute the matrix `w_tilde_curvature` efficiently, albeit the linear algebra calculations + in PyAutoArray bypass this matrix entirely to go straight to the curvature matrix. + + for dataset data, w_tilde is a sparse matrix, whereby non-zero entries are only contained for pairs of image pixels + where the two pixels overlap due to the kernel size. For example, if the kernel size is (11, 11) and two image + pixels are separated by more than 20 pixels, the kernel will never convolve flux between the two pixels. Two image + pixels will only share a convolution if they are within `kernel_overlap_size = 2 * kernel_shape - 1` pixels within + one another. + + Thus, a `w_tilde_curvature_preload` matrix of dimensions [image_pixels, kernel_overlap_size ** 2] can be computed + which significantly reduces the memory consumption by removing the sparsity. Because the dimensions of the second + axes is no longer `image_pixels`, a second matrix `w_tilde_indexes` must also be computed containing the slim image + pixel indexes of every entry of `w_tilde_preload`. + + In order for the preload to store half the number of values, owing to the symmetry of the `w_tilde_curvature` + matrix, the image pixel pairs corresponding to the same image pixel are divided by two. This ensures that when the + curvature matrix is computed these pixels are not double-counted. + + The values stored in `w_tilde_curvature_preload` represent the convolution of overlapping noise-maps given the + PSF kernel. It is common for many values to be neglibly small. Removing these values can speed up the inversion + and reduce memory at the expense of a numerically irrelevent change of solution. + + This matrix can then be used to compute the `curvature_matrix` in a memory efficient way that exploits the sparsity + of the linear algebra. + + Parameters + ---------- + noise_map_native + The two dimensional masked noise-map of values which `w_tilde_curvature` is computed from. + signal_to_noise_map_native + The two dimensional masked signal-to-noise-map from which the threshold discarding low S/N image pixel + pairs is used. + kernel_native + The two dimensional PSF kernel that `w_tilde_curvature` encodes the convolution of. + native_index_for_slim_index + An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. + + Returns + ------- + ndarray + A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of + the curvature matrix, where the dimensions are reduced to save memory. + """ + + image_pixels = len(native_index_for_slim_index) + + kernel_overlap_size = (2 * kernel_native.shape[0] - 1) * ( + 2 * kernel_native.shape[1] - 1 + ) + + curvature_preload_tmp = np.zeros((image_pixels, kernel_overlap_size)) + curvature_indexes_tmp = np.zeros((image_pixels, kernel_overlap_size)) + curvature_lengths = np.zeros(image_pixels) + + for ip0 in range(image_pixels): + ip0_y, ip0_x = native_index_for_slim_index[ip0] + + kernel_index = 0 + + for ip1 in range(ip0, curvature_preload_tmp.shape[0]): + ip1_y, ip1_x = native_index_for_slim_index[ip1] + + noise_value = w_tilde_curvature_value_from( + value_native=noise_map_native, + kernel_native=kernel_native, + ip0_y=ip0_y, + ip0_x=ip0_x, + ip1_y=ip1_y, + ip1_x=ip1_x, + ) + + if ip0 == ip1: + noise_value /= 2.0 + + if noise_value > 0.0: + curvature_preload_tmp[ip0, kernel_index] = noise_value + curvature_indexes_tmp[ip0, kernel_index] = ip1 + kernel_index += 1 + + curvature_lengths[ip0] = kernel_index + + curvature_total_pairs = int(np.sum(curvature_lengths)) + + curvature_preload = np.zeros((curvature_total_pairs)) + curvature_indexes = np.zeros((curvature_total_pairs)) + + index = 0 + + for i in range(image_pixels): + for data_index in range(int(curvature_lengths[i])): + curvature_preload[index] = curvature_preload_tmp[i, data_index] + curvature_indexes[index] = curvature_indexes_tmp[i, data_index] + + index += 1 + + return (curvature_preload, curvature_indexes, curvature_lengths) + + +@numba_util.jit() +def w_tilde_curvature_value_from( + value_native: np.ndarray, + kernel_native: np.ndarray, + ip0_y, + ip0_x, + ip1_y, + ip1_x, + renormalize=False, +) -> float: + """ + Compute the value of an entry of the `w_tilde_curvature` matrix, where this entry encodes the PSF convolution of + the noise-map between two image pixels. + + The calculation is performed by over-laying the PSF kernel over two noise-map pixels in 2D. For all pixels where + the two overlaid PSF kernels overlap, the following calculation is performed for every noise map value: + + `value = kernel_value_0 * kernel_value_1 * (1.0 / noise_value) ** 2.0` + + This calculation infers the fraction of flux that every PSF convolution will move between each pair of noise-map + pixels and can therefore be used to efficiently calculate the curvature_matrix that is used in the linear algebra + calculation of an inversion. + + The sum of all values where kernel pixels overlap is returned to give the `w_tilde` value. + + Parameters + ---------- + value_native + A two dimensional masked array of values (e.g. a noise-map, signal to noise map) which the w_tilde curvature + values are computed from. + kernel_native + The two dimensional PSF kernel that w_tilde encodes the convolution of. + ip0_y + The y index of the first image pixel in the image pixel pair. + ip0_x + The x index of the first image pixel in the image pixel pair. + ip1_y + The y index of the second image pixel in the image pixel pair. + ip1_x + The x index of the second image pixel in the image pixel pair. + + Returns + ------- + float + The w_tilde value that encodes the value of PSF convolution between a pair of image pixels. + + """ + + curvature_value = 0.0 + + kernel_shift_y = -(kernel_native.shape[1] // 2) + kernel_shift_x = -(kernel_native.shape[0] // 2) + + ip_y_offset = ip0_y - ip1_y + ip_x_offset = ip0_x - ip1_x + + if ( + ip_y_offset < 2 * kernel_shift_y + or ip_y_offset > -2 * kernel_shift_y + or ip_x_offset < 2 * kernel_shift_x + or ip_x_offset > -2 * kernel_shift_x + ): + return curvature_value + + kernel_pixels = kernel_native.shape[0] * kernel_native.shape[1] + kernel_count = 0 + + for k0_y in range(kernel_native.shape[0]): + for k0_x in range(kernel_native.shape[1]): + value = value_native[ + ip0_y + k0_y + kernel_shift_y, ip0_x + k0_x + kernel_shift_x + ] + + if value > 0.0: + k1_y = k0_y + ip_y_offset + k1_x = k0_x + ip_x_offset + + if ( + k1_y >= 0 + and k1_x >= 0 + and k1_y < kernel_native.shape[0] + and k1_x < kernel_native.shape[1] + ): + kernel_count += 1 + + kernel_value_0 = kernel_native[k0_y, k0_x] + kernel_value_1 = kernel_native[k1_y, k1_x] + + curvature_value += ( + kernel_value_0 * kernel_value_1 * (1.0 / value) ** 2.0 + ) + + if renormalize: + if kernel_count > 0: + curvature_value *= kernel_pixels / kernel_count + + return curvature_value + + +@numba_util.jit() +def data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix: np.ndarray, image: np.ndarray, noise_map: np.ndarray +) -> np.ndarray: + """ + Returns the data vector `D` from a blurred mapping matrix `f` and the 1D image `d` and 1D noise-map $\sigma$` + (see Warren & Dye 2003). + + Parameters + ---------- + blurred_mapping_matrix + The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. + image + Flattened 1D array of the observed image the inversion is fitting. + noise_map + Flattened 1D array of the noise-map used by the inversion during the fit. + """ + + data_shape = blurred_mapping_matrix.shape + + data_vector = np.zeros(data_shape[1]) + + for data_index in range(data_shape[0]): + for pix_index in range(data_shape[1]): + data_vector[pix_index] += ( + image[data_index] + * blurred_mapping_matrix[data_index, pix_index] + / (noise_map[data_index] ** 2.0) + ) + + return data_vector + + +@numba_util.jit() +def data_vector_via_w_tilde_data_imaging_from( + w_tilde_data: np.ndarray, + data_to_pix_unique: np.ndarray, + data_weights: np.ndarray, + pix_lengths: np.ndarray, + pix_pixels: int, +) -> np.ndarray: + """ + Returns the data vector `D` from the `w_tilde_data` matrix (see `w_tilde_data_imaging_from`), which encodes the + the 1D image `d` and 1D noise-map values `\sigma` (see Warren & Dye 2003). + + This uses the array `data_to_pix_unique`, which describes the unique mappings of every set of image sub-pixels to + pixelization pixels and `data_weights`, which describes how many sub-pixels uniquely map to each pixelization + pixels (see `data_slim_to_pixelization_unique_from`). + + Parameters + ---------- + w_tilde_data + A matrix that encodes the PSF convolution values between the imaging divided by the noise map**2 that enables + efficient calculation of the data vector. + data_to_pix_unique + An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of + pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`). + data_weights + For every unique mapping between a set of data sub-pixels and a pixelization pixel, the weight of these mapping + based on the number of sub-pixels that map to pixelization pixel. + pix_lengths + A 1D array describing how many unique pixels each data pixel maps too, which is used to iterate over + `data_to_pix_unique` and `data_weights`. + pix_pixels + The total number of pixels in the pixelization that reconstructs the data. + """ + + data_pixels = w_tilde_data.shape[0] + + data_vector = np.zeros(pix_pixels) + + for data_0 in range(data_pixels): + for pix_0_index in range(pix_lengths[data_0]): + data_0_weight = data_weights[data_0, pix_0_index] + pix_0 = data_to_pix_unique[data_0, pix_0_index] + + data_vector[pix_0] += data_0_weight * w_tilde_data[data_0] + + return data_vector + + +@numba_util.jit() +def curvature_matrix_with_added_to_diag_from( + curvature_matrix: np.ndarray, + value: float, + no_regularization_index_list: Optional[List] = None, +) -> np.ndarray: + """ + It is common for the `curvature_matrix` computed to not be positive-definite, leading for the inversion + via `np.linalg.solve` to fail and raise a `LinAlgError`. + + In many circumstances, adding a small numerical value of `1.0e-8` to the diagonal of the `curvature_matrix` + makes it positive definite, such that the inversion is performed without raising an error. + + This function adds this numerical value to the diagonal of the curvature matrix. + + Parameters + ---------- + curvature_matrix + The curvature matrix which is being constructed in order to solve a linear system of equations. + """ + + for i in no_regularization_index_list: + curvature_matrix[i, i] += value + + return curvature_matrix + + +@numba_util.jit() +def curvature_matrix_mirrored_from( + curvature_matrix: np.ndarray, +) -> np.ndarray: + curvature_matrix_mirrored = np.zeros( + (curvature_matrix.shape[0], curvature_matrix.shape[1]) + ) + + for i in range(curvature_matrix.shape[0]): + for j in range(curvature_matrix.shape[1]): + if curvature_matrix[i, j] != 0: + curvature_matrix_mirrored[i, j] = curvature_matrix[i, j] + curvature_matrix_mirrored[j, i] = curvature_matrix[i, j] + if curvature_matrix[j, i] != 0: + curvature_matrix_mirrored[i, j] = curvature_matrix[j, i] + curvature_matrix_mirrored[j, i] = curvature_matrix[j, i] + + return curvature_matrix_mirrored + + +@numba_util.jit() +def curvature_matrix_via_w_tilde_curvature_preload_imaging_from( + curvature_preload: np.ndarray, + curvature_indexes: np.ndarray, + curvature_lengths: np.ndarray, + data_to_pix_unique: np.ndarray, + data_weights: np.ndarray, + pix_lengths: np.ndarray, + pix_pixels: int, +) -> np.ndarray: + """ + Returns the curvature matrix `F` (see Warren & Dye 2003) by computing it using `w_tilde_preload` + (see `w_tilde_preload_interferometer_from`) for an imaging inversion. + + To compute the curvature matrix via w_tilde the following matrix multiplication is normally performed: + + curvature_matrix = mapping_matrix.T * w_tilde * mapping matrix + + This function speeds this calculation up in two ways: + + 1) Instead of using `w_tilde` (dimensions [image_pixels, image_pixels] it uses `w_tilde_preload` (dimensions + [image_pixels, kernel_overlap]). The massive reduction in the size of this matrix in memory allows for much fast + computation. + + 2) It omits the `mapping_matrix` and instead uses directly the 1D vector that maps every image pixel to a source + pixel `native_index_for_slim_index`. This exploits the sparsity in the `mapping_matrix` to directly + compute the `curvature_matrix` (e.g. it condenses the triple matrix multiplication into a double for loop!). + + Parameters + ---------- + curvature_preload + A matrix that precomputes the values for fast computation of the curvature matrix in a memory efficient way. + curvature_indexes + The image-pixel indexes of the values stored in the w tilde preload matrix, which are used to compute + the weights of the data values when computing the curvature matrix. + curvature_lengths + The number of image pixels in every row of `w_tilde_curvature`, which is iterated over when computing the + curvature matrix. + data_to_pix_unique + An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of + pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`). + data_weights + For every unique mapping between a set of data sub-pixels and a pixelization pixel, the weight of these mapping + based on the number of sub-pixels that map to pixelization pixel. + pix_lengths + A 1D array describing how many unique pixels each data pixel maps too, which is used to iterate over + `data_to_pix_unique` and `data_weights`. + pix_pixels + The total number of pixels in the pixelization that reconstructs the data. + + Returns + ------- + ndarray + The curvature matrix `F` (see Warren & Dye 2003). + """ + + data_pixels = curvature_lengths.shape[0] + + curvature_matrix = np.zeros((pix_pixels, pix_pixels)) + + curvature_index = 0 + + for data_0 in range(data_pixels): + for data_1_index in range(curvature_lengths[data_0]): + data_1 = curvature_indexes[curvature_index] + w_tilde_value = curvature_preload[curvature_index] + + for pix_0_index in range(pix_lengths[data_0]): + data_0_weight = data_weights[data_0, pix_0_index] + pix_0 = data_to_pix_unique[data_0, pix_0_index] + + for pix_1_index in range(pix_lengths[data_1]): + data_1_weight = data_weights[data_1, pix_1_index] + pix_1 = data_to_pix_unique[data_1, pix_1_index] + + curvature_matrix[pix_0, pix_1] += ( + data_0_weight * data_1_weight * w_tilde_value + ) + + curvature_index += 1 + + for i in range(pix_pixels): + for j in range(i, pix_pixels): + curvature_matrix[i, j] += curvature_matrix[j, i] + + for i in range(pix_pixels): + for j in range(i, pix_pixels): + curvature_matrix[j, i] = curvature_matrix[i, j] + + return curvature_matrix + + +@numba_util.jit() +def curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( + curvature_preload: np.ndarray, + curvature_indexes: np.ndarray, + curvature_lengths: np.ndarray, + data_to_pix_unique_0: np.ndarray, + data_weights_0: np.ndarray, + pix_lengths_0: np.ndarray, + pix_pixels_0: int, + data_to_pix_unique_1: np.ndarray, + data_weights_1: np.ndarray, + pix_lengths_1: np.ndarray, + pix_pixels_1: int, +) -> np.ndarray: + """ + Returns the off diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) by computing them + using `w_tilde_preload` (see `w_tilde_preload_interferometer_from`) for an imaging inversion. + + When there is more than one mapper in the inversion, its `mapping_matrix` is extended to have dimensions + [data_pixels, sum(source_pixels_in_each_mapper)]. The curvature matrix therefore will have dimensions + [sum(source_pixels_in_each_mapper), sum(source_pixels_in_each_mapper)]. + + To compute the curvature matrix via w_tilde the following matrix multiplication is normally performed: + + curvature_matrix = mapping_matrix.T * w_tilde * mapping matrix + + When the `mapping_matrix` consists of multiple mappers from different planes, this means that shared data mappings + between source-pixels in different mappers must be accounted for when computing the `curvature_matrix`. These + appear as off-diagonal terms in the overall curvature matrix. + + This function evaluates these off-diagonal terms, by using the w-tilde curvature preloads and the unique + data-to-pixelization mappings of each mapper. It behaves analogous to the + function `curvature_matrix_via_w_tilde_curvature_preload_imaging_from`. + + Parameters + ---------- + curvature_preload + A matrix that precomputes the values for fast computation of the curvature matrix in a memory efficient way. + curvature_indexes + The image-pixel indexes of the values stored in the w tilde preload matrix, which are used to compute + the weights of the data values when computing the curvature matrix. + curvature_lengths + The number of image pixels in every row of `w_tilde_curvature`, which is iterated over when computing the + curvature matrix. + data_to_pix_unique + An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of + pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`). + data_weights + For every unique mapping between a set of data sub-pixels and a pixelization pixel, the weight of these mapping + based on the number of sub-pixels that map to pixelization pixel. + pix_lengths + A 1D array describing how many unique pixels each data pixel maps too, which is used to iterate over + `data_to_pix_unique` and `data_weights`. + pix_pixels + The total number of pixels in the pixelization that reconstructs the data. + + Returns + ------- + ndarray + The curvature matrix `F` (see Warren & Dye 2003). + """ + + data_pixels = curvature_lengths.shape[0] + + curvature_matrix = np.zeros((pix_pixels_0, pix_pixels_1)) + + curvature_index = 0 + + for data_0 in range(data_pixels): + for data_1_index in range(curvature_lengths[data_0]): + data_1 = curvature_indexes[curvature_index] + w_tilde_value = curvature_preload[curvature_index] + + for pix_0_index in range(pix_lengths_0[data_0]): + data_0_weight = data_weights_0[data_0, pix_0_index] + pix_0 = data_to_pix_unique_0[data_0, pix_0_index] + + for pix_1_index in range(pix_lengths_1[data_1]): + data_1_weight = data_weights_1[data_1, pix_1_index] + pix_1 = data_to_pix_unique_1[data_1, pix_1_index] + + curvature_matrix[pix_0, pix_1] += ( + data_0_weight * data_1_weight * w_tilde_value + ) + + curvature_index += 1 + + return curvature_matrix + + +@numba_util.jit() +def curvature_matrix_off_diags_via_data_linear_func_matrix_from( + data_linear_func_matrix: np.ndarray, + data_to_pix_unique: np.ndarray, + data_weights: np.ndarray, + pix_lengths: np.ndarray, + pix_pixels: int, +): + """ + Returns the off diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) between a mapper object + and a linear func object, using the preloaded `data_linear_func_matrix` of the values of the linear functions. + + + If a linear function in an inversion is fixed, its values can be evaluated and preloaded beforehand. For every + data pixel, the PSF convolution with this preloaded linear function can also be preloaded, in a matrix of + shape [data_pixels, 1]. + + When mapper objects and linear functions are used simultaneously in an inversion, this preloaded matrix + significantly speed up the computation of their off-diagonal terms in the curvature matrix. + + This function performs this efficient calcluation via the preloaded `data_linear_func_matrix`. + + Parameters + ---------- + data_linear_func_matrix + A matrix of shape [data_pixels, total_fixed_linear_functions] that for each data pixel, maps it to the sum of + the values of a linear object function convolved with the PSF kernel at the data pixel. + data_to_pix_unique + The indexes of all pixels that each data pixel maps to (see the `Mapper` object). + data_weights + The weights of all pixels that each data pixel maps to (see the `Mapper` object). + pix_lengths + The number of pixelization pixels that each data pixel maps to (see the `Mapper` object). + pix_pixels + The number of pixelization pixels in the pixelization (see the `Mapper` object). + """ + + linear_func_pixels = data_linear_func_matrix.shape[1] + + off_diag = np.zeros((pix_pixels, linear_func_pixels)) + + data_pixels = data_weights.shape[0] + + for data_0 in range(data_pixels): + for pix_0_index in range(pix_lengths[data_0]): + data_0_weight = data_weights[data_0, pix_0_index] + pix_0 = data_to_pix_unique[data_0, pix_0_index] + + for linear_index in range(linear_func_pixels): + off_diag[pix_0, linear_index] += ( + data_linear_func_matrix[data_0, linear_index] * data_0_weight + ) + + return off_diag + + +@numba_util.jit() +def convolve_with_kernel_native(curvature_native, psf_kernel): + """ + Convolve each function slice of curvature_native with psf_kernel using direct sliding window. + + Parameters + ---------- + curvature_native : ndarray (ny, nx, n_funcs) + Curvature weights expanded to the native grid, 0 in masked regions. + psf_kernel : ndarray (ky, kx) + The PSF kernel. + + Returns + ------- + blurred_native : ndarray (ny, nx, n_funcs) + The curvature weights convolved with the PSF. + """ + ny, nx, n_funcs = curvature_native.shape + ky, kx = psf_kernel.shape + cy, cx = ky // 2, kx // 2 # kernel center + + blurred_native = np.zeros_like(curvature_native) + + for f in range(n_funcs): # parallelize over functions + for y in range(ny): + for x in range(nx): + acc = 0.0 + for dy in range(ky): + for dx in range(kx): + yy = y + dy - cy + xx = x + dx - cx + if 0 <= yy < ny and 0 <= xx < nx: + acc += psf_kernel[dy, dx] * curvature_native[yy, xx, f] + blurred_native[y, x, f] = acc + return blurred_native + + +@numba_util.jit() +def curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from( + data_to_pix_unique: np.ndarray, + data_weights: np.ndarray, + pix_lengths: np.ndarray, + pix_pixels: int, + curvature_weights: np.ndarray, # shape (n_unmasked, n_funcs) + mask: np.ndarray, # shape (ny, nx), bool + psf_kernel: np.ndarray, # shape (ky, kx) +) -> np.ndarray: + """ + Returns the off-diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) + between a mapper object and a linear func object, using the unique mappings between + data pixels and pixelization pixels. + + This version applies the PSF directly as a 2D convolution kernel. The curvature + weights of the linear function object (values of the linear function divided by the + noise-map squared) are expanded into the native 2D image grid, convolved with the PSF + kernel, and then remapped back to the 1D slim representation. + + For each unique mapping between a data pixel and a pixelization pixel, the convolved + curvature weights at that data pixel are multiplied by the mapping weights and + accumulated into the off-diagonal block of the curvature matrix. This accounts for + sub-pixel mappings between data pixels and pixelization pixels. + + Parameters + ---------- + data_to_pix_unique + An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) + to its unique set of pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`). + data_weights + For every unique mapping between a set of data sub-pixels and a pixelization pixel, + the weight of this mapping based on the number of sub-pixels that map to the pixelization pixel. + pix_lengths + A 1D array describing how many unique pixels each data pixel maps to. Used to iterate over + `data_to_pix_unique` and `data_weights`. + pix_pixels + The total number of pixels in the pixelization that reconstructs the data. + curvature_weights + The operated values of the linear function divided by the noise-map squared, with shape + [n_unmasked_data_pixels, n_linear_func_pixels]. + mask + A 2D boolean mask of shape (ny, nx) indicating which pixels are in the data region. + psf_kernel + The PSF kernel in its native 2D form, centered (odd dimensions recommended). + + Returns + ------- + ndarray + The off-diagonal block of the curvature matrix `F` (see Warren & Dye 2003), + with shape [pix_pixels, n_linear_func_pixels]. + """ + data_pixels = data_weights.shape[0] + n_funcs = curvature_weights.shape[1] + ny, nx = mask.shape + + # Expand curvature weights into native grid + curvature_native = np.zeros((ny, nx, n_funcs)) + unmasked_coords = np.argwhere(~mask) + for i, (y, x) in enumerate(unmasked_coords): + for f in range(n_funcs): + curvature_native[y, x, f] = curvature_weights[i, f] + + # Convolve in native space + blurred_native = convolve_with_kernel_native(curvature_native, psf_kernel) + + # Map back to slim representation + blurred_slim = np.zeros((data_pixels, n_funcs)) + for i, (y, x) in enumerate(unmasked_coords): + for f in range(n_funcs): + blurred_slim[i, f] = blurred_native[y, x, f] + + # Accumulate into off_diag + off_diag = np.zeros((pix_pixels, n_funcs)) + for data_0 in range(data_pixels): + for pix_0_index in range(pix_lengths[data_0]): + data_0_weight = data_weights[data_0, pix_0_index] + pix_0 = data_to_pix_unique[data_0, pix_0_index] + for f in range(n_funcs): + off_diag[pix_0, f] += data_0_weight * blurred_slim[data_0, f] + + return off_diag + + +@numba_util.jit() +def mapped_reconstructed_data_via_image_to_pix_unique_from( + data_to_pix_unique: np.ndarray, + data_weights: np.ndarray, + pix_lengths: np.ndarray, + reconstruction: np.ndarray, +) -> np.ndarray: + """ + Returns the reconstructed data vector from the blurred mapping matrix `f` and solution vector *S*. + + Parameters + ---------- + mapping_matrix + The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. + + """ + + data_pixels = data_to_pix_unique.shape[0] + + mapped_reconstructed_data = np.zeros(data_pixels) + + for data_0 in range(data_pixels): + for pix_0 in range(pix_lengths[data_0]): + pix_for_data = data_to_pix_unique[data_0, pix_0] + + mapped_reconstructed_data[data_0] += ( + data_weights[data_0, pix_0] * reconstruction[pix_for_data] + ) + + return mapped_reconstructed_data diff --git a/autoarray/inversion/inversion/imaging/inversion_imaging_util.py b/autoarray/inversion/inversion/imaging/inversion_imaging_util.py index 1be13b005..fe82ce398 100644 --- a/autoarray/inversion/inversion/imaging/inversion_imaging_util.py +++ b/autoarray/inversion/inversion/imaging/inversion_imaging_util.py @@ -1,14 +1,64 @@ +import jax.numpy as jnp + import numpy as np -from scipy.linalg import cho_solve -from typing import List, Optional, Tuple +from scipy.signal import fftconvolve + + +def psf_operator_matrix_dense_from( + kernel_native: np.ndarray, + native_index_for_slim_index: np.ndarray, # shape (N_pix, 2), native (y,x) coords of masked pixels + native_shape: tuple[int, int], + correlate: bool = True, +) -> np.ndarray: + """ + Construct a dense PSF operator W (N_pix x N_pix) that maps masked image pixels to masked image pixels. + + Parameters + ---------- + kernel_native : (Ky, Kx) PSF kernel. + native_index_for_slim_index : (N_pix, 2) array of int + Native (y, x) coords for each masked pixel. + native_shape : (Ny, Nx) + Native 2D image shape. + correlate : bool, default True + If True, use correlation convention (no kernel flip). + If False, use convolution convention (flip kernel). + + Returns + ------- + W : ndarray, shape (N_pix, N_pix) + Dense PSF operator. + """ + Ky, Kx = kernel_native.shape + ph, pw = Ky // 2, Kx // 2 + Ny, Nx = native_shape + N_pix = native_index_for_slim_index.shape[0] + + ker = kernel_native if correlate else kernel_native[::-1, ::-1] -from autoarray.inversion.inversion.settings import SettingsInversion + # Padded index grid: -1 everywhere, slim index where masked + index_padded = -np.ones((Ny + 2 * ph, Nx + 2 * pw), dtype=np.int64) + for p, (y, x) in enumerate(native_index_for_slim_index): + index_padded[y + ph, x + pw] = p -from autoarray import numba_util -from autoarray import exc + # Neighborhood offsets + dy = np.arange(Ky) - ph + dx = np.arange(Kx) - pw + + W = np.zeros((N_pix, N_pix), dtype=float) + + for i, (y, x) in enumerate(native_index_for_slim_index): + yp = y + ph + xp = x + pw + for j, dy_ in enumerate(dy): + for k, dx_ in enumerate(dx): + neigh = index_padded[yp + dy_, xp + dx_] + if neigh >= 0: + W[i, neigh] += ker[j, k] + + return W -@numba_util.jit() def w_tilde_data_imaging_from( image_native: np.ndarray, noise_map_native: np.ndarray, @@ -48,350 +98,37 @@ def w_tilde_data_imaging_from( efficient calculation of the data vector. """ - kernel_shift_y = -(kernel_native.shape[1] // 2) - kernel_shift_x = -(kernel_native.shape[0] // 2) - - image_pixels = len(native_index_for_slim_index) - - w_tilde_data = np.zeros((image_pixels,)) - - weight_map_native = image_native / noise_map_native**2.0 - - for ip0 in range(image_pixels): - ip0_y, ip0_x = native_index_for_slim_index[ip0] - - value = 0.0 - - for k0_y in range(kernel_native.shape[0]): - for k0_x in range(kernel_native.shape[1]): - weight_value = weight_map_native[ - ip0_y + k0_y + kernel_shift_y, ip0_x + k0_x + kernel_shift_x - ] - - if not np.isnan(weight_value): - value += kernel_native[k0_y, k0_x] * weight_value - - w_tilde_data[ip0] = value - - return w_tilde_data - - -@numba_util.jit() -def w_tilde_curvature_imaging_from( - noise_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index -) -> np.ndarray: - """ - The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF - convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the - curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the - PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging - datasets. - - The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, - making it impossible to store in memory and its use in linear algebra calculations extremely. The method - `w_tilde_curvature_preload_imaging_from` describes a compressed representation that overcomes this hurdles. It is - advised `w_tilde` and this method are only used for testing. - - Parameters - ---------- - noise_map_native - The two dimensional masked noise-map of values which w_tilde is computed from. - kernel_native - The two dimensional PSF kernel that w_tilde encodes the convolution of. - native_index_for_slim_index - An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. - - Returns - ------- - ndarray - A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of - the curvature matrix. - """ - image_pixels = len(native_index_for_slim_index) - - w_tilde_curvature = np.zeros((image_pixels, image_pixels)) - - for ip0 in range(w_tilde_curvature.shape[0]): - ip0_y, ip0_x = native_index_for_slim_index[ip0] - - for ip1 in range(ip0, w_tilde_curvature.shape[1]): - ip1_y, ip1_x = native_index_for_slim_index[ip1] - - w_tilde_curvature[ip0, ip1] += w_tilde_curvature_value_from( - value_native=noise_map_native, - kernel_native=kernel_native, - ip0_y=ip0_y, - ip0_x=ip0_x, - ip1_y=ip1_y, - ip1_x=ip1_x, - ) - - for ip0 in range(w_tilde_curvature.shape[0]): - for ip1 in range(ip0, w_tilde_curvature.shape[1]): - w_tilde_curvature[ip1, ip0] = w_tilde_curvature[ip0, ip1] - - return w_tilde_curvature - - -@numba_util.jit() -def w_tilde_curvature_preload_imaging_from( - noise_map_native: np.ndarray, kernel_native: np.ndarray, native_index_for_slim_index -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF - convolution of every pair of image pixels on the noise map. This can be used to efficiently compute the - curvature matrix via the mappings between image and source pixels, in a way that omits having to repeat the PSF - convolution on every individual source pixel. This provides a significant speed up for inversions of imaging - datasets. - - The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, - making it impossible to store in memory and its use in linear algebra calculations slow. This methods creates - a sparse matrix that can compute the matrix `w_tilde_curvature` efficiently, albeit the linear algebra calculations - in PyAutoArray bypass this matrix entirely to go straight to the curvature matrix. - - for dataset data, w_tilde is a sparse matrix, whereby non-zero entries are only contained for pairs of image pixels - where the two pixels overlap due to the kernel size. For example, if the kernel size is (11, 11) and two image - pixels are separated by more than 20 pixels, the kernel will never convolve flux between the two pixels. Two image - pixels will only share a convolution if they are within `kernel_overlap_size = 2 * kernel_shape - 1` pixels within - one another. - - Thus, a `w_tilde_curvature_preload` matrix of dimensions [image_pixels, kernel_overlap_size ** 2] can be computed - which significantly reduces the memory consumption by removing the sparsity. Because the dimensions of the second - axes is no longer `image_pixels`, a second matrix `w_tilde_indexes` must also be computed containing the slim image - pixel indexes of every entry of `w_tilde_preload`. - - In order for the preload to store half the number of values, owing to the symmetry of the `w_tilde_curvature` - matrix, the image pixel pairs corresponding to the same image pixel are divided by two. This ensures that when the - curvature matrix is computed these pixels are not double-counted. - - The values stored in `w_tilde_curvature_preload` represent the convolution of overlapping noise-maps given the - PSF kernel. It is common for many values to be neglibly small. Removing these values can speed up the inversion - and reduce memory at the expense of a numerically irrelevent change of solution. - - This matrix can then be used to compute the `curvature_matrix` in a memory efficient way that exploits the sparsity - of the linear algebra. - - Parameters - ---------- - noise_map_native - The two dimensional masked noise-map of values which `w_tilde_curvature` is computed from. - signal_to_noise_map_native - The two dimensional masked signal-to-noise-map from which the threshold discarding low S/N image pixel - pairs is used. - kernel_native - The two dimensional PSF kernel that `w_tilde_curvature` encodes the convolution of. - native_index_for_slim_index - An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array. - - Returns - ------- - ndarray - A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of - the curvature matrix, where the dimensions are reduced to save memory. - """ - - image_pixels = len(native_index_for_slim_index) - - kernel_overlap_size = (2 * kernel_native.shape[0] - 1) * ( - 2 * kernel_native.shape[1] - 1 + # 1) weight map = image / noise^2 (safe where noise==0) + weight_map = jnp.where( + noise_map_native > 0.0, image_native / (noise_map_native**2), 0.0 ) - curvature_preload_tmp = np.zeros((image_pixels, kernel_overlap_size)) - curvature_indexes_tmp = np.zeros((image_pixels, kernel_overlap_size)) - curvature_lengths = np.zeros(image_pixels) - - for ip0 in range(image_pixels): - ip0_y, ip0_x = native_index_for_slim_index[ip0] - - kernel_index = 0 - - for ip1 in range(ip0, curvature_preload_tmp.shape[0]): - ip1_y, ip1_x = native_index_for_slim_index[ip1] - - noise_value = w_tilde_curvature_value_from( - value_native=noise_map_native, - kernel_native=kernel_native, - ip0_y=ip0_y, - ip0_x=ip0_x, - ip1_y=ip1_y, - ip1_x=ip1_x, - ) - - if ip0 == ip1: - noise_value /= 2.0 - - if noise_value > 0.0: - curvature_preload_tmp[ip0, kernel_index] = noise_value - curvature_indexes_tmp[ip0, kernel_index] = ip1 - kernel_index += 1 - - curvature_lengths[ip0] = kernel_index - - curvature_total_pairs = int(np.sum(curvature_lengths)) - - curvature_preload = np.zeros((curvature_total_pairs)) - curvature_indexes = np.zeros((curvature_total_pairs)) - - index = 0 - - for i in range(image_pixels): - for data_index in range(int(curvature_lengths[i])): - curvature_preload[index] = curvature_preload_tmp[i, data_index] - curvature_indexes[index] = curvature_indexes_tmp[i, data_index] - - index += 1 - - return (curvature_preload, curvature_indexes, curvature_lengths) - - -@numba_util.jit() -def w_tilde_curvature_value_from( - value_native: np.ndarray, - kernel_native: np.ndarray, - ip0_y, - ip0_x, - ip1_y, - ip1_x, - renormalize=False, -) -> float: - """ - Compute the value of an entry of the `w_tilde_curvature` matrix, where this entry encodes the PSF convolution of - the noise-map between two image pixels. - - The calculation is performed by over-laying the PSF kernel over two noise-map pixels in 2D. For all pixels where - the two overlaid PSF kernels overlap, the following calculation is performed for every noise map value: - - `value = kernel_value_0 * kernel_value_1 * (1.0 / noise_value) ** 2.0` - - This calculation infers the fraction of flux that every PSF convolution will move between each pair of noise-map - pixels and can therefore be used to efficiently calculate the curvature_matrix that is used in the linear algebra - calculation of an inversion. - - The sum of all values where kernel pixels overlap is returned to give the `w_tilde` value. - - Parameters - ---------- - value_native - A two dimensional masked array of values (e.g. a noise-map, signal to noise map) which the w_tilde curvature - values are computed from. - kernel_native - The two dimensional PSF kernel that w_tilde encodes the convolution of. - ip0_y - The y index of the first image pixel in the image pixel pair. - ip0_x - The x index of the first image pixel in the image pixel pair. - ip1_y - The y index of the second image pixel in the image pixel pair. - ip1_x - The x index of the second image pixel in the image pixel pair. - - Returns - ------- - float - The w_tilde value that encodes the value of PSF convolution between a pair of image pixels. - - """ - - curvature_value = 0.0 - - kernel_shift_y = -(kernel_native.shape[1] // 2) - kernel_shift_x = -(kernel_native.shape[0] // 2) - - ip_y_offset = ip0_y - ip1_y - ip_x_offset = ip0_x - ip1_x - - if ( - ip_y_offset < 2 * kernel_shift_y - or ip_y_offset > -2 * kernel_shift_y - or ip_x_offset < 2 * kernel_shift_x - or ip_x_offset > -2 * kernel_shift_x - ): - return curvature_value - - kernel_pixels = kernel_native.shape[0] * kernel_native.shape[1] - kernel_count = 0 - - for k0_y in range(kernel_native.shape[0]): - for k0_x in range(kernel_native.shape[1]): - value = value_native[ - ip0_y + k0_y + kernel_shift_y, ip0_x + k0_x + kernel_shift_x - ] - - if value > 0.0: - k1_y = k0_y + ip_y_offset - k1_x = k0_x + ip_x_offset - - if ( - k1_y >= 0 - and k1_x >= 0 - and k1_y < kernel_native.shape[0] - and k1_x < kernel_native.shape[1] - ): - kernel_count += 1 - - kernel_value_0 = kernel_native[k0_y, k0_x] - kernel_value_1 = kernel_native[k1_y, k1_x] - - curvature_value += ( - kernel_value_0 * kernel_value_1 * (1.0 / value) ** 2.0 - ) - - if renormalize: - if kernel_count > 0: - curvature_value *= kernel_pixels / kernel_count - - return curvature_value - - -@numba_util.jit() -def data_vector_via_w_tilde_data_imaging_from( - w_tilde_data: np.ndarray, - data_to_pix_unique: np.ndarray, - data_weights: np.ndarray, - pix_lengths: np.ndarray, - pix_pixels: int, -) -> np.ndarray: - """ - Returns the data vector `D` from the `w_tilde_data` matrix (see `w_tilde_data_imaging_from`), which encodes the - the 1D image `d` and 1D noise-map values `\sigma` (see Warren & Dye 2003). + Ky, Kx = kernel_native.shape + ph, pw = Ky // 2, Kx // 2 - This uses the array `data_to_pix_unique`, which describes the unique mappings of every set of image sub-pixels to - pixelization pixels and `data_weights`, which describes how many sub-pixels uniquely map to each pixelization - pixels (see `data_slim_to_pixelization_unique_from`). - - Parameters - ---------- - w_tilde_data - A matrix that encodes the PSF convolution values between the imaging divided by the noise map**2 that enables - efficient calculation of the data vector. - data_to_pix_unique - An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of - pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`). - data_weights - For every unique mapping between a set of data sub-pixels and a pixelization pixel, the weight of these mapping - based on the number of sub-pixels that map to pixelization pixel. - pix_lengths - A 1D array describing how many unique pixels each data pixel maps too, which is used to iterate over - `data_to_pix_unique` and `data_weights`. - pix_pixels - The total number of pixels in the pixelization that reconstructs the data. - """ - - data_pixels = w_tilde_data.shape[0] + # 2) pad so neighbourhood gathers never go OOB + padded = jnp.pad( + weight_map, ((ph, ph), (pw, pw)), mode="constant", constant_values=0.0 + ) - data_vector = np.zeros(pix_pixels) + # 3) build broadcasted neighbourhood indices for all requested pixels + # shift pixel coords into the padded frame + ys = native_index_for_slim_index[:, 0] + ph # (N,) + xs = native_index_for_slim_index[:, 1] + pw # (N,) - for data_0 in range(data_pixels): - for pix_0_index in range(pix_lengths[data_0]): - data_0_weight = data_weights[data_0, pix_0_index] - pix_0 = data_to_pix_unique[data_0, pix_0_index] + # kernel-relative offsets + dy = jnp.arange(Ky) - ph # (Ky,) + dx = jnp.arange(Kx) - pw # (Kx,) - data_vector[pix_0] += data_0_weight * w_tilde_data[data_0] + # broadcast to (N, Ky, Kx) + Y = ys[:, None, None] + dy[None, :, None] + X = xs[:, None, None] + dx[None, None, :] - return data_vector + # 4) gather patches and correlate (no kernel flip) + patches = padded[Y, X] # (N, Ky, Kx) + return jnp.sum(patches * kernel_native[None, :, :], axis=(1, 2)) # (N,) -@numba_util.jit() def data_vector_via_blurred_mapping_matrix_from( blurred_mapping_matrix: np.ndarray, image: np.ndarray, noise_map: np.ndarray ) -> np.ndarray: @@ -408,210 +145,11 @@ def data_vector_via_blurred_mapping_matrix_from( noise_map Flattened 1D array of the noise-map used by the inversion during the fit. """ + return (image / noise_map**2.0) @ blurred_mapping_matrix - data_shape = blurred_mapping_matrix.shape - - data_vector = np.zeros(data_shape[1]) - - for data_index in range(data_shape[0]): - for pix_index in range(data_shape[1]): - data_vector[pix_index] += ( - image[data_index] - * blurred_mapping_matrix[data_index, pix_index] - / (noise_map[data_index] ** 2.0) - ) - - return data_vector - - -@numba_util.jit() -def curvature_matrix_via_w_tilde_curvature_preload_imaging_from( - curvature_preload: np.ndarray, - curvature_indexes: np.ndarray, - curvature_lengths: np.ndarray, - data_to_pix_unique: np.ndarray, - data_weights: np.ndarray, - pix_lengths: np.ndarray, - pix_pixels: int, -) -> np.ndarray: - """ - Returns the curvature matrix `F` (see Warren & Dye 2003) by computing it using `w_tilde_preload` - (see `w_tilde_preload_interferometer_from`) for an imaging inversion. - - To compute the curvature matrix via w_tilde the following matrix multiplication is normally performed: - - curvature_matrix = mapping_matrix.T * w_tilde * mapping matrix - - This function speeds this calculation up in two ways: - - 1) Instead of using `w_tilde` (dimensions [image_pixels, image_pixels] it uses `w_tilde_preload` (dimensions - [image_pixels, kernel_overlap]). The massive reduction in the size of this matrix in memory allows for much fast - computation. - - 2) It omits the `mapping_matrix` and instead uses directly the 1D vector that maps every image pixel to a source - pixel `native_index_for_slim_index`. This exploits the sparsity in the `mapping_matrix` to directly - compute the `curvature_matrix` (e.g. it condenses the triple matrix multiplication into a double for loop!). - - Parameters - ---------- - curvature_preload - A matrix that precomputes the values for fast computation of the curvature matrix in a memory efficient way. - curvature_indexes - The image-pixel indexes of the values stored in the w tilde preload matrix, which are used to compute - the weights of the data values when computing the curvature matrix. - curvature_lengths - The number of image pixels in every row of `w_tilde_curvature`, which is iterated over when computing the - curvature matrix. - data_to_pix_unique - An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of - pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`). - data_weights - For every unique mapping between a set of data sub-pixels and a pixelization pixel, the weight of these mapping - based on the number of sub-pixels that map to pixelization pixel. - pix_lengths - A 1D array describing how many unique pixels each data pixel maps too, which is used to iterate over - `data_to_pix_unique` and `data_weights`. - pix_pixels - The total number of pixels in the pixelization that reconstructs the data. - - Returns - ------- - ndarray - The curvature matrix `F` (see Warren & Dye 2003). - """ - - data_pixels = curvature_lengths.shape[0] - - curvature_matrix = np.zeros((pix_pixels, pix_pixels)) - - curvature_index = 0 - - for data_0 in range(data_pixels): - for data_1_index in range(curvature_lengths[data_0]): - data_1 = curvature_indexes[curvature_index] - w_tilde_value = curvature_preload[curvature_index] - - for pix_0_index in range(pix_lengths[data_0]): - data_0_weight = data_weights[data_0, pix_0_index] - pix_0 = data_to_pix_unique[data_0, pix_0_index] - - for pix_1_index in range(pix_lengths[data_1]): - data_1_weight = data_weights[data_1, pix_1_index] - pix_1 = data_to_pix_unique[data_1, pix_1_index] - - curvature_matrix[pix_0, pix_1] += ( - data_0_weight * data_1_weight * w_tilde_value - ) - - curvature_index += 1 - - for i in range(pix_pixels): - for j in range(i, pix_pixels): - curvature_matrix[i, j] += curvature_matrix[j, i] - - for i in range(pix_pixels): - for j in range(i, pix_pixels): - curvature_matrix[j, i] = curvature_matrix[i, j] - - return curvature_matrix - - -@numba_util.jit() -def curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( - curvature_preload: np.ndarray, - curvature_indexes: np.ndarray, - curvature_lengths: np.ndarray, - data_to_pix_unique_0: np.ndarray, - data_weights_0: np.ndarray, - pix_lengths_0: np.ndarray, - pix_pixels_0: int, - data_to_pix_unique_1: np.ndarray, - data_weights_1: np.ndarray, - pix_lengths_1: np.ndarray, - pix_pixels_1: int, -) -> np.ndarray: - """ - Returns the off diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) by computing them - using `w_tilde_preload` (see `w_tilde_preload_interferometer_from`) for an imaging inversion. - - When there is more than one mapper in the inversion, its `mapping_matrix` is extended to have dimensions - [data_pixels, sum(source_pixels_in_each_mapper)]. The curvature matrix therefore will have dimensions - [sum(source_pixels_in_each_mapper), sum(source_pixels_in_each_mapper)]. - - To compute the curvature matrix via w_tilde the following matrix multiplication is normally performed: - - curvature_matrix = mapping_matrix.T * w_tilde * mapping matrix - - When the `mapping_matrix` consists of multiple mappers from different planes, this means that shared data mappings - between source-pixels in different mappers must be accounted for when computing the `curvature_matrix`. These - appear as off-diagonal terms in the overall curvature matrix. - - This function evaluates these off-diagonal terms, by using the w-tilde curvature preloads and the unique - data-to-pixelization mappings of each mapper. It behaves analogous to the - function `curvature_matrix_via_w_tilde_curvature_preload_imaging_from`. - - Parameters - ---------- - curvature_preload - A matrix that precomputes the values for fast computation of the curvature matrix in a memory efficient way. - curvature_indexes - The image-pixel indexes of the values stored in the w tilde preload matrix, which are used to compute - the weights of the data values when computing the curvature matrix. - curvature_lengths - The number of image pixels in every row of `w_tilde_curvature`, which is iterated over when computing the - curvature matrix. - data_to_pix_unique - An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of - pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`). - data_weights - For every unique mapping between a set of data sub-pixels and a pixelization pixel, the weight of these mapping - based on the number of sub-pixels that map to pixelization pixel. - pix_lengths - A 1D array describing how many unique pixels each data pixel maps too, which is used to iterate over - `data_to_pix_unique` and `data_weights`. - pix_pixels - The total number of pixels in the pixelization that reconstructs the data. - - Returns - ------- - ndarray - The curvature matrix `F` (see Warren & Dye 2003). - """ - data_pixels = curvature_lengths.shape[0] - - curvature_matrix = np.zeros((pix_pixels_0, pix_pixels_1)) - - curvature_index = 0 - - for data_0 in range(data_pixels): - for data_1_index in range(curvature_lengths[data_0]): - data_1 = curvature_indexes[curvature_index] - w_tilde_value = curvature_preload[curvature_index] - - for pix_0_index in range(pix_lengths_0[data_0]): - data_0_weight = data_weights_0[data_0, pix_0_index] - pix_0 = data_to_pix_unique_0[data_0, pix_0_index] - - for pix_1_index in range(pix_lengths_1[data_1]): - data_1_weight = data_weights_1[data_1, pix_1_index] - pix_1 = data_to_pix_unique_1[data_1, pix_1_index] - - curvature_matrix[pix_0, pix_1] += ( - data_0_weight * data_1_weight * w_tilde_value - ) - - curvature_index += 1 - - return curvature_matrix - - -@numba_util.jit() def data_linear_func_matrix_from( - curvature_weights_matrix: np.ndarray, - image_frame_1d_lengths: np.ndarray, - image_frame_1d_indexes: np.ndarray, - image_frame_1d_kernels: np.ndarray, + curvature_weights_matrix: np.ndarray, kernel_native, mask ) -> np.ndarray: """ Returns a matrix that for each data pixel, maps it to the sum of the values of a linear object function convolved @@ -640,12 +178,8 @@ def data_linear_func_matrix_from( curvature_weights_matrix The operated values of each linear function divided by the noise-map squared, in a matrix of shape [data_pixels, total_fixed_linear_functions]. - image_frame_indexes - The indexes of all masked pixels that the PSF blurs light into (see the `Convolver` object). - image_frame_kernels - The kernel values of all masked pixels that the PSF blurs light into (see the `Convolver` object). - image_frame_length - The number of masked pixels it will blur light into (unmasked pixels are excluded, see the `Convolver` object). + kernel_native + The 2D PSf kernel. Returns ------- @@ -653,149 +187,21 @@ def data_linear_func_matrix_from( A matrix of shape [data_pixels, total_fixed_linear_functions] that for each data pixel, maps it to the sum of the values of a linear object function convolved with the PSF kernel at the data pixel. """ - data_pixels = curvature_weights_matrix.shape[0] - linear_func_pixels = curvature_weights_matrix.shape[1] - - data_linear_func_matrix_dict = np.zeros(shape=(data_pixels, linear_func_pixels)) - - for data_0 in range(data_pixels): - for psf_index in range(image_frame_1d_lengths[data_0]): - data_index = image_frame_1d_indexes[data_0, psf_index] - kernel_value = image_frame_1d_kernels[data_0, psf_index] - - for linear_index in range(linear_func_pixels): - data_linear_func_matrix_dict[data_0, linear_index] += ( - kernel_value * curvature_weights_matrix[data_index, linear_index] - ) - - return data_linear_func_matrix_dict - - -@numba_util.jit() -def curvature_matrix_off_diags_via_data_linear_func_matrix_from( - data_linear_func_matrix: np.ndarray, - data_to_pix_unique: np.ndarray, - data_weights: np.ndarray, - pix_lengths: np.ndarray, - pix_pixels: int, -): - """ - Returns the off diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) between a mapper object - and a linear func object, using the preloaded `data_linear_func_matrix` of the values of the linear functions. - - - If a linear function in an inversion is fixed, its values can be evaluated and preloaded beforehand. For every - data pixel, the PSF convolution with this preloaded linear function can also be preloaded, in a matrix of - shape [data_pixels, 1]. - - When mapper objects and linear functions are used simultaneously in an inversion, this preloaded matrix - significantly speed up the computation of their off-diagonal terms in the curvature matrix. - - This function performs this efficient calcluation via the preloaded `data_linear_func_matrix`. - - Parameters - ---------- - data_linear_func_matrix - A matrix of shape [data_pixels, total_fixed_linear_functions] that for each data pixel, maps it to the sum of - the values of a linear object function convolved with the PSF kernel at the data pixel. - data_to_pix_unique - The indexes of all pixels that each data pixel maps to (see the `Mapper` object). - data_weights - The weights of all pixels that each data pixel maps to (see the `Mapper` object). - pix_lengths - The number of pixelization pixels that each data pixel maps to (see the `Mapper` object). - pix_pixels - The number of pixelization pixels in the pixelization (see the `Mapper` object). - """ - - linear_func_pixels = data_linear_func_matrix.shape[1] - - off_diag = np.zeros((pix_pixels, linear_func_pixels)) - - data_pixels = data_weights.shape[0] - - for data_0 in range(data_pixels): - for pix_0_index in range(pix_lengths[data_0]): - data_0_weight = data_weights[data_0, pix_0_index] - pix_0 = data_to_pix_unique[data_0, pix_0_index] - - for linear_index in range(linear_func_pixels): - off_diag[pix_0, linear_index] += ( - data_linear_func_matrix[data_0, linear_index] * data_0_weight - ) - - return off_diag - - -@numba_util.jit() -def curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from( - data_to_pix_unique: np.ndarray, - data_weights: np.ndarray, - pix_lengths: np.ndarray, - pix_pixels: int, - curvature_weights: np.ndarray, - image_frame_1d_lengths: np.ndarray, - image_frame_1d_indexes: np.ndarray, - image_frame_1d_kernels: np.ndarray, -) -> np.ndarray: - """ - Returns the off diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) between a mapper object - and a linear func object, using the unique mappings between data pixels and pixelization pixels. - - This takes as input the curvature weights of the linear function object, which are the values of the linear - function convolved with the PSF and divided by the noise-map squared. - - For each unique mapping between a data pixel and a pixelization pixel, the pixels which that pixel convolves - light into are computed, multiplied by their corresponding curvature weights and summed. This process also - accounts the sub-pixel mapping of each data pixel to the pixelization pixel - - This is done for every unique mapping of a data pixel to a pixelization pixel, giving the off-diagonal terms in - the curvature matrix. - - Parameters - ---------- - data_to_pix_unique - An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D) to its unique set of - pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`). - data_weights - For every unique mapping between a set of data sub-pixels and a pixelization pixel, the weight of these mapping - based on the number of sub-pixels that map to pixelization pixel. - pix_lengths - A 1D array describing how many unique pixels each data pixel maps too, which is used to iterate over - `data_to_pix_unique` and `data_weights`. - pix_pixels - The total number of pixels in the pixelization that reconstructs the data. - curvature_weights - The operated values of the linear func divided by the noise-map squared. - image_frame_indexes - The indexes of all masked pixels that the PSF blurs light into (see the `Convolver` object). - image_frame_kernels - The kernel values of all masked pixels that the PSF blurs light into (see the `Convolver` object). - image_frame_length - The number of masked pixels it will blur light into (unmasked pixels are excluded, see the `Convolver` object). - - Returns - ------- - ndarray - The curvature matrix `F` (see Warren & Dye 2003). - """ - - data_pixels = data_weights.shape[0] - linear_func_pixels = curvature_weights.shape[1] - off_diag = np.zeros((pix_pixels, linear_func_pixels)) + ny, nx = mask.shape_native + n_unmasked, n_funcs = curvature_weights_matrix.shape - for data_0 in range(data_pixels): - for pix_0_index in range(pix_lengths[data_0]): - data_0_weight = data_weights[data_0, pix_0_index] - pix_0 = data_to_pix_unique[data_0, pix_0_index] + # Expand masked -> native grid + native = np.zeros((ny, nx, n_funcs)) + native[~mask] = curvature_weights_matrix # put values into unmasked positions - for psf_index in range(image_frame_1d_lengths[data_0]): - data_index = image_frame_1d_indexes[data_0, psf_index] - kernel_value = image_frame_1d_kernels[data_0, psf_index] + # Convolve each function with PSF kernel + from scipy.signal import fftconvolve - off_diag[pix_0, :] += ( - data_0_weight * curvature_weights[data_index, :] * kernel_value - ) + blurred_list = [] + for i in range(n_funcs): + blurred = fftconvolve(native[..., i], kernel_native, mode="same") + # Re-mask: only keep unmasked pixels + blurred_list.append(blurred[~mask]) - return off_diag + return np.stack(blurred_list, axis=1) # shape (n_unmasked, n_funcs) diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 03d73ff63..698750a22 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -1,17 +1,15 @@ -import copy import numpy as np from typing import Dict, List, Optional, Union from autoconf import cached_property -from autoarray.numba_util import profile_func - from autoarray.dataset.imaging.dataset import Imaging from autoarray.inversion.inversion.dataset_interface import DatasetInterface from autoarray.inversion.inversion.imaging.abstract import AbstractInversionImaging from autoarray.inversion.linear_obj.linear_obj import LinearObj from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.inversion.inversion.settings import SettingsInversion +from autoarray.preloads import Preloads from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.inversion.inversion import inversion_util @@ -24,7 +22,7 @@ def __init__( dataset: Union[Imaging, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, + preloads: Preloads = None, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -44,19 +42,16 @@ def __init__( linear_obj_list The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed the simultaneous linear equations are combined and solved simultaneously. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, + preloads=preloads, ) @property - @profile_func def _data_vector_mapper(self) -> np.ndarray: """ Returns the `data_vector` of all mappers, a 1D vector whose values are solved for by the simultaneous @@ -79,7 +74,7 @@ def _data_vector_mapper(self) -> np.ndarray: param_range = mapper_param_range_list[i] operated_mapping_matrix = self.psf.convolve_mapping_matrix( - mapping_matrix=mapper.mapping_matrix + mapping_matrix=mapper.mapping_matrix, mask=self.mask ) data_vector_mapper = ( @@ -95,7 +90,6 @@ def _data_vector_mapper(self) -> np.ndarray: return data_vector @cached_property - @profile_func def data_vector(self) -> np.ndarray: """ The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed @@ -109,15 +103,13 @@ def data_vector(self) -> np.ndarray: The calculation is described in more detail in `inversion_util.data_vector_via_blurred_mapping_matrix_from`. """ - return inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from( blurred_mapping_matrix=self.operated_mapping_matrix, - image=np.array(self.data), - noise_map=np.array(self.noise_map), + image=self.data.array, + noise_map=self.noise_map.array, ) @property - @profile_func def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: """ Returns the diagonal regions of the `curvature_matrix`, a 2D matrix which uses the mappings between the data @@ -141,7 +133,7 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: mapper_param_range_i = mapper_param_range_list[i] operated_mapping_matrix = self.psf.convolve_mapping_matrix( - mapping_matrix=mapper_i.mapping_matrix + mapping_matrix=mapper_i.mapping_matrix, mask=self.mask ) diag = inversion_util.curvature_matrix_via_mapping_matrix_from( @@ -164,7 +156,6 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: return curvature_matrix @cached_property - @profile_func def curvature_matrix(self): """ The `curvature_matrix` is a 2D matrix which uses the mappings between the data and the linear objects to @@ -192,7 +183,6 @@ def curvature_matrix(self): ) @property - @profile_func def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: """ When constructing the simultaneous linear equations (via vectors and matrices) the quantities of each individual diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index 199ba66b2..1e67270a0 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -1,11 +1,9 @@ -import copy +import jax.numpy as jnp import numpy as np from typing import Dict, List, Optional, Union from autoconf import cached_property -from autoarray.numba_util import profile_func - from autoarray.dataset.imaging.dataset import Imaging from autoarray.dataset.imaging.w_tilde import WTildeImaging from autoarray.inversion.inversion.dataset_interface import DatasetInterface @@ -16,8 +14,8 @@ from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.inversion.inversion import inversion_util -from autoarray.inversion.inversion.imaging import inversion_imaging_util +from autoarray import exc +from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util class InversionImagingWTilde(AbstractInversionImaging): @@ -27,7 +25,6 @@ def __init__( w_tilde: WTildeImaging, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -50,15 +47,23 @@ def __init__( linear_obj_list The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed the simultaneous linear equations are combined and solved simultaneously. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ + try: + import numba + except ModuleNotFoundError: + raise exc.InversionException( + "Inversion functionality (linear light profiles, pixelized reconstructions) is " + "disabled if numba is not installed.\n\n" + "This is because the run-times without numba are too slow.\n\n" + "Please install numba, which is described at the following web page:\n\n" + "https://pyautolens.readthedocs.io/en/latest/installation/overview.html" + ) + super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, ) if self.settings.use_w_tilde: @@ -68,17 +73,16 @@ def __init__( self.w_tilde = None @cached_property - @profile_func def w_tilde_data(self): - return inversion_imaging_util.w_tilde_data_imaging_from( - image_native=np.array(self.data.native), - noise_map_native=np.array(self.noise_map.native), - kernel_native=np.array(self.psf.kernel.native), + + return inversion_imaging_numba_util.w_tilde_data_imaging_from( + image_native=np.array(self.data.native.array), + noise_map_native=self.noise_map.native.array, + kernel_native=self.psf.native.array, native_index_for_slim_index=self.data.mask.derive_indexes.native_for_slim, ) @property - @profile_func def _data_vector_mapper(self) -> np.ndarray: """ Returns the `data_vector` of all mappers, a 1D vector whose values are solved for by the simultaneous @@ -98,11 +102,13 @@ def _data_vector_mapper(self) -> np.ndarray: for mapper_index, mapper in enumerate(mapper_list): data_vector_mapper = ( - inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( + inversion_imaging_numba_util.data_vector_via_w_tilde_data_imaging_from( w_tilde_data=self.w_tilde_data, - data_to_pix_unique=mapper.unique_mappings.data_to_pix_unique, - data_weights=mapper.unique_mappings.data_weights, - pix_lengths=mapper.unique_mappings.pix_lengths, + data_to_pix_unique=np.array( + mapper.unique_mappings.data_to_pix_unique + ), + data_weights=np.array(mapper.unique_mappings.data_weights), + pix_lengths=np.array(mapper.unique_mappings.pix_lengths), pix_pixels=mapper.params, ) ) @@ -113,7 +119,6 @@ def _data_vector_mapper(self) -> np.ndarray: return data_vector @cached_property - @profile_func def data_vector(self) -> np.ndarray: """ Returns the `data_vector`, a 1D vector whose values are solved for by the simultaneous linear equations @@ -134,7 +139,6 @@ def data_vector(self) -> np.ndarray: return self._data_vector_multi_mapper @property - @profile_func def _data_vector_x1_mapper(self) -> np.ndarray: """ Returns the `data_vector`, a 1D vector whose values are solved for by the simultaneous linear equations @@ -143,10 +147,9 @@ def _data_vector_x1_mapper(self) -> np.ndarray: This method computes the `data_vector` whenthere is a single mapper object in the `Inversion`, which circumvents `np.concatenate` for speed up. """ - linear_obj = self.linear_obj_list[0] - return inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( + return inversion_imaging_numba_util.data_vector_via_w_tilde_data_imaging_from( w_tilde_data=self.w_tilde_data, data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, data_weights=linear_obj.unique_mappings.data_weights, @@ -155,7 +158,6 @@ def _data_vector_x1_mapper(self) -> np.ndarray: ) @property - @profile_func def _data_vector_multi_mapper(self) -> np.ndarray: """ Returns the `data_vector`, a 1D vector whose values are solved for by the simultaneous linear equations @@ -167,7 +169,7 @@ def _data_vector_multi_mapper(self) -> np.ndarray: return np.concatenate( [ - inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from( + inversion_imaging_numba_util.data_vector_via_w_tilde_data_imaging_from( w_tilde_data=self.w_tilde_data, data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, data_weights=linear_obj.unique_mappings.data_weights, @@ -179,7 +181,6 @@ def _data_vector_multi_mapper(self) -> np.ndarray: ) @property - @profile_func def _data_vector_func_list_and_mapper(self) -> np.ndarray: """ Returns the `data_vector`, a 1D vector whose values are solved for by the simultaneous linear equations @@ -192,7 +193,7 @@ def _data_vector_func_list_and_mapper(self) -> np.ndarray: separation of functions enables the `data_vector` to be preloaded in certain circumstances. """ - data_vector = self._data_vector_mapper + data_vector = np.array(self._data_vector_mapper) linear_func_param_range = self.param_range_list_from( cls=AbstractLinearObjFuncList @@ -205,10 +206,10 @@ def _data_vector_func_list_and_mapper(self) -> np.ndarray: linear_func ] - diag = inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=operated_mapping_matrix, - image=np.array(self.data), - noise_map=np.array(self.noise_map), + diag = inversion_imaging_numba_util.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=np.array(operated_mapping_matrix), + image=self.data.array, + noise_map=self.noise_map.array, ) param_range = linear_func_param_range[linear_func_index] @@ -218,7 +219,6 @@ def _data_vector_func_list_and_mapper(self) -> np.ndarray: return data_vector @cached_property - @profile_func def curvature_matrix(self) -> np.ndarray: """ Returns the `curvature_matrix`, a 2D matrix which uses the mappings between the data and the linear objects to @@ -247,21 +247,22 @@ def curvature_matrix(self) -> np.ndarray: else: curvature_matrix = self._curvature_matrix_multi_mapper - curvature_matrix = inversion_util.curvature_matrix_mirrored_from( + curvature_matrix = inversion_imaging_numba_util.curvature_matrix_mirrored_from( curvature_matrix=curvature_matrix ) if len(self.no_regularization_index_list) > 0: - curvature_matrix = inversion_util.curvature_matrix_with_added_to_diag_from( - curvature_matrix=curvature_matrix, - value=self.settings.no_regularization_add_to_curvature_diag_value, - no_regularization_index_list=self.no_regularization_index_list, + curvature_matrix = ( + inversion_imaging_numba_util.curvature_matrix_with_added_to_diag_from( + curvature_matrix=curvature_matrix, + value=self.settings.no_regularization_add_to_curvature_diag_value, + no_regularization_index_list=self.no_regularization_index_list, + ) ) return curvature_matrix @property - @profile_func def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: """ Returns the diagonal regions of the `curvature_matrix`, a 2D matrix which uses the mappings between the data @@ -284,13 +285,15 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: mapper_i = mapper_list[i] mapper_param_range_i = mapper_param_range_list[i] - diag = inversion_imaging_util.curvature_matrix_via_w_tilde_curvature_preload_imaging_from( + diag = inversion_imaging_numba_util.curvature_matrix_via_w_tilde_curvature_preload_imaging_from( curvature_preload=self.w_tilde.curvature_preload, curvature_indexes=self.w_tilde.indexes, curvature_lengths=self.w_tilde.lengths, - data_to_pix_unique=mapper_i.unique_mappings.data_to_pix_unique, - data_weights=mapper_i.unique_mappings.data_weights, - pix_lengths=mapper_i.unique_mappings.pix_lengths, + data_to_pix_unique=np.array( + mapper_i.unique_mappings.data_to_pix_unique + ), + data_weights=np.array(mapper_i.unique_mappings.data_weights), + pix_lengths=np.array(mapper_i.unique_mappings.pix_lengths), pix_pixels=mapper_i.params, ) @@ -304,7 +307,6 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: return curvature_matrix - @profile_func def _curvature_matrix_off_diag_from( self, mapper_0: AbstractMapper, mapper_1: AbstractMapper ) -> np.ndarray: @@ -318,7 +320,7 @@ def _curvature_matrix_off_diag_from( This function computes the off-diagonal terms of F using the w_tilde formalism. """ - curvature_matrix_off_diag_0 = inversion_imaging_util.curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( + curvature_matrix_off_diag_0 = inversion_imaging_numba_util.curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( curvature_preload=self.w_tilde.curvature_preload, curvature_indexes=self.w_tilde.indexes, curvature_lengths=self.w_tilde.lengths, @@ -332,7 +334,7 @@ def _curvature_matrix_off_diag_from( pix_pixels_1=mapper_1.params, ) - curvature_matrix_off_diag_1 = inversion_imaging_util.curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( + curvature_matrix_off_diag_1 = inversion_imaging_numba_util.curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from( curvature_preload=self.w_tilde.curvature_preload, curvature_indexes=self.w_tilde.indexes, curvature_lengths=self.w_tilde.lengths, @@ -349,7 +351,6 @@ def _curvature_matrix_off_diag_from( return curvature_matrix_off_diag_0 + curvature_matrix_off_diag_1.T @property - @profile_func def _curvature_matrix_x1_mapper(self) -> np.ndarray: """ Returns the `curvature_matrix`, a 2D matrix which uses the mappings between the data and the linear objects to @@ -361,7 +362,6 @@ def _curvature_matrix_x1_mapper(self) -> np.ndarray: return self._curvature_matrix_mapper_diag @property - @profile_func def _curvature_matrix_multi_mapper(self) -> np.ndarray: """ Returns the `curvature_matrix`, a 2D matrix which uses the mappings between the data and the linear objects to @@ -399,7 +399,6 @@ def _curvature_matrix_multi_mapper(self) -> np.ndarray: return curvature_matrix @property - @profile_func def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: """ The `curvature_matrix` is a 2D matrix which uses the mappings between the data and the linear objects to @@ -433,15 +432,14 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: / self.noise_map[:, None] ** 2 ) - off_diag = inversion_imaging_util.curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from( + off_diag = inversion_imaging_numba_util.curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from( data_to_pix_unique=mapper.unique_mappings.data_to_pix_unique, data_weights=mapper.unique_mappings.data_weights, pix_lengths=mapper.unique_mappings.pix_lengths, pix_pixels=mapper.params, - curvature_weights=curvature_weights, - image_frame_1d_lengths=self.convolver.image_frame_1d_lengths, - image_frame_1d_indexes=self.convolver.image_frame_1d_indexes, - image_frame_1d_kernels=self.convolver.image_frame_1d_kernels, + curvature_weights=np.array(curvature_weights), + mask=self.mask.array, + psf_kernel=self.psf.native.array, ) curvature_matrix[ @@ -478,7 +476,6 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray: return curvature_matrix @property - @profile_func def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: """ When constructing the simultaneous linear equations (via vectors and matrices) the quantities of each individual @@ -514,22 +511,23 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: reconstruction = reconstruction_dict[linear_obj] if isinstance(linear_obj, AbstractMapper): - mapped_reconstructed_image = inversion_util.mapped_reconstructed_data_via_image_to_pix_unique_from( + mapped_reconstructed_image = inversion_imaging_numba_util.mapped_reconstructed_data_via_image_to_pix_unique_from( data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique, data_weights=linear_obj.unique_mappings.data_weights, pix_lengths=linear_obj.unique_mappings.pix_lengths, reconstruction=reconstruction, ) - mapped_reconstructed_image = Array2D( - values=mapped_reconstructed_image, mask=self.mask - ) - mapped_reconstructed_image = self.psf.convolve_image_no_blurring( image=mapped_reconstructed_image, mask=self.mask + ).array + + mapped_reconstructed_image = Array2D( + values=mapped_reconstructed_image, mask=self.mask ) else: + operated_mapping_matrix = self.linear_func_operated_mapping_matrix_dict[ linear_obj ] diff --git a/autoarray/inversion/inversion/interferometer/abstract.py b/autoarray/inversion/inversion/interferometer/abstract.py index 47e1c84bf..09e2a01e7 100644 --- a/autoarray/inversion/inversion/interferometer/abstract.py +++ b/autoarray/inversion/inversion/interferometer/abstract.py @@ -11,8 +11,6 @@ from autoarray.inversion.inversion import inversion_util -from autoarray.numba_util import profile_func - class AbstractInversionInterferometer(AbstractInversion): def __init__( @@ -20,7 +18,6 @@ def __init__( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -41,15 +38,12 @@ def __init__( linear_obj_list The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed the simultaneous linear equations are combined and solved simultaneously. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, ) @property @@ -78,7 +72,6 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: ] @property - @profile_func def mapped_reconstructed_image_dict( self, ) -> Dict[LinearObj, Array2D]: diff --git a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py index 29580c06d..120f1c31b 100644 --- a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py +++ b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py @@ -1,4 +1,3 @@ -from astropy.io import fits import logging import numpy as np import time @@ -387,7 +386,6 @@ def w_tilde_via_preload_from(w_tilde_preload, native_index_for_slim_index): return w_tilde_via_preload -@numba_util.jit() def data_vector_via_transformed_mapping_matrix_from( transformed_mapping_matrix: np.ndarray, visibilities: np.ndarray, @@ -406,31 +404,24 @@ def data_vector_via_transformed_mapping_matrix_from( noise_map Flattened 1D array of the noise-map used by the inversion during the fit. """ + # Extract components + vis_real = visibilities.real + vis_imag = visibilities.imag + f_real = transformed_mapping_matrix.real + f_imag = transformed_mapping_matrix.imag + noise_real = noise_map.real + noise_imag = noise_map.imag - data_vector = np.zeros(transformed_mapping_matrix.shape[1]) - - visibilities_real = visibilities.real - visibilities_imag = visibilities.imag - transformed_mapping_matrix_real = transformed_mapping_matrix.real - transformed_mapping_matrix_imag = transformed_mapping_matrix.imag - noise_map_real = noise_map.real - noise_map_imag = noise_map.imag - - for vis_1d_index in range(transformed_mapping_matrix.shape[0]): - for pix_1d_index in range(transformed_mapping_matrix.shape[1]): - real_value = ( - visibilities_real[vis_1d_index] - * transformed_mapping_matrix_real[vis_1d_index, pix_1d_index] - / (noise_map_real[vis_1d_index] ** 2.0) - ) - imag_value = ( - visibilities_imag[vis_1d_index] - * transformed_mapping_matrix_imag[vis_1d_index, pix_1d_index] - / (noise_map_imag[vis_1d_index] ** 2.0) - ) - data_vector[pix_1d_index] += real_value + imag_value + # Square noise components + inv_var_real = 1.0 / (noise_real**2) + inv_var_imag = 1.0 / (noise_imag**2) + + # Real and imaginary contributions + weighted_real = (vis_real * inv_var_real)[:, None] * f_real + weighted_imag = (vis_imag * inv_var_imag)[:, None] * f_imag - return data_vector + # Sum over visibilities + return np.sum(weighted_real + weighted_imag, axis=0) @numba_util.jit() @@ -512,7 +503,6 @@ def curvature_matrix_via_w_tilde_curvature_preload_interferometer_from( return curvature_matrix -@numba_util.jit() def mapped_reconstructed_visibilities_from( transformed_mapping_matrix: np.ndarray, reconstruction: np.ndarray ) -> np.ndarray: @@ -525,20 +515,7 @@ def mapped_reconstructed_visibilities_from( The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. """ - mapped_reconstructed_visibilities = (0.0 + 0.0j) * np.zeros( - transformed_mapping_matrix.shape[0] - ) - - transformed_mapping_matrix_real = transformed_mapping_matrix.real - transformed_mapping_matrix_imag = transformed_mapping_matrix.imag - - for i in range(transformed_mapping_matrix.shape[0]): - for j in range(reconstruction.shape[0]): - mapped_reconstructed_visibilities[i] += ( - reconstruction[j] * transformed_mapping_matrix_real[i, j] - ) + 1.0j * (reconstruction[j] * transformed_mapping_matrix_imag[i, j]) - - return mapped_reconstructed_visibilities + return transformed_mapping_matrix @ reconstruction """ @@ -813,6 +790,9 @@ def w_tilde_curvature_preload_interferometer_in_stages_with_chunks_from( check=True, directory=None, ) -> np.ndarray: + + from astropy.io import fits + if directory is None: raise NotImplementedError() @@ -1853,3 +1833,42 @@ def curvature_matrix_via_w_tilde_curvature_preload_interferometer_para_from( ) print("finished 3rd loop.") return curvature_matrix + + +@numba_util.jit() +def sub_slim_indexes_for_pix_index( + pix_indexes_for_sub_slim_index: np.ndarray, + pix_weights_for_sub_slim_index: np.ndarray, + pix_pixels: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + sub_slim_sizes_for_pix_index = np.zeros(pix_pixels) + + for pix_indexes in pix_indexes_for_sub_slim_index: + for pix_index in pix_indexes: + sub_slim_sizes_for_pix_index[pix_index] += 1 + + max_pix_size = np.max(sub_slim_sizes_for_pix_index) + + sub_slim_indexes_for_pix_index = -1 * np.ones(shape=(pix_pixels, int(max_pix_size))) + sub_slim_weights_for_pix_index = -1 * np.ones(shape=(pix_pixels, int(max_pix_size))) + sub_slim_sizes_for_pix_index = np.zeros(pix_pixels) + + for slim_index, pix_indexes in enumerate(pix_indexes_for_sub_slim_index): + pix_weights = pix_weights_for_sub_slim_index[slim_index] + + for pix_index, pix_weight in zip(pix_indexes, pix_weights): + sub_slim_indexes_for_pix_index[ + pix_index, int(sub_slim_sizes_for_pix_index[pix_index]) + ] = slim_index + + sub_slim_weights_for_pix_index[ + pix_index, int(sub_slim_sizes_for_pix_index[pix_index]) + ] = pix_weight + + sub_slim_sizes_for_pix_index[pix_index] += 1 + + return ( + sub_slim_indexes_for_pix_index, + sub_slim_sizes_for_pix_index, + sub_slim_weights_for_pix_index, + ) diff --git a/autoarray/inversion/inversion/interferometer/lop.py b/autoarray/inversion/inversion/interferometer/lop.py deleted file mode 100644 index fdd6b8adf..000000000 --- a/autoarray/inversion/inversion/interferometer/lop.py +++ /dev/null @@ -1,146 +0,0 @@ -from scipy import sparse - -import numpy as np -from typing import Dict - -from autoconf import cached_property - -from autoarray.inversion.inversion.interferometer.abstract import ( - AbstractInversionInterferometer, -) -from autoarray.inversion.linear_obj.linear_obj import LinearObj -from autoarray.structures.visibilities import Visibilities - -from autoarray.numba_util import profile_func - - -class InversionInterferometerMappingPyLops(AbstractInversionInterferometer): - """ - Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations - to be solved (see `inversion.inversion.abstract.AbstractInversion` for a full description). - - A linear object describes the mappings between values in observed `data` and the linear object's model via its - `mapping_matrix`. This class constructs linear equations for `Interferometer` objects, where the data is an - an array of visibilities and the mappings include a non-uniform fast Fourier transform operation described by - the interferometer dataset's transformer. - - This class uses the mapping formalism, which constructs the simultaneous linear equations using the - `mapping_matrix` of every linear object. This is performed using the library PyLops, which uses linear - operators to avoid these matrices being created explicitly in memory, making the calculation more efficient. - """ - - @cached_property - @profile_func - def reconstruction(self): - """ - Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12) - of https://arxiv.org/pdf/astro-ph/0302587.pdf - - S is the vector of reconstructed inversion values. - """ - - import pylops - - Aop = pylops.MatrixMult( - sparse.bsr_matrix(self.linear_obj_list[0].mapping_matrix) - ) - - Fop = self.transformer - - Op = Fop * Aop - - MOp = pylops.MatrixMult(sparse.bsr_matrix(self.preconditioner_matrix_inverse)) - - try: - return pylops.NormalEquationsInversion( - Op=Op, - Regs=None, - epsNRs=[1.0], - data=self.data.ordered_1d, - Weight=pylops.Diagonal(diag=self.noise_map.weight_list_ordered_1d), - NRegs=[ - pylops.MatrixMult(sparse.bsr_matrix(self.regularization_matrix)) - ], - M=MOp, - tol=self.settings.tolerance, - atol=self.settings.tolerance, - **dict(maxiter=self.settings.maxiter), - ) - except AttributeError: - return pylops.normal_equations_inversion( - Op=Op, - Regs=None, - epsNRs=[1.0], - y=self.data.ordered_1d, - Weight=pylops.Diagonal(diag=self.noise_map.weight_list_ordered_1d), - NRegs=[ - pylops.MatrixMult(sparse.bsr_matrix(self.regularization_matrix)) - ], - M=MOp, - tol=self.settings.tolerance, - atol=self.settings.tolerance, - **dict(maxiter=self.settings.maxiter), - )[0] - - @property - @profile_func - def mapped_reconstructed_data_dict( - self, - ) -> Dict[LinearObj, Visibilities]: - """ - When constructing the simultaneous linear equations (via vectors and matrices) the quantities of each individual - linear object (e.g. their `mapping_matrix`) are combined into single ndarrays. This does not track which - quantities belong to which linear objects, therefore the linear equation's solutions (which are returned as - ndarrays) do not contain information on which linear object(s) they correspond to. - - For example, consider if two `Mapper` objects with 50 and 100 source pixels are used in an `Inversion`. - The `reconstruction` (which contains the solved for source pixels values) is an ndarray of shape [150], but - the ndarray itself does not track which values belong to which `Mapper`. - - This function converts an ndarray of a `reconstruction` to a dictionary of ndarrays containing each linear - object's reconstructed images, where the keys are the instances of each mapper in the inversion. - - The PyLops calculation bypasses the calculation of the `mapping_matrix` and it therefore cannot be used to map - the reconstruction's values to the image-plane. Instead, the unique data-to-pixelization mappings are used, - including the 2D non-uniform fast Fourier transform operation after mapping is complete. - - Parameters - ---------- - reconstruction - The reconstruction (in the source frame) whose values are mapped to a dictionary of values for each - individual mapper (in the image-plane). - """ - - mapped_reconstructed_image_dict = self.mapped_reconstructed_image_dict - - return { - linear_obj: self.transformer.visibilities_from(image=image) - for linear_obj, image in mapped_reconstructed_image_dict.items() - } - - @cached_property - @profile_func - def preconditioner_matrix(self): - curvature_matrix_approx = np.multiply( - np.sum(self.noise_map.weight_list_ordered_1d), - self.linear_obj_list[0].mapping_matrix.T - @ self.linear_obj_list[0].mapping_matrix, - ) - - return np.add(curvature_matrix_approx, self.regularization_matrix) - - @cached_property - @profile_func - def preconditioner_matrix_inverse(self): - return np.linalg.inv(self.preconditioner_matrix) - - @cached_property - @profile_func - def log_det_curvature_reg_matrix_term(self): - return 2.0 * np.sum( - np.log(np.diag(np.linalg.cholesky(self.preconditioner_matrix))) - ) - - @property - def reconstruction_noise_map(self): - return None diff --git a/autoarray/inversion/inversion/interferometer/mapping.py b/autoarray/inversion/inversion/interferometer/mapping.py index 9cde492e9..2a4e4f316 100644 --- a/autoarray/inversion/inversion/interferometer/mapping.py +++ b/autoarray/inversion/inversion/interferometer/mapping.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from typing import Dict, List, Optional, Union @@ -15,8 +16,6 @@ from autoarray.inversion.inversion.interferometer import inversion_interferometer_util from autoarray.inversion.inversion import inversion_util -from autoarray.numba_util import profile_func - class InversionInterferometerMapping(AbstractInversionInterferometer): def __init__( @@ -24,7 +23,6 @@ def __init__( dataset: Union[Interferometer, DatasetInterface], linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -48,19 +46,15 @@ def __init__( linear_obj_list The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed the simultaneous linear equations are combined and solved simultaneously. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, ) @cached_property - @profile_func def data_vector(self) -> np.ndarray: """ The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed @@ -77,12 +71,11 @@ def data_vector(self) -> np.ndarray: return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from( transformed_mapping_matrix=self.operated_mapping_matrix, - visibilities=np.array(self.data), + visibilities=self.data, noise_map=np.array(self.noise_map), ) @cached_property - @profile_func def curvature_matrix(self) -> np.ndarray: """ The `curvature_matrix` is a 2D matrix which uses the mappings between the data and the linear objects to @@ -106,19 +99,18 @@ def curvature_matrix(self) -> np.ndarray: noise_map=self.noise_map.imag, ) - curvature_matrix = np.add(real_curvature_matrix, imag_curvature_matrix) + curvature_matrix = jnp.add(real_curvature_matrix, imag_curvature_matrix) if len(self.no_regularization_index_list) > 0: curvature_matrix = inversion_util.curvature_matrix_with_added_to_diag_from( curvature_matrix=curvature_matrix, - no_regularization_index_list=self.no_regularization_index_list, value=self.settings.no_regularization_add_to_curvature_diag_value, + no_regularization_index_list=self.no_regularization_index_list, ) return curvature_matrix @property - @profile_func def mapped_reconstructed_data_dict( self, ) -> Dict[LinearObj, Visibilities]: diff --git a/autoarray/inversion/inversion/interferometer/w_tilde.py b/autoarray/inversion/inversion/interferometer/w_tilde.py index 18d7e1cec..8a3656fa2 100644 --- a/autoarray/inversion/inversion/interferometer/w_tilde.py +++ b/autoarray/inversion/inversion/interferometer/w_tilde.py @@ -17,8 +17,6 @@ from autoarray.inversion.inversion import inversion_util from autoarray.inversion.inversion.interferometer import inversion_interferometer_util -from autoarray.numba_util import profile_func - class InversionInterferometerWTilde(AbstractInversionInterferometer): def __init__( @@ -27,7 +25,6 @@ def __init__( w_tilde: WTildeInterferometer, linear_obj_list: List[LinearObj], settings: SettingsInversion = SettingsInversion(), - run_time_dict: Optional[Dict] = None, ): """ Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations @@ -54,8 +51,6 @@ def __init__( linear_obj_list The linear objects used to reconstruct the data's observed values. If multiple linear objects are passed the simultaneous linear equations are combined and solved simultaneously. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ self.w_tilde = w_tilde @@ -65,13 +60,11 @@ def __init__( dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, - run_time_dict=run_time_dict, ) self.settings = settings @cached_property - @profile_func def data_vector(self) -> np.ndarray: """ The `data_vector` is a 1D vector whose values are solved for by the simultaneous linear equations constructed @@ -85,12 +78,9 @@ def data_vector(self) -> np.ndarray: The calculation is described in more detail in `inversion_util.w_tilde_data_interferometer_from`. """ - return np.dot( - self.linear_obj_list[0].mapping_matrix.T, self.w_tilde.dirty_image - ) + return np.dot(self.mapping_matrix.T, self.w_tilde.dirty_image) @cached_property - @profile_func def curvature_matrix(self) -> np.ndarray: """ The `curvature_matrix` is a 2D matrix which uses the mappings between the data and the linear objects to @@ -106,7 +96,6 @@ def curvature_matrix(self) -> np.ndarray: return self.curvature_matrix_diag @property - @profile_func def curvature_matrix_diag(self) -> np.ndarray: """ The `curvature_matrix` is a 2D matrix which uses the mappings between the data and the linear objects to @@ -131,7 +120,9 @@ def curvature_matrix_diag(self) -> np.ndarray: pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, pix_size_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index, pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, - native_index_for_slim_index=self.transformer.real_space_mask.derive_indexes.native_for_slim, + native_index_for_slim_index=np.array( + self.transformer.real_space_mask.derive_indexes.native_for_slim + ).astype("int"), pix_pixels=self.linear_obj_list[0].params, ) @@ -139,11 +130,17 @@ def curvature_matrix_diag(self) -> np.ndarray: sub_slim_indexes_for_pix_index, sub_slim_sizes_for_pix_index, sub_slim_weights_for_pix_index, - ) = mapper.sub_slim_indexes_for_pix_index_arr + ) = inversion_interferometer_util.sub_slim_indexes_for_pix_index( + pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, + pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, + pix_pixels=mapper.pixels, + ) return inversion_interferometer_util.curvature_matrix_via_w_tilde_curvature_preload_interferometer_from_2( curvature_preload=self.w_tilde.curvature_preload, - native_index_for_slim_index=self.transformer.real_space_mask.derive_indexes.native_for_slim, + native_index_for_slim_index=np.array( + self.transformer.real_space_mask.derive_indexes.native_for_slim + ).astype("int"), pix_pixels=self.linear_obj_list[0].params, sub_slim_indexes_for_pix_index=sub_slim_indexes_for_pix_index.astype("int"), sub_slim_sizes_for_pix_index=sub_slim_sizes_for_pix_index.astype("int"), @@ -151,7 +148,6 @@ def curvature_matrix_diag(self) -> np.ndarray: ) @property - @profile_func def mapped_reconstructed_data_dict( self, ) -> Dict[LinearObj, Visibilities]: diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 178ce08ac..95e216c9e 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -1,14 +1,13 @@ +import jax.numpy as jnp +import jax.lax as lax import numpy as np -from typing import List, Optional, Tuple - -from autoconf import conf +from typing import List, Optional, Type from autoarray.inversion.inversion.settings import SettingsInversion from autoarray import numba_util from autoarray import exc -from autoarray.util.fnnls import fnnls_cholesky def curvature_matrix_via_w_tilde_from( @@ -35,11 +34,9 @@ def curvature_matrix_via_w_tilde_from( ndarray The curvature matrix `F` (see Warren & Dye 2003). """ - - return np.dot(mapping_matrix.T, np.dot(w_tilde, mapping_matrix)) + return jnp.dot(mapping_matrix.T, jnp.dot(w_tilde, mapping_matrix)) -@numba_util.jit() def curvature_matrix_with_added_to_diag_from( curvature_matrix: np.ndarray, value: float, @@ -59,31 +56,22 @@ def curvature_matrix_with_added_to_diag_from( curvature_matrix The curvature matrix which is being constructed in order to solve a linear system of equations. """ + return curvature_matrix.at[ + no_regularization_index_list, no_regularization_index_list + ].add(value) - for i in no_regularization_index_list: - curvature_matrix[i, i] += value - - return curvature_matrix - -@numba_util.jit() def curvature_matrix_mirrored_from( curvature_matrix: np.ndarray, ) -> np.ndarray: - curvature_matrix_mirrored = np.zeros( - (curvature_matrix.shape[0], curvature_matrix.shape[1]) - ) + # Copy the original matrix and its transpose + m1 = curvature_matrix + m2 = curvature_matrix.T - for i in range(curvature_matrix.shape[0]): - for j in range(curvature_matrix.shape[1]): - if curvature_matrix[i, j] != 0: - curvature_matrix_mirrored[i, j] = curvature_matrix[i, j] - curvature_matrix_mirrored[j, i] = curvature_matrix[i, j] - if curvature_matrix[j, i] != 0: - curvature_matrix_mirrored[i, j] = curvature_matrix[j, i] - curvature_matrix_mirrored[j, i] = curvature_matrix[j, i] + # For each entry, prefer the non-zero value from either the matrix or its transpose + mirrored = jnp.where(m1 != 0, m1, m2) - return curvature_matrix_mirrored + return mirrored def curvature_matrix_via_mapping_matrix_from( @@ -106,7 +94,7 @@ def curvature_matrix_via_mapping_matrix_from( Flattened 1D array of the noise-map used by the inversion during the fit. """ array = mapping_matrix / noise_map[:, None] - curvature_matrix = np.dot(array.T, array) + curvature_matrix = jnp.dot(array.T, array) if add_to_curvature_diag and len(no_regularization_index_list) > 0: curvature_matrix = curvature_matrix_with_added_to_diag_from( @@ -118,12 +106,8 @@ def curvature_matrix_via_mapping_matrix_from( return curvature_matrix -@numba_util.jit() -def mapped_reconstructed_data_via_image_to_pix_unique_from( - data_to_pix_unique: np.ndarray, - data_weights: np.ndarray, - pix_lengths: np.ndarray, - reconstruction: np.ndarray, +def mapped_reconstructed_data_via_mapping_matrix_from( + mapping_matrix: np.ndarray, reconstruction: np.ndarray ) -> np.ndarray: """ Returns the reconstructed data vector from the blurred mapping matrix `f` and solution vector *S*. @@ -134,48 +118,40 @@ def mapped_reconstructed_data_via_image_to_pix_unique_from( The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. """ - - data_pixels = data_to_pix_unique.shape[0] - - mapped_reconstructed_data = np.zeros(data_pixels) - - for data_0 in range(data_pixels): - for pix_0 in range(pix_lengths[data_0]): - pix_for_data = data_to_pix_unique[data_0, pix_0] - - mapped_reconstructed_data[data_0] += ( - data_weights[data_0, pix_0] * reconstruction[pix_for_data] - ) - - return mapped_reconstructed_data + return jnp.dot(mapping_matrix, reconstruction) -@numba_util.jit() -def mapped_reconstructed_data_via_mapping_matrix_from( - mapping_matrix: np.ndarray, reconstruction: np.ndarray +def mapped_reconstructed_data_via_w_tilde_from( + w_tilde: np.ndarray, mapping_matrix: np.ndarray, reconstruction: np.ndarray ) -> np.ndarray: """ - Returns the reconstructed data vector from the blurred mapping matrix `f` and solution vector *S*. + Returns the reconstructed data vector from the unblurred mapping matrix `M`, + the reconstruction vector `s`, and the PSF convolution operator `w_tilde`. + + Equivalent to: + reconstructed = (W @ M) @ s + = W @ (M @ s) Parameters ---------- + w_tilde + Array of shape [image_pixels, image_pixels], the PSF convolution operator. mapping_matrix - The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. + Array of shape [image_pixels, source_pixels], unblurred mapping matrix. + reconstruction + Array of shape [source_pixels], solution vector. + Returns + ------- + ndarray + The reconstructed data vector of shape [image_pixels]. """ - mapped_reconstructed_data = np.zeros(mapping_matrix.shape[0]) - for i in range(mapping_matrix.shape[0]): - for j in range(reconstruction.shape[0]): - mapped_reconstructed_data[i] += reconstruction[j] * mapping_matrix[i, j] - - return mapped_reconstructed_data + return w_tilde @ (mapping_matrix @ reconstruction) def reconstruction_positive_negative_from( data_vector: np.ndarray, curvature_reg_matrix: np.ndarray, - mapper_param_range_list, - force_check_reconstruction: bool = False, ): """ Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12) @@ -212,29 +188,12 @@ def reconstruction_positive_negative_from( curvature_reg_matrix The curvature_matrix plus regularization matrix, overwriting the curvature_matrix in memory. """ - try: - reconstruction = np.linalg.solve(curvature_reg_matrix, data_vector) - except np.linalg.LinAlgError as e: - raise exc.InversionException() from e - - if ( - conf.instance["general"]["inversion"]["check_reconstruction"] - or force_check_reconstruction - ): - for mapper_param_range in mapper_param_range_list: - if np.allclose( - a=reconstruction[mapper_param_range[0] : mapper_param_range[1]], - b=reconstruction[mapper_param_range[0]], - ): - raise exc.InversionException() - - return reconstruction + return jnp.linalg.solve(curvature_reg_matrix, data_vector) def reconstruction_positive_only_from( data_vector: np.ndarray, curvature_reg_matrix: np.ndarray, - settings: SettingsInversion = SettingsInversion(), ): """ Solve the linear system Eq.(2) (in terms of minimizing the quadratic value) of @@ -278,27 +237,9 @@ def reconstruction_positive_only_from( ------- Non-negative S that minimizes the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf. """ + import jaxnnls - if len(data_vector): - try: - if settings.positive_only_uses_p_initial: - P_initial = np.linalg.solve(curvature_reg_matrix, data_vector) > 0 - else: - P_initial = np.zeros(0, dtype=int) - - reconstruction = fnnls_cholesky( - curvature_reg_matrix, - (data_vector).T, - P_initial=P_initial, - ) - - except (RuntimeError, np.linalg.LinAlgError, ValueError) as e: - raise exc.InversionException() from e - - else: - raise exc.InversionException() - - return reconstruction + return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector) def preconditioner_matrix_via_mapping_matrix_from( @@ -328,3 +269,48 @@ def preconditioner_matrix_via_mapping_matrix_from( return ( preconditioner_noise_normalization * curvature_matrix ) + regularization_matrix + + +def param_range_list_from(cls: Type, linear_obj_list) -> List[List[int]]: + """ + Each linear object in the `Inversion` has N parameters, and these parameters correspond to a certain range + of indexing values in the matrices used to perform the inversion. + + This function returns the `param_range_list` of an input type of linear object, which gives the indexing range + of each linear object of the input type. + + For example, if an `Inversion` has: + + - A `LinearFuncList` linear object with 3 `params`. + - A `Mapper` with 100 `params`. + - A `Mapper` with 200 `params`. + + The corresponding matrices of this inversion (e.g. the `curvature_matrix`) have `shape=(303, 303)` where: + + - The `LinearFuncList` values are in the entries `[0:3]`. + - The first `Mapper` values are in the entries `[3:103]`. + - The second `Mapper` values are in the entries `[103:303] + + For this example, `param_range_list_from(cls=AbstractMapper)` therefore returns the + list `[[3, 103], [103, 303]]`. + + Parameters + ---------- + cls + The type of class that the list of their parameter range index values are returned for. + + Returns + ------- + A list of the index range of the parameters of each linear object in the inversion of the input cls type. + """ + index_list = [] + + pixel_count = 0 + + for linear_obj in linear_obj_list: + if isinstance(linear_obj, cls): + index_list.append([pixel_count, pixel_count + linear_obj.params]) + + pixel_count += linear_obj.params + + return index_list diff --git a/autoarray/inversion/inversion/settings.py b/autoarray/inversion/inversion/settings.py index 2c0eba077..3deab4a6e 100644 --- a/autoarray/inversion/inversion/settings.py +++ b/autoarray/inversion/inversion/settings.py @@ -10,13 +10,11 @@ class SettingsInversion: def __init__( self, - use_w_tilde: bool = True, + use_w_tilde: bool = False, use_positive_only_solver: Optional[bool] = None, positive_only_uses_p_initial: Optional[bool] = None, use_border_relocator: Optional[bool] = None, force_edge_pixels_to_zeros: bool = True, - force_edge_image_pixels_to_zeros: bool = False, - image_pixels_source_zero=None, no_regularization_add_to_curvature_diag_value: float = None, use_w_tilde_numpy: bool = False, use_source_loop: bool = False, @@ -84,8 +82,6 @@ def __init__( self._use_border_relocator = use_border_relocator self.use_linear_operators = use_linear_operators self.force_edge_pixels_to_zeros = force_edge_pixels_to_zeros - self.force_edge_image_pixels_to_zeros = force_edge_image_pixels_to_zeros - self.image_pixels_source_zero = image_pixels_source_zero self._no_regularization_add_to_curvature_diag_value = ( no_regularization_add_to_curvature_diag_value ) diff --git a/autoarray/inversion/linear_obj/func_list.py b/autoarray/inversion/linear_obj/func_list.py index ce72bdcbe..9f30f3c60 100644 --- a/autoarray/inversion/linear_obj/func_list.py +++ b/autoarray/inversion/linear_obj/func_list.py @@ -9,15 +9,12 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization from autoarray.type import Grid1D2DLike -from autoarray.numba_util import profile_func - class AbstractLinearObjFuncList(LinearObj): def __init__( self, grid: Grid1D2DLike, regularization: Optional[AbstractRegularization], - run_time_dict: Optional[Dict] = None, ): """ A linear object which reconstructs a dataset based on mapping between the data points of that dataset and @@ -42,11 +39,9 @@ def __init__( is evaluated. regularization The regularization scheme which may be applied to this linear object in order to smooth its solution. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ - super().__init__(regularization=regularization, run_time_dict=run_time_dict) + super().__init__(regularization=regularization) self.grid = grid @@ -83,7 +78,6 @@ def neighbors(self) -> Neighbors: ) @cached_property - @profile_func def unique_mappings(self) -> UniqueMappings: """ Returns the unique mappings of every unmasked data pixel's (e.g. `grid_slim`) sub-pixels (e.g. `grid_sub_slim`) diff --git a/autoarray/inversion/linear_obj/linear_obj.py b/autoarray/inversion/linear_obj/linear_obj.py index 1e0bacc85..402658039 100644 --- a/autoarray/inversion/linear_obj/linear_obj.py +++ b/autoarray/inversion/linear_obj/linear_obj.py @@ -6,14 +6,11 @@ from autoarray.inversion.linear_obj.neighbors import Neighbors from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.numba_util import profile_func - class LinearObj: def __init__( self, regularization: Optional[AbstractRegularization], - run_time_dict: Optional[Dict] = None, ): """ A linear object which reconstructs a dataset based on mapping between the data points of that dataset and @@ -33,11 +30,8 @@ def __init__( ---------- regularization The regularization scheme which may be applied to this linear object in order to smooth its solution. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ self.regularization = regularization - self.run_time_dict = run_time_dict @property def params(self) -> int: @@ -75,7 +69,6 @@ def neighbors(self) -> Neighbors: raise NotImplementedError @cached_property - @profile_func def unique_mappings(self): """ An object describing the unique mappings between data points / pixels in the data and the parameters of the diff --git a/autoarray/inversion/mock/mock_inversion.py b/autoarray/inversion/mock/mock_inversion.py index 1fb125861..053cb4a0e 100644 --- a/autoarray/inversion/mock/mock_inversion.py +++ b/autoarray/inversion/mock/mock_inversion.py @@ -29,13 +29,15 @@ def __init__( regularization_term=None, log_det_curvature_reg_matrix_term=None, log_det_regularization_matrix_term=None, - settings: SettingsInversion = SettingsInversion(), + settings: SettingsInversion = None, ): dataset = DatasetInterface( data=data, noise_map=noise_map, ) + settings = settings or SettingsInversion() + super().__init__( dataset=dataset, linear_obj_list=linear_obj_list or [], diff --git a/autoarray/inversion/mock/mock_inversion_imaging.py b/autoarray/inversion/mock/mock_inversion_imaging.py index de7f2baa1..673f33566 100644 --- a/autoarray/inversion/mock/mock_inversion_imaging.py +++ b/autoarray/inversion/mock/mock_inversion_imaging.py @@ -10,6 +10,7 @@ class MockInversionImaging(InversionImagingMapping): def __init__( self, + mask=None, data=None, noise_map=None, psf=None, @@ -17,8 +18,11 @@ def __init__( operated_mapping_matrix=None, linear_func_operated_mapping_matrix_dict=None, data_linear_func_matrix_dict=None, - settings: SettingsInversion = SettingsInversion(), + settings: SettingsInversion = None, ): + + settings = settings or SettingsInversion() + dataset = DatasetInterface( data=data, noise_map=noise_map, @@ -31,6 +35,7 @@ def __init__( settings=settings, ) + self._mask = mask self._operated_mapping_matrix = operated_mapping_matrix self._linear_func_operated_mapping_matrix_dict = ( @@ -38,6 +43,13 @@ def __init__( ) self._data_linear_func_matrix_dict = data_linear_func_matrix_dict + @property + def mask(self) -> np.ndarray: + if self._mask is None: + return super().mask + + return self._mask + @property def operated_mapping_matrix(self) -> np.ndarray: if self._operated_mapping_matrix is None: @@ -74,7 +86,7 @@ def __init__( w_tilde=None, linear_obj_list=None, curvature_matrix_mapper_diag=None, - settings: SettingsInversion = SettingsInversion(), + settings: SettingsInversion = None, ): dataset = DatasetInterface( data=data, @@ -82,6 +94,8 @@ def __init__( psf=psf, ) + settings = settings or SettingsInversion() + super().__init__( dataset=dataset, w_tilde=w_tilde or MockWTildeImaging(), diff --git a/autoarray/inversion/mock/mock_inversion_interferometer.py b/autoarray/inversion/mock/mock_inversion_interferometer.py index 58de71520..a25ec03f9 100644 --- a/autoarray/inversion/mock/mock_inversion_interferometer.py +++ b/autoarray/inversion/mock/mock_inversion_interferometer.py @@ -15,7 +15,7 @@ def __init__( transformer=None, linear_obj_list=None, operated_mapping_matrix=None, - settings: SettingsInversion = SettingsInversion(), + settings: SettingsInversion = None, ): dataset = DatasetInterface( data=data, @@ -23,6 +23,8 @@ def __init__( transformer=transformer, ) + settings = settings or SettingsInversion() + super().__init__( dataset=dataset, linear_obj_list=linear_obj_list, diff --git a/autoarray/inversion/mock/mock_mesh.py b/autoarray/inversion/mock/mock_mesh.py index def02657a..721c6a990 100644 --- a/autoarray/inversion/mock/mock_mesh.py +++ b/autoarray/inversion/mock/mock_mesh.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Dict, Optional +from typing import Optional from autoarray.mask.mask_2d import Mask2D from autoarray.inversion.pixelization.mesh.abstract import AbstractMesh @@ -23,7 +23,6 @@ def mapper_grids_from( source_plane_mesh_grid: Optional[Abstract2DMesh] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: Optional[np.ndarray] = None, - run_time_dict: Optional[Dict] = None, ) -> MapperGrids: return MapperGrids( mask=mask, @@ -32,7 +31,6 @@ def mapper_grids_from( source_plane_mesh_grid=source_plane_mesh_grid, image_plane_mesh_grid=self.image_plane_mesh_grid, adapt_data=adapt_data, - run_time_dict=run_time_dict, ) def image_plane_mesh_grid_from( diff --git a/autoarray/inversion/mock/mock_pixelization.py b/autoarray/inversion/mock/mock_pixelization.py index 72ced89e5..a71abebf7 100644 --- a/autoarray/inversion/mock/mock_pixelization.py +++ b/autoarray/inversion/mock/mock_pixelization.py @@ -1,5 +1,3 @@ -import numpy as np - from autoarray.mask.mask_2d import Mask2D from autoarray.inversion.pixelization.pixelization import Pixelization @@ -29,7 +27,6 @@ def mapper_grids_from( image_plane_mesh_grid=None, adapt_data=None, settings=None, - run_time_dict=None, ): return self.mapper diff --git a/autoarray/inversion/pixelization/border_relocator.py b/autoarray/inversion/pixelization/border_relocator.py index 737444b53..e46cbef1e 100644 --- a/autoarray/inversion/pixelization/border_relocator.py +++ b/autoarray/inversion/pixelization/border_relocator.py @@ -1,8 +1,7 @@ from __future__ import annotations +import jax.numpy as jnp import numpy as np -from typing import Union - -from autoconf import cached_property +from typing import Tuple, Union from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.uniform_2d import Array2D @@ -54,7 +53,7 @@ def sub_slim_indexes_for_slim_index_via_mask_2d_from( slim_index_for_sub_slim_indexes = ( over_sample_util.slim_index_for_sub_slim_index_via_mask_2d_from( - mask_2d=mask_2d, sub_size=np.array(sub_size) + mask_2d=mask_2d, sub_size=sub_size ).astype("int") ) @@ -64,6 +63,43 @@ def sub_slim_indexes_for_slim_index_via_mask_2d_from( return sub_slim_indexes_for_slim_index +def furthest_grid_2d_slim_index_from( + grid_2d_slim: np.ndarray, slim_indexes: np.ndarray, coordinate: Tuple[float, float] +) -> int: + """ + Returns the index in `slim_indexes` corresponding to the 2D point in `grid_2d_slim` + that is furthest from a given coordinate, measured by squared Euclidean distance. + + Parameters + ---------- + grid_2d_slim + A 2D array of shape (N, 2), where each row is a (y, x) coordinate. + slim_indexes + An array of indices into `grid_2d_slim` specifying which coordinates to consider. + coordinate + The (y, x) coordinate from which distances are calculated. + + Returns + ------- + int + The slim index of the point in `grid_2d_slim[slim_indexes]` that is furthest from `coordinate`. + """ + subgrid = grid_2d_slim[slim_indexes] + dy = subgrid[:, 0] - coordinate[0] + dx = subgrid[:, 1] - coordinate[1] + squared_distances = dx**2 + dy**2 + + max_dist = np.max(squared_distances) + + # Find all indices with max distance + max_positions = np.where(squared_distances == max_dist)[0] + + # Choose the last one (to match original loop behavior) + max_index = max_positions[-1] + + return slim_indexes[max_index] + + def sub_border_pixel_slim_indexes_from( mask_2d: np.ndarray, sub_size: Array2D ) -> np.ndarray: @@ -107,7 +143,7 @@ def sub_border_pixel_slim_indexes_from( sub_grid_2d_slim = over_sample_util.grid_2d_slim_over_sampled_via_mask_from( mask_2d=mask_2d, pixel_scales=(1.0, 1.0), - sub_size=np.array(sub_size), + sub_size=sub_size, origin=(0.0, 0.0), ) mask_centre = grid_2d_util.grid_2d_centre_from(grid_2d_slim=sub_grid_2d_slim) @@ -117,129 +153,176 @@ def sub_border_pixel_slim_indexes_from( int(border_pixel) ] - sub_border_pixels[border_1d_index] = ( - grid_2d_util.furthest_grid_2d_slim_index_from( - grid_2d_slim=sub_grid_2d_slim, - slim_indexes=sub_border_pixels_of_border_pixel, - coordinate=mask_centre, - ) + sub_border_pixels[border_1d_index] = furthest_grid_2d_slim_index_from( + grid_2d_slim=sub_grid_2d_slim, + slim_indexes=sub_border_pixels_of_border_pixel, + coordinate=mask_centre, ) return sub_border_pixels -class BorderRelocator: - def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): - self.mask = mask +def sub_border_slim_from(mask, sub_size): + """ + Returns the subgridded 1D ``slim`` indexes of border pixels in the ``Mask2D``, representing all unmasked + sub-pixels (given by ``False``) which neighbor any masked value (give by ``True``) and which are on the + extreme exterior of the mask. - self.sub_size = over_sample_util.over_sample_size_convert_to_array_2d_from( - over_sample_size=sub_size, mask=mask + The indexes are the sub-gridded extension of the ``border_slim`` which is illustrated above. + + This quantity is too complicated to write-out in a docstring, and it is recommended you print it in + Python code to understand it if anything is unclear. + + Examples + -------- + + .. code-block:: python + + import autoarray as aa + + mask_2d = aa.Mask2D( + mask=[[True, True, True, True, True, True, True, True, True], + [True, False, False, False, False, False, False, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, False, True, False, True, False, True], + [True, False, True, False, False, False, True, False, True], + [True, False, True, True, True, True, True, False, True], + [True, False, False, False, False, False, False, False, True], + [True, True, True, True, True, True, True, True, True]] + pixel_scales=1.0, ) - @cached_property - def border_slim(self): - """ - Returns the 1D ``slim`` indexes of border pixels in the ``Mask2D``, representing all unmasked - sub-pixels (given by ``False``) which neighbor any masked value (give by ``True``) and which are on the - extreme exterior of the mask. + derive_indexes_2d = aa.DeriveIndexes2D(mask=mask_2d) - The indexes are the extended below to form the ``sub_border_slim`` which is illustrated above. + print(derive_indexes_2d.sub_border_slim) + """ + return sub_border_pixel_slim_indexes_from( + mask_2d=mask, sub_size=sub_size.astype("int") + ).astype("int") - This quantity is too complicated to write-out in a docstring, and it is recommended you print it in - Python code to understand it if anything is unclear. - Examples - -------- +def relocated_grid_from(grid, border_grid): + """ + Relocate the coordinates of a grid to its border if they are outside the border, where the border is + defined as all pixels at the edge of the grid's mask (see *mask._border_1d_indexes*). - .. code-block:: python + This is performed as follows: - import autoarray as aa + 1: Use the mean value of the grid's y and x coordinates to determine the origin of the grid. + 2: Compute the radial distance of every grid coordinate from the origin. + 3: For every coordinate, find its nearest pixel in the border. + 4: Determine if it is outside the border, by comparing its radial distance from the origin to its paired + border pixel's radial distance. + 5: If its radial distance is larger, use the ratio of radial distances to move the coordinate to the + border (if its inside the border, do nothing). - mask_2d = aa.Mask2D( - mask=[[True, True, True, True, True, True, True, True, True], - [True, False, False, False, False, False, False, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, False, True, False, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, False, False, False, False, False, False, True], - [True, True, True, True, True, True, True, True, True]] - pixel_scales=1.0, - ) + The method can be used on uniform or irregular grids, however for irregular grids the border of the + 'image-plane' mask is used to define border pixels. - derive_indexes_2d = aa.DeriveIndexes2D(mask=mask_2d) + Parameters + ---------- + grid + The grid (uniform or irregular) whose pixels are to be relocated to the border edge if outside it. + border_grid : Grid2D + The grid of border (y,x) coordinates. + """ - print(derive_indexes_2d.border_slim) - """ - return self.mask.derive_indexes.border_slim + # Compute origin (center) of the border grid + border_origin = jnp.mean(border_grid, axis=0) - @cached_property - def sub_border_slim(self) -> np.ndarray: - """ - Returns the subgridded 1D ``slim`` indexes of border pixels in the ``Mask2D``, representing all unmasked - sub-pixels (given by ``False``) which neighbor any masked value (give by ``True``) and which are on the - extreme exterior of the mask. + # Radii from origin + grid_radii = jnp.linalg.norm(grid - border_origin, axis=1) # (N,) + border_radii = jnp.linalg.norm(border_grid - border_origin, axis=1) # (M,) + border_min_radius = jnp.min(border_radii) - The indexes are the sub-gridded extension of the ``border_slim`` which is illustrated above. + # Determine which points are outside + outside_mask = grid_radii > border_min_radius # (N,) - This quantity is too complicated to write-out in a docstring, and it is recommended you print it in - Python code to understand it if anything is unclear. + # To compute nearest border point for each grid point, we must do it for all and then mask later + # Compute all distances: (N, M) + diffs = grid[:, None, :] - border_grid[None, :, :] # (N, M, 2) + dists_squared = jnp.sum(diffs**2, axis=2) # (N, M) + closest_indices = jnp.argmin(dists_squared, axis=1) # (N,) - Examples - -------- + # Get border radius for closest border point to each grid point + matched_border_radii = border_radii[closest_indices] # (N,) - .. code-block:: python + # Ratio of border to grid radius + move_factors = matched_border_radii / grid_radii # (N,) - import autoarray as aa + # Only move if: + # - the point is outside the border + # - the matched border point is closer to the origin (i.e. move_factor < 1) + apply_move = jnp.logical_and(outside_mask, move_factors < 1.0) # (N,) - mask_2d = aa.Mask2D( - mask=[[True, True, True, True, True, True, True, True, True], - [True, False, False, False, False, False, False, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, False, True, False, True, False, True], - [True, False, True, False, False, False, True, False, True], - [True, False, True, True, True, True, True, False, True], - [True, False, False, False, False, False, False, False, True], - [True, True, True, True, True, True, True, True, True]] - pixel_scales=1.0, - ) + # Compute moved positions (for all points, but will select with mask) + direction_vectors = grid - border_origin # (N, 2) + moved_grid = move_factors[:, None] * direction_vectors + border_origin # (N, 2) - derive_indexes_2d = aa.DeriveIndexes2D(mask=mask_2d) + # Select which grid points to move + relocated_grid = jnp.where(apply_move[:, None], moved_grid, grid) # (N, 2) - print(derive_indexes_2d.sub_border_slim) - """ - return sub_border_pixel_slim_indexes_from( - mask_2d=np.array(self.mask), sub_size=np.array(self.sub_size).astype("int") - ).astype("int") + return relocated_grid - @cached_property - def border_grid(self) -> np.ndarray: - """ - The (y,x) grid of all sub-pixels which are at the border of the mask. - This is NOT all sub-pixels which are in mask pixels at the mask's border, but specifically the sub-pixels - within these border pixels which are at the extreme edge of the border. +class BorderRelocator: + def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): """ - return self.mask.derive_grid.border + Relocates source plane coordinates that trace outside the mask’s border in the source-plane back onto the + border. - @cached_property - def sub_border_grid(self) -> np.ndarray: - """ - The (y,x) grid of all sub-pixels which are at the border of the mask. + Given an input mask and (optionally) a per‐pixel sub‐sampling size, this class computes: + + 1. `border_grid`: the (y,x) coordinates of every border pixel of the mask. + 2. `sub_border_grid`: an over‐sampled border grid if sub‐sampling is requested. + 3. `relocated_grid(grid)`: for any arbitrary grid of points (uniform or irregular), returns a new grid + where any point whose radius from the mask center exceeds the minimum radius of the border is + moved radially inward until it lies exactly on its nearest border pixel. + + In practice this ensures that “outlier” rays or source‐plane pixels don’t fall outside the allowed + mask region when performing pixelization–based inversions or lens‐plane mappings. + + See Figure 2 of https://arxiv.org/abs/1708.07377 for a description of why this functionality is required. - This is NOT all sub-pixels which are in mask pixels at the mask's border, but specifically the sub-pixels - within these border pixels which are at the extreme edge of the border. + Attributes + ---------- + mask : Mask2D + The input mask whose border defines the permissible region. + sub_size : Array2D + Per‐pixel sub‐sampling size (can be constant or spatially varying). + border_slim : np.ndarray + 1D indexes of the mask’s border pixels in the slimmed representation. + sub_border_slim : np.ndarray + 1D indexes of the over‐sampled (sub) border pixels. + border_grid : np.ndarray + Array of (y,x) coordinates for each border pixel. + sub_border_grid : np.ndarray + Array of (y,x) coordinates for each over‐sampled border pixel. """ + self.mask = mask + + self.sub_size = over_sample_util.over_sample_size_convert_to_array_2d_from( + over_sample_size=sub_size, mask=mask + ) + + self.border_slim = self.mask.derive_indexes.border_slim + self.sub_border_slim = sub_border_slim_from( + mask=self.mask, sub_size=self.sub_size + ) + try: + self.border_grid = self.mask.derive_grid.border + except TypeError: + self.border_grid = None + sub_grid = over_sample_util.grid_2d_slim_over_sampled_via_mask_from( - mask_2d=np.array(self.mask), + mask_2d=self.mask, pixel_scales=self.mask.pixel_scales, - sub_size=np.array(self.sub_size).astype("int"), + sub_size=self.sub_size.astype("int"), origin=self.mask.origin, ) - return sub_grid[self.sub_border_slim] + self.sub_border_grid = sub_grid[self.sub_border_slim] def relocated_grid_from(self, grid: Grid2D) -> Grid2D: """ @@ -268,14 +351,14 @@ def relocated_grid_from(self, grid: Grid2D) -> Grid2D: if len(self.sub_border_grid) == 0: return grid - values = grid_2d_util.relocated_grid_via_jit_from( - grid=np.array(grid), - border_grid=np.array(grid[self.border_slim]), + values = relocated_grid_from( + grid=grid.array, + border_grid=grid.array[self.border_slim], ) - over_sampled = grid_2d_util.relocated_grid_via_jit_from( - grid=np.array(grid.over_sampled), - border_grid=np.array(grid.over_sampled[self.sub_border_slim]), + over_sampled = relocated_grid_from( + grid=grid.over_sampled.array, + border_grid=grid.over_sampled.array[self.sub_border_slim], ) return Grid2D( @@ -302,8 +385,8 @@ def relocated_mesh_grid_from( return mesh_grid return Grid2DIrregular( - values=grid_2d_util.relocated_grid_via_jit_from( - grid=np.array(mesh_grid), - border_grid=np.array(grid[self.sub_border_slim]), + values=relocated_grid_from( + grid=mesh_grid.array, + border_grid=grid[self.sub_border_slim], ), ) diff --git a/autoarray/inversion/pixelization/image_mesh/hilbert.py b/autoarray/inversion/pixelization/image_mesh/hilbert.py index 964457ded..119db098d 100644 --- a/autoarray/inversion/pixelization/image_mesh/hilbert.py +++ b/autoarray/inversion/pixelization/image_mesh/hilbert.py @@ -1,15 +1,13 @@ from __future__ import annotations import numpy as np -from scipy.interpolate import interp1d, griddata + from typing import Optional -from autoarray.mask.mask_2d import Mask2D from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.mask.mask_2d import Mask2D from autoarray.inversion.pixelization.image_mesh.abstract_weighted import ( AbstractImageMeshWeighted, ) -from autoarray.operators.over_sampling.over_sampler import OverSampler from autoarray.inversion.inversion.settings import SettingsInversion from autoarray.structures.grids.irregular_2d import Grid2DIrregular @@ -125,6 +123,8 @@ def image_and_grid_from(image, mask, mask_radius, pixel_scales, hilbert_length): image associated to that grid. """ + from scipy.interpolate import griddata + # For multi wavelength fits the input image may be a different resolution than the mask. try: @@ -164,6 +164,7 @@ def inverse_transform_sampling_interpolated(probabilities, n_samples, gridx, gri probabilities: 1D normalized cumulative probablity curve. n_samples: the number of points to draw. """ + from scipy.interpolate import interp1d cdf = np.cumsum(probabilities) npixels = len(probabilities) diff --git a/autoarray/inversion/pixelization/image_mesh/kmeans.py b/autoarray/inversion/pixelization/image_mesh/kmeans.py index a7aa96536..55ecb99fc 100644 --- a/autoarray/inversion/pixelization/image_mesh/kmeans.py +++ b/autoarray/inversion/pixelization/image_mesh/kmeans.py @@ -1,5 +1,4 @@ import numpy as np -from sklearn.cluster import KMeans as ScipyKMeans from typing import Optional import sys import warnings @@ -97,6 +96,8 @@ def image_plane_mesh_grid_from( weight_map = self.weight_map_from(adapt_data=adapt_data) + from sklearn.cluster import KMeans as ScipyKMeans + kmeans = ScipyKMeans( n_clusters=int(self.pixels), random_state=1, @@ -104,7 +105,7 @@ def image_plane_mesh_grid_from( max_iter=5, ) - grid = mask.derive_grid.unmasked + grid = mask.derive_grid.unmasked.array try: kmeans = kmeans.fit(X=grid, sample_weight=weight_map) diff --git a/autoarray/inversion/pixelization/image_mesh/overlay.py b/autoarray/inversion/pixelization/image_mesh/overlay.py index de130ee6e..654cf8ec7 100644 --- a/autoarray/inversion/pixelization/image_mesh/overlay.py +++ b/autoarray/inversion/pixelization/image_mesh/overlay.py @@ -220,11 +220,13 @@ def image_plane_mesh_grid_from( origin=origin, ) - overlaid_centres = np.array(geometry_util.grid_pixel_centres_2d_slim_from( - grid_scaled_2d_slim=unmasked_overlay_grid, - shape_native=mask.shape_native, - pixel_scales=mask.pixel_scales, - )).astype("int") + overlaid_centres = np.array( + geometry_util.grid_pixel_centres_2d_slim_from( + grid_scaled_2d_slim=unmasked_overlay_grid, + shape_native=mask.shape_native, + pixel_scales=mask.pixel_scales, + ) + ).astype("int") total_pixels = total_pixels_2d_from( mask_2d=mask.array, diff --git a/autoarray/inversion/pixelization/mappers/abstract.py b/autoarray/inversion/pixelization/mappers/abstract.py index fd0ec1038..480650ee0 100644 --- a/autoarray/inversion/pixelization/mappers/abstract.py +++ b/autoarray/inversion/pixelization/mappers/abstract.py @@ -15,9 +15,8 @@ from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.mesh.abstract_2d import Abstract2DMesh - -from autoarray.numba_util import profile_func from autoarray.inversion.pixelization.mappers import mapper_util +from autoarray.inversion.pixelization.mappers import mapper_numba_util class AbstractMapper(LinearObj): @@ -26,7 +25,6 @@ def __init__( mapper_grids: MapperGrids, regularization: Optional[AbstractRegularization], border_relocator: BorderRelocator, - run_time_dict: Optional[Dict] = None, ): """ To understand a `Mapper` one must be familiar `Mesh` objects and the `mesh` and `pixelization` packages, where @@ -83,11 +81,9 @@ def __init__( border_relocator The border relocator, which relocates coordinates outside the border of the source-plane data grid to its edge. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ - super().__init__(regularization=regularization, run_time_dict=run_time_dict) + super().__init__(regularization=regularization) self.border_relocator = border_relocator self.mapper_grids = mapper_grids @@ -212,31 +208,7 @@ def sub_slim_indexes_for_pix_index(self) -> List[List]: return sub_slim_indexes_for_pix_index - @property - @profile_func - def sub_slim_indexes_for_pix_index_arr( - self, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Returns the index mappings between each of the pixelization's pixels and the masked data's sub-pixels. - - Given that even pixelization pixel maps to multiple data sub-pixels, index mappings are returned as a list of - lists where the first entries are the pixelization index and second entries store the data sub-pixel indexes. - - For example, if `sub_slim_indexes_for_pix_index[2][4] = 10`, the pixelization pixel with index 2 - (e.g. `mesh_grid[2,:]`) has a mapping to a data sub-pixel with index 10 (e.g. `grid_slim[10, :]). - - This is effectively a reversal of the array `pix_indexes_for_sub_slim_index`. - """ - - return mapper_util.sub_slim_indexes_for_pix_index( - pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=self.pix_weights_for_sub_slim_index, - pix_pixels=self.pixels, - ) - @cached_property - @profile_func def unique_mappings(self) -> UniqueMappings: """ Returns the unique mappings of every unmasked data pixel's (e.g. `grid_slim`) sub-pixels (e.g. `grid_sub_slim`) @@ -254,11 +226,15 @@ def unique_mappings(self) -> UniqueMappings: data_to_pix_unique, data_weights, pix_lengths, - ) = mapper_util.data_slim_to_pixelization_unique_from( + ) = mapper_numba_util.data_slim_to_pixelization_unique_from( data_pixels=self.over_sampler.mask.pixels_in_mask, - pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=self.pix_sizes_for_sub_slim_index, - pix_weights_for_sub_slim_index=self.pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=np.array( + self.pix_indexes_for_sub_slim_index + ), + pix_sizes_for_sub_slim_index=np.array(self.pix_sizes_for_sub_slim_index), + pix_weights_for_sub_slim_index=np.array( + self.pix_weights_for_sub_slim_index + ), pix_pixels=self.params, sub_size=np.array(self.over_sampler.sub_size).astype("int"), ) @@ -270,7 +246,6 @@ def unique_mappings(self) -> UniqueMappings: ) @cached_property - @profile_func def mapping_matrix(self) -> np.ndarray: """ The `mapping_matrix` of a linear object describes the mappings between the observed data's data-points / pixels @@ -283,6 +258,7 @@ def mapping_matrix(self) -> np.ndarray: It is described in the following paper as matrix `f` https://arxiv.org/pdf/astro-ph/0302587.pdf and in more detail in the function `mapper_util.mapping_matrix_from()`. """ + return mapper_util.mapping_matrix_from( pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index, @@ -290,7 +266,7 @@ def mapping_matrix(self) -> np.ndarray: pixels=self.pixels, total_mask_pixels=self.over_sampler.mask.pixels_in_mask, slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index, - sub_fraction=np.array(self.over_sampler.sub_fraction), + sub_fraction=self.over_sampler.sub_fraction.array, ) def pixel_signals_from(self, signal_scale: float) -> np.ndarray: @@ -313,10 +289,10 @@ def pixel_signals_from(self, signal_scale: float) -> np.ndarray: pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index, slim_index_for_sub_slim_index=self.over_sampler.slim_for_sub_slim, - adapt_data=np.array(self.adapt_data), + adapt_data=self.adapt_data.array, ) - def pix_indexes_for_slim_indexes(self, pix_indexes: List) -> List[List]: + def slim_indexes_for_pix_indexes(self, pix_indexes: List) -> List[List]: """ Returns the index mappings between every masked data-point (not subgridded) on the data and the mapper pixels / parameters that it maps too. @@ -324,7 +300,7 @@ def pix_indexes_for_slim_indexes(self, pix_indexes: List) -> List[List]: The `slim_index` refers to the masked data pixels (without subgridding) and `pix_indexes` the pixelization pixel indexes, for example: - - `pix_indexes_for_slim_indexes[0] = [2, 3]`: The data's first (index 0) pixel maps to the + - `slim_indexes_for_pix_indexes[0] = [2, 3]`: The data's first (index 0) pixel maps to the pixelization's third (index 2) and fourth (index 3) pixels. Parameters @@ -363,8 +339,12 @@ def data_weight_total_for_pix_from(self) -> np.ndarray: """ return mapper_util.data_weight_total_for_pix_from( - pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=self.pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=np.array( + self.pix_indexes_for_sub_slim_index + ), + pix_weights_for_sub_slim_index=np.array( + self.pix_weights_for_sub_slim_index + ), pixels=self.pixels, ) @@ -387,8 +367,8 @@ def mapped_to_source_from(self, array: Array2D) -> np.ndarray: source domain in order to compute their average values. """ return mapper_util.mapped_to_source_via_mapping_matrix_from( - mapping_matrix=self.mapping_matrix, - array_slim=np.array(array.slim), + mapping_matrix=np.array(self.mapping_matrix), + array_slim=array.slim, ) def extent_from( @@ -397,6 +377,9 @@ def extent_from( zoom_to_brightest: bool = True, zoom_percent: Optional[float] = None, ) -> Tuple[float, float, float, float]: + + from autoarray.geometry import geometry_util + if zoom_to_brightest and values is not None: if zoom_percent is None: zoom_percent = conf.instance["visualize"]["general"]["zoom"][ @@ -408,8 +391,6 @@ def extent_from( true_indices = np.argwhere(fractional_bool) true_grid = self.source_plane_mesh_grid[true_indices] - from autoarray.geometry import geometry_util - try: return geometry_util.extent_symmetric_from( extent=( @@ -420,9 +401,13 @@ def extent_from( ) ) except ValueError: - return self.source_plane_mesh_grid.geometry.extent + return geometry_util.extent_symmetric_from( + extent=self.source_plane_mesh_grid.geometry.extent + ) - return self.source_plane_mesh_grid.geometry.extent + return geometry_util.extent_symmetric_from( + extent=self.source_plane_mesh_grid.geometry.extent + ) def interpolated_array_from( self, diff --git a/autoarray/inversion/pixelization/mappers/delaunay.py b/autoarray/inversion/pixelization/mappers/delaunay.py index 737247b0b..1b4226771 100644 --- a/autoarray/inversion/pixelization/mappers/delaunay.py +++ b/autoarray/inversion/pixelization/mappers/delaunay.py @@ -5,8 +5,8 @@ from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights -from autoarray.numba_util import profile_func from autoarray.inversion.pixelization.mappers import mapper_util +from autoarray.inversion.pixelization.mappers import mapper_numba_util class MapperDelaunay(AbstractMapper): @@ -56,8 +56,6 @@ class MapperDelaunay(AbstractMapper): regularization The regularization scheme which may be applied to this linear object in order to smooth its solution, which for a mapper smooths neighboring pixels on the mesh. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ @property @@ -65,7 +63,6 @@ def delaunay(self): return self.source_plane_mesh_grid.delaunay @cached_property - @profile_func def pix_sub_weights(self) -> PixSubWeights: """ Computes the following three quantities describing the mappings between of every sub-pixel in the masked data @@ -107,7 +104,7 @@ def pix_sub_weights(self) -> PixSubWeights: The interpolation weights of these multiple mappings are stored in the array `pix_weights_for_sub_slim_index`. For the Delaunay pixelization these mappings are calculated using the Scipy spatial library - (see `mapper_util.pix_indexes_for_sub_slim_index_delaunay_from`). + (see `mapper_numba_util.pix_indexes_for_sub_slim_index_delaunay_from`). """ delaunay = self.delaunay @@ -116,17 +113,21 @@ def pix_sub_weights(self) -> PixSubWeights: ) pix_indexes_for_simplex_index = delaunay.simplices - mappings, sizes = mapper_util.pix_indexes_for_sub_slim_index_delaunay_from( - source_plane_data_grid=np.array(self.source_plane_data_grid.over_sampled), - simplex_index_for_sub_slim_index=simplex_index_for_sub_slim_index, - pix_indexes_for_simplex_index=pix_indexes_for_simplex_index, - delaunay_points=delaunay.points, + mappings, sizes = ( + mapper_numba_util.pix_indexes_for_sub_slim_index_delaunay_from( + source_plane_data_grid=np.array( + self.source_plane_data_grid.over_sampled + ), + simplex_index_for_sub_slim_index=simplex_index_for_sub_slim_index, + pix_indexes_for_simplex_index=pix_indexes_for_simplex_index, + delaunay_points=delaunay.points, + ) ) mappings = mappings.astype("int") sizes = sizes.astype("int") - weights = mapper_util.pixel_weights_delaunay_from( + weights = mapper_numba_util.pixel_weights_delaunay_from( source_plane_data_grid=np.array(self.source_plane_data_grid.over_sampled), source_plane_mesh_grid=np.array(self.source_plane_mesh_grid), slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index, @@ -158,14 +159,14 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights: ( splitted_mappings, splitted_sizes, - ) = mapper_util.pix_indexes_for_sub_slim_index_delaunay_from( + ) = mapper_numba_util.pix_indexes_for_sub_slim_index_delaunay_from( source_plane_data_grid=self.source_plane_mesh_grid.split_cross, simplex_index_for_sub_slim_index=splitted_simplex_index_for_sub_slim_index, pix_indexes_for_simplex_index=pix_indexes_for_simplex_index, delaunay_points=delaunay.points, ) - splitted_weights = mapper_util.pixel_weights_delaunay_from( + splitted_weights = mapper_numba_util.pixel_weights_delaunay_from( source_plane_data_grid=self.source_plane_mesh_grid.split_cross, source_plane_mesh_grid=np.array(self.source_plane_mesh_grid), slim_index_for_sub_slim_index=self.source_plane_mesh_grid.split_cross, diff --git a/autoarray/inversion/pixelization/mappers/factory.py b/autoarray/inversion/pixelization/mappers/factory.py index 310133202..689f35011 100644 --- a/autoarray/inversion/pixelization/mappers/factory.py +++ b/autoarray/inversion/pixelization/mappers/factory.py @@ -4,6 +4,7 @@ from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray.inversion.regularization.abstract import AbstractRegularization from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular +from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay from autoarray.structures.mesh.voronoi_2d import Mesh2DVoronoi @@ -12,7 +13,6 @@ def mapper_from( mapper_grids: MapperGrids, regularization: Optional[AbstractRegularization], border_relocator: Optional[BorderRelocator] = None, - run_time_dict: Optional[Dict] = None, ): """ Factory which given input `MapperGrids` and `Regularization` objects creates a `Mapper`. @@ -32,8 +32,6 @@ def mapper_from( regularization The regularization scheme which may be applied to this linear object in order to smooth its solution, which for a mapper smooths neighboring pixels on the mesh. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. Returns ------- @@ -42,27 +40,33 @@ def mapper_from( from autoarray.inversion.pixelization.mappers.rectangular import ( MapperRectangular, ) + from autoarray.inversion.pixelization.mappers.rectangular_uniform import ( + MapperRectangularUniform, + ) from autoarray.inversion.pixelization.mappers.delaunay import MapperDelaunay from autoarray.inversion.pixelization.mappers.voronoi import MapperVoronoi - if isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular): + if isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangularUniform): + return MapperRectangularUniform( + mapper_grids=mapper_grids, + border_relocator=border_relocator, + regularization=regularization, + ) + elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular): return MapperRectangular( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, - run_time_dict=run_time_dict, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay): return MapperDelaunay( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, - run_time_dict=run_time_dict, ) elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DVoronoi): return MapperVoronoi( mapper_grids=mapper_grids, border_relocator=border_relocator, regularization=regularization, - run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/pixelization/mappers/mapper_grids.py b/autoarray/inversion/pixelization/mappers/mapper_grids.py index 9a12e2f95..86074b043 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_grids.py +++ b/autoarray/inversion/pixelization/mappers/mapper_grids.py @@ -19,7 +19,6 @@ def __init__( source_plane_mesh_grid: Optional[Abstract2DMesh] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: Optional[np.ndarray] = None, - run_time_dict: Optional[Dict] = None, ): """ Groups the different grids used by `Mesh` objects, the `mesh` package and the `pixelization` package, which @@ -55,8 +54,6 @@ def __init__( adapt_data An image which is used to determine the `image_plane_mesh_grid` and therefore adapt the distribution of pixels of the Delaunay grid to the data it discretizes. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ self.mask = mask @@ -64,7 +61,6 @@ def __init__( self.source_plane_mesh_grid = source_plane_mesh_grid self.image_plane_mesh_grid = image_plane_mesh_grid self.adapt_data = adapt_data - self.run_time_dict = run_time_dict @property def image_plane_data_grid(self): @@ -72,6 +68,7 @@ def image_plane_data_grid(self): @property def mesh_pixels_per_image_pixels(self): + mesh_pixels_per_image_pixels = grid_2d_util.grid_pixels_in_mask_pixels_from( grid=np.array(self.image_plane_mesh_grid), shape_native=self.mask.shape_native, diff --git a/autoarray/inversion/pixelization/mappers/mapper_numba_util.py b/autoarray/inversion/pixelization/mappers/mapper_numba_util.py new file mode 100644 index 000000000..916c813dd --- /dev/null +++ b/autoarray/inversion/pixelization/mappers/mapper_numba_util.py @@ -0,0 +1,353 @@ +import numpy as np +from typing import Tuple + +from autoconf import conf + +from autoarray import numba_util +from autoarray.inversion.pixelization.mesh import mesh_numba_util + +from autoarray import exc + + +@numba_util.jit() +def data_slim_to_pixelization_unique_from( + data_pixels, + pix_indexes_for_sub_slim_index: np.ndarray, + pix_sizes_for_sub_slim_index: np.ndarray, + pix_weights_for_sub_slim_index, + pix_pixels: int, + sub_size: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Create an array describing the unique mappings between the sub-pixels of every slim data pixel and the pixelization + pixels, which is used to perform efficiently linear algebra calculations. + + For example, assuming `sub_size=2`: + + - If 3 sub-pixels in image pixel 0 map to pixelization pixel 2 then `data_pix_to_unique[0, 0] = 2`. + - If the fourth sub-pixel maps to pixelization pixel 4, then `data_to_pix_unique[0, 1] = 4`. + + The size of the second index depends on the number of unique sub-pixel to pixelization pixels mappings in a given + data pixel. In the example above, there were only two unique sets of mapping, but for high levels of sub-gridding + there could be many more unique mappings all of which must be stored. + + The array `data_to_pix_unique` does not describe how many sub-pixels uniquely map to each pixelization pixel for + a given data pixel. This information is contained in the array `data_weights`. For the example above, + where `sub_size=2` and therefore `sub_fraction=0.25`: + + - `data_weights[0, 0] = 0.75` (because 3 sub-pixels mapped to this pixelization pixel). + - `data_weights[0, 1] = 0.25` (because 1 sub-pixel mapped to this pixelization pixel). + + The `sub_fractions` are stored as opposed to the number of sub-pixels, because these values are used directly + when performing the linear algebra calculation. + + The array `pix_lengths` in a 1D array of dimensions [data_pixels] describing how many unique pixelization pixels + each data pixel's set of sub-pixels maps too. + + Parameters + ---------- + data_pixels + The total number of data pixels in the dataset. + pix_indexes_for_sub_slim_index + Maps an unmasked data sub pixel to its corresponding pixelization pixel. + sub_size + The size of the sub-grid defining the number of sub-pixels in every data pixel. + + Returns + ------- + ndarray + The unique mappings between the sub-pixels of every data pixel and the pixelization pixels, alongside arrays + that give the weights and total number of mappings. + """ + + sub_fraction = 1.0 / (sub_size**2.0) + + max_pix_mappings = int(np.max(pix_sizes_for_sub_slim_index)) + + # TODO : Work out if we can reduce size from np.max(sub_size) using sub_size of max_pix_mappings. + + data_to_pix_unique = -1 * np.ones( + (data_pixels, max_pix_mappings * np.max(sub_size) ** 2) + ) + data_weights = np.zeros((data_pixels, max_pix_mappings * np.max(sub_size) ** 2)) + pix_lengths = np.zeros(data_pixels) + pix_check = -1 * np.ones(shape=pix_pixels) + + ip_sub_start = 0 + + for ip in range(data_pixels): + pix_check[:] = -1 + + pix_size = 0 + + ip_sub_end = ip_sub_start + sub_size[ip] ** 2 + + for ip_sub in range(ip_sub_start, ip_sub_end): + for pix_interp_index in range(pix_sizes_for_sub_slim_index[ip_sub]): + pix = pix_indexes_for_sub_slim_index[ip_sub, pix_interp_index] + pixel_weight = pix_weights_for_sub_slim_index[ip_sub, pix_interp_index] + + if pix_check[pix] > -0.5: + data_weights[ip, int(pix_check[pix])] += ( + sub_fraction[ip] * pixel_weight + ) + + else: + data_to_pix_unique[ip, pix_size] = pix + data_weights[ip, pix_size] += sub_fraction[ip] * pixel_weight + pix_check[pix] = pix_size + pix_size += 1 + + ip_sub_start = ip_sub_end + + pix_lengths[ip] = pix_size + + return data_to_pix_unique, data_weights, pix_lengths + + +@numba_util.jit() +def pix_indexes_for_sub_slim_index_delaunay_from( + source_plane_data_grid, + simplex_index_for_sub_slim_index, + pix_indexes_for_simplex_index, + delaunay_points, +) -> Tuple[np.ndarray, np.ndarray]: + """ + The indexes mappings between the sub pixels and Voronoi mesh pixels. + For Delaunay tessellation, most sub pixels should have contribution of 3 pixelization pixels. However, + for those ones not belonging to any triangle, we link its value to its closest point. + + The returning result is a matrix of (len(sub_pixels, 3)) where the entries mark the relevant source pixel indexes. + A row like [A, -1, -1] means that sub pixel only links to source pixel A. + """ + + pix_indexes_for_sub_slim_index = -1 * np.ones( + shape=(source_plane_data_grid.shape[0], 3) + ) + + for i in range(len(source_plane_data_grid)): + simplex_index = simplex_index_for_sub_slim_index[i] + if simplex_index != -1: + pix_indexes_for_sub_slim_index[i] = pix_indexes_for_simplex_index[ + simplex_index_for_sub_slim_index[i] + ] + else: + pix_indexes_for_sub_slim_index[i][0] = np.argmin( + np.sum((delaunay_points - source_plane_data_grid[i]) ** 2.0, axis=1) + ) + + pix_indexes_for_sub_slim_index_sizes = np.sum( + pix_indexes_for_sub_slim_index >= 0, axis=1 + ) + + return pix_indexes_for_sub_slim_index, pix_indexes_for_sub_slim_index_sizes + + +@numba_util.jit() +def pixel_weights_delaunay_from( + source_plane_data_grid, + source_plane_mesh_grid, + slim_index_for_sub_slim_index: np.ndarray, + pix_indexes_for_sub_slim_index, +) -> np.ndarray: + """ + Returns the weights of the mappings between the masked sub-pixels and the Delaunay pixelization. + + Weights are determiend via a nearest neighbor interpolation scheme, whereby every data-sub pixel maps to three + Delaunay pixel vertexes (in the source frame). The weights of these 3 mappings depends on the distance of the + coordinate to each vertex, with the highest weight being its closest neighbor, + + Parameters + ---------- + source_plane_data_grid + A 2D grid of (y,x) coordinates associated with the unmasked 2D data after it has been transformed to the + `source` reference frame. + source_plane_mesh_grid + The 2D grid of (y,x) centres of every pixelization pixel in the `source` frame. + slim_index_for_sub_slim_index + The mappings between the data's sub slimmed indexes and the slimmed indexes on the non sub-sized indexes. + pix_indexes_for_sub_slim_index + The mappings from a data sub-pixel index to a pixelization pixel index. + """ + + pixel_weights = np.zeros(pix_indexes_for_sub_slim_index.shape) + + for sub_slim_index in range(slim_index_for_sub_slim_index.shape[0]): + pix_indexes = pix_indexes_for_sub_slim_index[sub_slim_index] + + if pix_indexes[1] != -1: + vertices_of_the_simplex = source_plane_mesh_grid[pix_indexes] + + sub_gird_coordinate_on_source_place = source_plane_data_grid[sub_slim_index] + + area_0 = mesh_numba_util.delaunay_triangle_area_from( + corner_0=vertices_of_the_simplex[1], + corner_1=vertices_of_the_simplex[2], + corner_2=sub_gird_coordinate_on_source_place, + ) + area_1 = mesh_numba_util.delaunay_triangle_area_from( + corner_0=vertices_of_the_simplex[0], + corner_1=vertices_of_the_simplex[2], + corner_2=sub_gird_coordinate_on_source_place, + ) + area_2 = mesh_numba_util.delaunay_triangle_area_from( + corner_0=vertices_of_the_simplex[0], + corner_1=vertices_of_the_simplex[1], + corner_2=sub_gird_coordinate_on_source_place, + ) + + norm = area_0 + area_1 + area_2 + + weight_abc = np.array([area_0, area_1, area_2]) / norm + + pixel_weights[sub_slim_index] = weight_abc + + else: + pixel_weights[sub_slim_index][0] = 1.0 + + return pixel_weights + + +@numba_util.jit() +def remove_bad_entries_voronoi_nn( + bad_indexes, + pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index, + grid, + mesh_grid, +): + """ + The nearest neighbor interpolation can return invalid or bad entries which are removed from the mapping arrays. The + current circumstances this arises are: + + 1) If a point is outside the whole Voronoi region, some weights have negative values. In this case, we reset its + neighbor to its closest neighbor. + + 2) The nearest neighbor interpolation code may not return even a single neighbor. We mark these as a bad grid by + settings their neighbors to the closest ones. + + Parameters + ---------- + bad_indexes + pix_weights_for_sub_slim_index + pix_indexes_for_sub_slim_index + grid + mesh_grid + + Returns + ------- + + """ + + for item in bad_indexes: + ind = item[0] + pix_indexes_for_sub_slim_index[ind] = -1 + pix_indexes_for_sub_slim_index[ind][0] = np.argmin( + np.sum((grid[ind] - mesh_grid) ** 2.0, axis=1) + ) + pix_weights_for_sub_slim_index[ind] = 0.0 + pix_weights_for_sub_slim_index[ind][0] = 1.0 + + return pix_weights_for_sub_slim_index, pix_indexes_for_sub_slim_index + + +def pix_size_weights_voronoi_nn_from( + grid: np.ndarray, mesh_grid: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Returns the mappings between a set of slimmed sub-grid pixels and pixelization pixels, using information on + how the pixels hosting each sub-pixel map to their closest pixelization pixel on the slim grid in the data-plane + and the pixelization's pixel centres. + + To determine the complete set of slim sub-pixel to pixelization pixel mappings, we must pair every sub-pixel to + its nearest pixel. Using a full nearest neighbor search to do this is slow, thus the pixel neighbors (derived via + the Voronoi grid) are used to localize each nearest neighbor search by using a graph search. + + Parameters + ---------- + grid + The grid of (y,x) scaled coordinates at the centre of every unmasked pixel, which has been traced to + to an irgrid via lens. + slim_index_for_sub_slim_index + The mappings between the data slimmed sub-pixels and their regular pixels. + mesh_grid + The (y,x) centre of every Voronoi pixel in arc-seconds. + neighbors + An array of length (voronoi_pixels) which provides the index of all neighbors of every pixel in + the Voronoi grid (entries of -1 correspond to no neighbor). + neighbors_sizes + An array of length (voronoi_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + """ + + try: + from autoarray.util.nn import nn_py + except ImportError as e: + raise ImportError( + "In order to use the Voronoi pixelization you must install the " + "Natural Neighbor Interpolation c package.\n\n" + "" + "See: https://github.com/Jammy2211/PyAutoArray/tree/main/autoarray/util/nn" + ) from e + + max_nneighbours = conf.instance["general"]["pixelization"][ + "voronoi_nn_max_interpolation_neighbors" + ] + + ( + pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index, + ) = nn_py.natural_interpolation_weights( + x_in=mesh_grid[:, 1], + y_in=mesh_grid[:, 0], + x_target=grid[:, 1], + y_target=grid[:, 0], + max_nneighbours=max_nneighbours, + ) + + bad_indexes = np.argwhere(np.sum(pix_weights_for_sub_slim_index < 0.0, axis=1) > 0) + + ( + pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index, + ) = remove_bad_entries_voronoi_nn( + bad_indexes=bad_indexes, + pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, + grid=np.array(grid), + mesh_grid=np.array(mesh_grid), + ) + + bad_indexes = np.argwhere(pix_indexes_for_sub_slim_index[:, 0] == -1) + + ( + pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index, + ) = remove_bad_entries_voronoi_nn( + bad_indexes=bad_indexes, + pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, + grid=np.array(grid), + mesh_grid=np.array(mesh_grid), + ) + + pix_indexes_for_sub_slim_index_sizes = np.sum( + pix_indexes_for_sub_slim_index != -1, axis=1 + ) + + if np.max(pix_indexes_for_sub_slim_index_sizes) > max_nneighbours: + raise exc.MeshException( + f""" + The number of Voronoi natural neighbours interpolations in one or more pixelization pixel's + exceeds the maximum allowed: max_nneighbors = {max_nneighbours}. + + To fix this, increase the value of `voronoi_nn_max_interpolation_neighbors` in the [pixelization] + section of the `general.ini` config file. + """ + ) + + return ( + pix_indexes_for_sub_slim_index, + pix_indexes_for_sub_slim_index_sizes, + pix_weights_for_sub_slim_index, + ) diff --git a/autoarray/inversion/pixelization/mappers/mapper_util.py b/autoarray/inversion/pixelization/mappers/mapper_util.py index d9792fe01..2fac0a81c 100644 --- a/autoarray/inversion/pixelization/mappers/mapper_util.py +++ b/autoarray/inversion/pixelization/mappers/mapper_util.py @@ -1,411 +1,288 @@ +from functools import partial +import jax +import jax.numpy as jnp import numpy as np -from scipy.spatial import cKDTree from typing import Tuple from autoconf import conf -from autoarray import numba_util from autoarray import exc -from autoarray.inversion.pixelization.mesh import mesh_util -@numba_util.jit() -def sub_slim_indexes_for_pix_index( - pix_indexes_for_sub_slim_index: np.ndarray, - pix_weights_for_sub_slim_index: np.ndarray, - pix_pixels: int, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - sub_slim_sizes_for_pix_index = np.zeros(pix_pixels) +def forward_interp(xp, yp, x): + return jax.vmap(jnp.interp, in_axes=(1, 1, None, None, None))(x, xp, yp, 0, 1).T - for pix_indexes in pix_indexes_for_sub_slim_index: - for pix_index in pix_indexes: - sub_slim_sizes_for_pix_index[pix_index] += 1 - max_pix_size = np.max(sub_slim_sizes_for_pix_index) +def reverse_interp(xp, yp, x): + return jax.vmap(jnp.interp, in_axes=(1, None, 1))(x, xp, yp).T - sub_slim_indexes_for_pix_index = -1 * np.ones(shape=(pix_pixels, int(max_pix_size))) - sub_slim_weights_for_pix_index = -1 * np.ones(shape=(pix_pixels, int(max_pix_size))) - sub_slim_sizes_for_pix_index = np.zeros(pix_pixels) - for slim_index, pix_indexes in enumerate(pix_indexes_for_sub_slim_index): - pix_weights = pix_weights_for_sub_slim_index[slim_index] +def create_transforms(traced_points): + # make functions that takes a set of traced points + # stored in a (N, 2) array and return functions that + # take in (N, 2) arrays and transform the values into + # the range (0, 1) and the inverse transform + N = traced_points.shape[0] # // 2 + t = jnp.arange(1, N + 1) / (N + 1) - for pix_index, pix_weight in zip(pix_indexes, pix_weights): - sub_slim_indexes_for_pix_index[ - pix_index, int(sub_slim_sizes_for_pix_index[pix_index]) - ] = slim_index + sort_points = jnp.sort(traced_points, axis=0) # [::2] - sub_slim_weights_for_pix_index[ - pix_index, int(sub_slim_sizes_for_pix_index[pix_index]) - ] = pix_weight + transform = partial(forward_interp, sort_points, t) + inv_transform = partial(reverse_interp, t, sort_points) + return transform, inv_transform - sub_slim_sizes_for_pix_index[pix_index] += 1 - return ( - sub_slim_indexes_for_pix_index, - sub_slim_sizes_for_pix_index, - sub_slim_weights_for_pix_index, - ) +def adaptive_rectangular_transformed_grid_from(source_plane_data_grid, grid): + mu = source_plane_data_grid.mean(axis=0) + scale = source_plane_data_grid.std(axis=0).min() + source_grid_scaled = (source_plane_data_grid - mu) / scale + transform, inv_transform = create_transforms(source_grid_scaled) -@numba_util.jit() -def data_slim_to_pixelization_unique_from( - data_pixels, - pix_indexes_for_sub_slim_index: np.ndarray, - pix_sizes_for_sub_slim_index: np.ndarray, - pix_weights_for_sub_slim_index, - pix_pixels: int, - sub_size: np.ndarray, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Create an array describing the unique mappings between the sub-pixels of every slim data pixel and the pixelization - pixels, which is used to perform efficiently linear algebra calculations. + def inv_full(U): + return inv_transform(U) * scale + mu - For example, assuming `sub_size=2`: + return inv_full(grid) - - If 3 sub-pixels in image pixel 0 map to pixelization pixel 2 then `data_pix_to_unique[0, 0] = 2`. - - If the fourth sub-pixel maps to pixelization pixel 4, then `data_to_pix_unique[0, 1] = 4`. - The size of the second index depends on the number of unique sub-pixel to pixelization pixels mappings in a given - data pixel. In the example above, there were only two unique sets of mapping, but for high levels of sub-gridding - there could be many more unique mappings all of which must be stored. +def adaptive_rectangular_areas_from(source_grid_size, source_plane_data_grid): - The array `data_to_pix_unique` does not describe how many sub-pixels uniquely map to each pixelization pixel for - a given data pixel. This information is contained in the array `data_weights`. For the example above, - where `sub_size=2` and therefore `sub_fraction=0.25`: + pixel_edges_1d = jnp.linspace(0, 1, source_grid_size + 1) - - `data_weights[0, 0] = 0.75` (because 3 sub-pixels mapped to this pixelization pixel). - - `data_weights[0, 1] = 0.25` (because 1 sub-pixel mapped to this pixelization pixel). + mu = source_plane_data_grid.mean(axis=0) + scale = source_plane_data_grid.std(axis=0).min() + source_grid_scaled = (source_plane_data_grid - mu) / scale - The `sub_fractions` are stored as opposed to the number of sub-pixels, because these values are used directly - when performing the linear algebra calculation. + transform, inv_transform = create_transforms(source_grid_scaled) - The array `pix_lengths` in a 1D array of dimensions [data_pixels] describing how many unique pixelization pixels - each data pixel's set of sub-pixels maps too. + def inv_full(U): + return inv_transform(U) * scale + mu - Parameters - ---------- - data_pixels - The total number of data pixels in the dataset. - pix_indexes_for_sub_slim_index - Maps an unmasked data sub pixel to its corresponding pixelization pixel. - sub_size - The size of the sub-grid defining the number of sub-pixels in every data pixel. + pixel_edges = inv_full(jnp.stack([pixel_edges_1d, pixel_edges_1d]).T) + pixel_lengths = jnp.diff(pixel_edges, axis=0).squeeze() # shape (N_source, 2) - Returns - ------- - ndarray - The unique mappings between the sub-pixels of every data pixel and the pixelization pixels, alongside arrays - that give the weights and total number of mappings. - """ - - sub_fraction = 1.0 / (sub_size**2.0) + dy = pixel_lengths[:, 0] + dx = pixel_lengths[:, 1] - max_pix_mappings = int(np.max(pix_sizes_for_sub_slim_index)) - - # TODO : Work out if we can reduce size from np.max(sub_size) using sub_size of max_pix_mappings. - - data_to_pix_unique = -1 * np.ones( - (data_pixels, max_pix_mappings * np.max(sub_size) ** 2) - ) - data_weights = np.zeros((data_pixels, max_pix_mappings * np.max(sub_size) ** 2)) - pix_lengths = np.zeros(data_pixels) - pix_check = -1 * np.ones(shape=pix_pixels) + return jnp.outer(dy, dx).flatten() - ip_sub_start = 0 - for ip in range(data_pixels): - pix_check[:] = -1 - - pix_size = 0 - - ip_sub_end = ip_sub_start + sub_size[ip] ** 2 - - for ip_sub in range(ip_sub_start, ip_sub_end): - for pix_interp_index in range(pix_sizes_for_sub_slim_index[ip_sub]): - pix = pix_indexes_for_sub_slim_index[ip_sub, pix_interp_index] - pixel_weight = pix_weights_for_sub_slim_index[ip_sub, pix_interp_index] - - if pix_check[pix] > -0.5: - data_weights[ip, int(pix_check[pix])] += ( - sub_fraction[ip] * pixel_weight - ) - - else: - data_to_pix_unique[ip, pix_size] = pix - data_weights[ip, pix_size] += sub_fraction[ip] * pixel_weight - pix_check[pix] = pix_size - pix_size += 1 - - ip_sub_start = ip_sub_end - - pix_lengths[ip] = pix_size - - return data_to_pix_unique, data_weights, pix_lengths - - -@numba_util.jit() -def pix_indexes_for_sub_slim_index_delaunay_from( +def adaptive_rectangular_mappings_weights_via_interpolation_from( + source_grid_size: int, source_plane_data_grid, - simplex_index_for_sub_slim_index, - pix_indexes_for_simplex_index, - delaunay_points, -) -> Tuple[np.ndarray, np.ndarray]: + source_plane_data_grid_over_sampled, +): """ - The indexes mappings between the sub pixels and Voronoi mesh pixels. - For Delaunay tessellation, most sub pixels should have contribution of 3 pixelization pixels. However, - for those ones not belonging to any triangle, we link its value to its closest point. + Compute bilinear interpolation indices and weights for mapping an oversampled + source-plane grid onto a regular rectangular pixelization. + + This function takes a set of irregularly-sampled source-plane coordinates and + builds an adaptive mapping onto a `source_grid_size x source_grid_size` rectangular + pixelization using bilinear interpolation. The interpolation is expressed as: + + f(x, y) ≈ w_bl * f(ix_down, iy_down) + + w_br * f(ix_up, iy_down) + + w_tl * f(ix_down, iy_up) + + w_tr * f(ix_up, iy_up) + + where `(ix_down, ix_up, iy_down, iy_up)` are the integer grid coordinates + surrounding the continuous position `(x, y)`. + + Steps performed: + 1. Normalize the source-plane grid by subtracting its mean and dividing by + the minimum axis standard deviation (to balance scaling). + 2. Construct forward/inverse transforms which map the grid into the unit square [0,1]^2. + 3. Transform the oversampled source-plane grid into [0,1]^2, then scale it + to index space `[0, source_grid_size)`. + 4. Compute floor/ceil along x and y axes to find the enclosing rectangular cell. + 5. Build the four corner indices: bottom-left (bl), bottom-right (br), + top-left (tl), and top-right (tr). + 6. Flatten the 2D indices into 1D indices suitable for scatter operations, + with a flipped row-major convention: row = source_grid_size - i, col = j. + 7. Compute bilinear interpolation weights (`w_bl, w_br, w_tl, w_tr`). + 8. Return arrays of flattened indices and weights of shape `(N, 4)`, where + `N` is the number of oversampled coordinates. - The returning result is a matrix of (len(sub_pixels, 3)) where the entries mark the relevant source pixel indexes. - A row like [A, -1, -1] means that sub pixel only links to source pixel A. - """ + Parameters + ---------- + source_grid_size : int + The number of pixels along one dimension of the rectangular pixelization. + The grid is square: (source_grid_size x source_grid_size). + source_plane_data_grid : (M, 2) ndarray + The base source-plane coordinates, used to define normalization and transforms. + source_plane_data_grid_over_sampled : (N, 2) ndarray + Oversampled source-plane coordinates to be interpolated onto the rectangular grid. - pix_indexes_for_sub_slim_index = -1 * np.ones( - shape=(source_plane_data_grid.shape[0], 3) - ) + Returns + ------- + flat_indices : (N, 4) int ndarray + The flattened indices of the four neighboring pixel corners for each oversampled point. + Order: [bl, br, tl, tr]. + weights : (N, 4) float ndarray + The bilinear interpolation weights for each of the four neighboring pixels. + Order: [w_bl, w_br, w_tl, w_tr]. + """ - for i in range(len(source_plane_data_grid)): - simplex_index = simplex_index_for_sub_slim_index[i] - if simplex_index != -1: - pix_indexes_for_sub_slim_index[i] = pix_indexes_for_simplex_index[ - simplex_index_for_sub_slim_index[i] - ] - else: - pix_indexes_for_sub_slim_index[i][0] = np.argmin( - np.sum((delaunay_points - source_plane_data_grid[i]) ** 2.0, axis=1) - ) - - pix_indexes_for_sub_slim_index_sizes = np.sum( - pix_indexes_for_sub_slim_index >= 0, axis=1 + # --- Step 1. Normalize grid --- + mu = source_plane_data_grid.mean(axis=0) + scale = source_plane_data_grid.std(axis=0).min() + source_grid_scaled = (source_plane_data_grid - mu) / scale + + # --- Step 2. Build transforms --- + transform, inv_transform = create_transforms(source_grid_scaled) + + # --- Step 3. Transform oversampled grid into index space --- + grid_over_sampled_scaled = (source_plane_data_grid_over_sampled - mu) / scale + grid_over_sampled_transformed = transform(grid_over_sampled_scaled) + grid_over_index = (source_grid_size - 3) * grid_over_sampled_transformed + 1 + + # --- Step 4. Floor/ceil indices --- + ix_down = jnp.floor(grid_over_index[:, 0]) + ix_up = jnp.ceil(grid_over_index[:, 0]) + iy_down = jnp.floor(grid_over_index[:, 1]) + iy_up = jnp.ceil(grid_over_index[:, 1]) + + # --- Step 5. Four corners --- + idx_tl = jnp.stack([ix_up, iy_down], axis=1) + idx_tr = jnp.stack([ix_up, iy_up], axis=1) + idx_br = jnp.stack([ix_down, iy_up], axis=1) + idx_bl = jnp.stack([ix_down, iy_down], axis=1) + + # --- Step 6. Flatten indices --- + def flatten(idx, n): + row = n - idx[:, 0] + col = idx[:, 1] + return row * n + col + + flat_tl = flatten(idx_tl, source_grid_size) + flat_tr = flatten(idx_tr, source_grid_size) + flat_bl = flatten(idx_bl, source_grid_size) + flat_br = flatten(idx_br, source_grid_size) + + flat_indices = jnp.stack([flat_tl, flat_tr, flat_bl, flat_br], axis=1).astype( + "int64" ) - return pix_indexes_for_sub_slim_index, pix_indexes_for_sub_slim_index_sizes - - -def nearest_pixelization_index_for_slim_index_from_kdtree(grid, mesh_grid): - kdtree = cKDTree(mesh_grid) - - sparse_index_for_slim_index = [] + # --- Step 7. Bilinear interpolation weights --- + t_row = (grid_over_index[:, 0] - ix_down) / (ix_up - ix_down + 1e-12) + t_col = (grid_over_index[:, 1] - iy_down) / (iy_up - iy_down + 1e-12) - for i in range(grid.shape[0]): - input_point = [grid[i, [0]], grid[i, 1]] - index = kdtree.query(input_point)[1] - sparse_index_for_slim_index.append(index) + # Weights + w_tl = (1 - t_row) * (1 - t_col) + w_tr = (1 - t_row) * t_col + w_bl = t_row * (1 - t_col) + w_br = t_row * t_col + weights = jnp.stack([w_tl, w_tr, w_bl, w_br], axis=1) - return sparse_index_for_slim_index + return flat_indices, weights -@numba_util.jit() -def pixel_weights_delaunay_from( - source_plane_data_grid, - source_plane_mesh_grid, - slim_index_for_sub_slim_index: np.ndarray, - pix_indexes_for_sub_slim_index, -) -> np.ndarray: +def rectangular_mappings_weights_via_interpolation_from( + shape_native: Tuple[int, int], + source_plane_data_grid: jnp.ndarray, + source_plane_mesh_grid: jnp.ndarray, +): """ - Returns the weights of the mappings between the masked sub-pixels and the Delaunay pixelization. + Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid. + + Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function + determines for each irregular point: + - the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and + - the bilinear interpolation weights with respect to those pixels. - Weights are determiend via a nearest neighbor interpolation scheme, whereby every data-sub pixel maps to three - Delaunay pixel vertexes (in the source frame). The weights of these 3 mappings depends on the distance of the - coordinate to each vertex, with the highest weight being its closest neighbor, + The function supports JAX and is compatible with JIT compilation. Parameters ---------- + shape_native + The shape (Ny, Nx) of the original rectangular mesh grid before flattening. source_plane_data_grid - A 2D grid of (y,x) coordinates associated with the unmasked 2D data after it has been transformed to the - `source` reference frame. + The irregular grid of (y, x) points to interpolate. source_plane_mesh_grid - The 2D grid of (y,x) centres of every pixelization pixel in the `source` frame. - slim_index_for_sub_slim_index - The mappings between the data's sub slimmed indexes and the slimmed indexes on the non sub-sized indexes. - pix_indexes_for_sub_slim_index - The mappings from a data sub-pixel index to a pixelization pixel index. - """ - - pixel_weights = np.zeros(pix_indexes_for_sub_slim_index.shape) - - for sub_slim_index in range(slim_index_for_sub_slim_index.shape[0]): - pix_indexes = pix_indexes_for_sub_slim_index[sub_slim_index] - - if pix_indexes[1] != -1: - vertices_of_the_simplex = source_plane_mesh_grid[pix_indexes] + The flattened regular rectangular mesh grid of (y, x) coordinates. - sub_gird_coordinate_on_source_place = source_plane_data_grid[sub_slim_index] - - area_0 = mesh_util.delaunay_triangle_area_from( - corner_0=vertices_of_the_simplex[1], - corner_1=vertices_of_the_simplex[2], - corner_2=sub_gird_coordinate_on_source_place, - ) - area_1 = mesh_util.delaunay_triangle_area_from( - corner_0=vertices_of_the_simplex[0], - corner_1=vertices_of_the_simplex[2], - corner_2=sub_gird_coordinate_on_source_place, - ) - area_2 = mesh_util.delaunay_triangle_area_from( - corner_0=vertices_of_the_simplex[0], - corner_1=vertices_of_the_simplex[1], - corner_2=sub_gird_coordinate_on_source_place, - ) - - norm = area_0 + area_1 + area_2 - - weight_abc = np.array([area_0, area_1, area_2]) / norm - - pixel_weights[sub_slim_index] = weight_abc + Returns + ------- + mappings : jnp.ndarray of shape (N, 4) + Indices of the four nearest rectangular mesh pixels in the flattened mesh grid. + Order is: top-left, top-right, bottom-left, bottom-right. + weights : jnp.ndarray of shape (N, 4) + Bilinear interpolation weights corresponding to the four nearest mesh pixels. + + Notes + ----- + - Assumes the mesh grid is uniformly spaced. + - The weights sum to 1 for each irregular point. + - Uses bilinear interpolation in the (y, x) coordinate system. + """ + source_plane_mesh_grid = source_plane_mesh_grid.reshape(*shape_native, 2) - else: - pixel_weights[sub_slim_index][0] = 1.0 + # Assume mesh is shaped (Ny, Nx, 2) + Ny, Nx = source_plane_mesh_grid.shape[:2] - return pixel_weights + # Get mesh spacings and lower corner + y_coords = source_plane_mesh_grid[:, 0, 0] # shape (Ny,) + x_coords = source_plane_mesh_grid[0, :, 1] # shape (Nx,) + dy = y_coords[1] - y_coords[0] + dx = x_coords[1] - x_coords[0] -def pix_size_weights_voronoi_nn_from( - grid: np.ndarray, mesh_grid: np.ndarray -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Returns the mappings between a set of slimmed sub-grid pixels and pixelization pixels, using information on - how the pixels hosting each sub-pixel map to their closest pixelization pixel on the slim grid in the data-plane - and the pixelization's pixel centres. + y_min = y_coords[0] + x_min = x_coords[0] - To determine the complete set of slim sub-pixel to pixelization pixel mappings, we must pair every sub-pixel to - its nearest pixel. Using a full nearest neighbor search to do this is slow, thus the pixel neighbors (derived via - the Voronoi grid) are used to localize each nearest neighbor search by using a graph search. + # shape (N_irregular, 2) + irregular = source_plane_data_grid - Parameters - ---------- - grid - The grid of (y,x) scaled coordinates at the centre of every unmasked pixel, which has been traced to - to an irgrid via lens. - slim_index_for_sub_slim_index - The mappings between the data slimmed sub-pixels and their regular pixels. - mesh_grid - The (y,x) centre of every Voronoi pixel in arc-seconds. - neighbors - An array of length (voronoi_pixels) which provides the index of all neighbors of every pixel in - the Voronoi grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (voronoi_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - """ + # Compute normalized mesh coordinates (floating indices) + fy = (irregular[:, 0] - y_min) / dy + fx = (irregular[:, 1] - x_min) / dx - try: - from autoarray.util.nn import nn_py - except ImportError as e: - raise ImportError( - "In order to use the Voronoi pixelization you must install the " - "Natural Neighbor Interpolation c package.\n\n" - "" - "See: https://github.com/Jammy2211/PyAutoArray/tree/main/autoarray/util/nn" - ) from e - - max_nneighbours = conf.instance["general"]["pixelization"][ - "voronoi_nn_max_interpolation_neighbors" - ] - - ( - pix_weights_for_sub_slim_index, - pix_indexes_for_sub_slim_index, - ) = nn_py.natural_interpolation_weights( - x_in=mesh_grid[:, 1], - y_in=mesh_grid[:, 0], - x_target=grid[:, 1], - y_target=grid[:, 0], - max_nneighbours=max_nneighbours, - ) + # Integer indices of top-left corners + ix = jnp.floor(fx).astype(jnp.int32) + iy = jnp.floor(fy).astype(jnp.int32) - bad_indexes = np.argwhere(np.sum(pix_weights_for_sub_slim_index < 0.0, axis=1) > 0) - - ( - pix_weights_for_sub_slim_index, - pix_indexes_for_sub_slim_index, - ) = remove_bad_entries_voronoi_nn( - bad_indexes=bad_indexes, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - grid=np.array(grid), - mesh_grid=np.array(mesh_grid), - ) + # Clip to stay within bounds + ix = jnp.clip(ix, 0, Nx - 2) + iy = jnp.clip(iy, 0, Ny - 2) - bad_indexes = np.argwhere(pix_indexes_for_sub_slim_index[:, 0] == -1) - - ( - pix_weights_for_sub_slim_index, - pix_indexes_for_sub_slim_index, - ) = remove_bad_entries_voronoi_nn( - bad_indexes=bad_indexes, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - grid=np.array(grid), - mesh_grid=np.array(mesh_grid), - ) + # Local coordinates inside the cell (0 <= tx, ty <= 1) + tx = fx - ix + ty = fy - iy - pix_indexes_for_sub_slim_index_sizes = np.sum( - pix_indexes_for_sub_slim_index != -1, axis=1 - ) + # Bilinear weights + w00 = (1 - tx) * (1 - ty) + w10 = tx * (1 - ty) + w01 = (1 - tx) * ty + w11 = tx * ty - if np.max(pix_indexes_for_sub_slim_index_sizes) > max_nneighbours: - raise exc.MeshException( - f""" - The number of Voronoi natural neighbours interpolations in one or more pixelization pixel's - exceeds the maximum allowed: max_nneighbors = {max_nneighbours}. - - To fix this, increase the value of `voronoi_nn_max_interpolation_neighbors` in the [pixelization] - section of the `general.ini` config file. - """ - ) - - return ( - pix_indexes_for_sub_slim_index, - pix_indexes_for_sub_slim_index_sizes, - pix_weights_for_sub_slim_index, - ) + weights = jnp.stack([w00, w10, w01, w11], axis=1) # shape (N_irregular, 4) + # Compute indices of 4 surrounding pixels in the flattened mesh + i00 = iy * Nx + ix + i10 = iy * Nx + (ix + 1) + i01 = (iy + 1) * Nx + ix + i11 = (iy + 1) * Nx + (ix + 1) -@numba_util.jit() -def remove_bad_entries_voronoi_nn( - bad_indexes, - pix_weights_for_sub_slim_index, - pix_indexes_for_sub_slim_index, - grid, - mesh_grid, -): - """ - The nearest neighbor interpolation can return invalid or bad entries which are removed from the mapping arrays. The - current circumstances this arises are: + mappings = jnp.stack([i00, i10, i01, i11], axis=1) # shape (N_irregular, 4) - 1) If a point is outside the whole Voronoi region, some weights have negative values. In this case, we reset its - neighbor to its closest neighbor. + return mappings, weights - 2) The nearest neighbor interpolation code may not return even a single neighbor. We mark these as a bad grid by - settings their neighbors to the closest ones. - Parameters - ---------- - bad_indexes - pix_weights_for_sub_slim_index - pix_indexes_for_sub_slim_index - grid - mesh_grid +def nearest_pixelization_index_for_slim_index_from_kdtree(grid, mesh_grid): + from scipy.spatial import cKDTree - Returns - ------- + kdtree = cKDTree(mesh_grid) - """ + sparse_index_for_slim_index = [] - for item in bad_indexes: - ind = item[0] - pix_indexes_for_sub_slim_index[ind] = -1 - pix_indexes_for_sub_slim_index[ind][0] = np.argmin( - np.sum((grid[ind] - mesh_grid) ** 2.0, axis=1) - ) - pix_weights_for_sub_slim_index[ind] = 0.0 - pix_weights_for_sub_slim_index[ind][0] = 1.0 + for i in range(grid.shape[0]): + input_point = [grid[i, [0]], grid[i, 1]] + index = kdtree.query(input_point)[1] + sparse_index_for_slim_index.append(index) - return pix_weights_for_sub_slim_index, pix_indexes_for_sub_slim_index + return sparse_index_for_slim_index -@numba_util.jit() def adaptive_pixel_signals_from( pixels: int, pixel_weights: np.ndarray, @@ -443,33 +320,47 @@ def adaptive_pixel_signals_from( The image of the galaxy which is used to compute the weigghted pixel signals. """ - pixel_signals = np.zeros((pixels,)) - pixel_sizes = np.zeros((pixels,)) + M_sub, B = pix_indexes_for_sub_slim_index.shape + + # 1) Flatten the per‐mapping tables: + flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,) + flat_weights = pixel_weights.reshape(-1) # (M_sub*B,) - for sub_slim_index in range(len(pix_indexes_for_sub_slim_index)): - vertices_indexes = pix_indexes_for_sub_slim_index[sub_slim_index] + # 2) Build a matching “parent‐slim” index for each flattened entry: + I_sub = jnp.repeat(jnp.arange(M_sub), B) # (M_sub*B,) - mask_1d_index = slim_index_for_sub_slim_index[sub_slim_index] + # 3) Mask out any k >= pix_size_for_sub_slim_index[i] + valid = I_sub < 0 # dummy to get shape + # better: + valid = (jnp.arange(B)[None, :] < pix_size_for_sub_slim_index[:, None]).reshape(-1) - pix_size_tem = pix_size_for_sub_slim_index[sub_slim_index] + flat_weights = jnp.where(valid, flat_weights, 0.0) + flat_pixidx = jnp.where( + valid, flat_pixidx, pixels + ) # send invalid indices to an out-of-bounds slot - if pix_size_tem > 1: - pixel_signals[vertices_indexes[:pix_size_tem]] += ( - adapt_data[mask_1d_index] * pixel_weights[sub_slim_index] - ) - pixel_sizes[vertices_indexes] += 1 - else: - pixel_signals[vertices_indexes[0]] += adapt_data[mask_1d_index] - pixel_sizes[vertices_indexes[0]] += 1 + # 4) Look up data & multiply by mapping weights: + flat_data_vals = adapt_data[slim_index_for_sub_slim_index][I_sub] # (M_sub*B,) + flat_contrib = flat_data_vals * flat_weights # (M_sub*B,) - pixel_sizes[pixel_sizes == 0] = 1 - pixel_signals /= pixel_sizes - pixel_signals /= np.max(pixel_signals) + # 5) Scatter‐add into signal sums and counts: + pixel_signals = jnp.zeros((pixels + 1,)).at[flat_pixidx].add(flat_contrib) + pixel_counts = jnp.zeros((pixels + 1,)).at[flat_pixidx].add(valid.astype(float)) + # 6) Drop the extra “out-of-bounds” slot: + pixel_signals = pixel_signals[:pixels] + pixel_counts = pixel_counts[:pixels] + + # 7) Normalize + pixel_counts = jnp.where(pixel_counts > 0, pixel_counts, 1.0) + pixel_signals = pixel_signals / pixel_counts + max_sig = jnp.max(pixel_signals) + pixel_signals = jnp.where(max_sig > 0, pixel_signals / max_sig, pixel_signals) + + # 8) Exponentiate return pixel_signals**signal_scale -@numba_util.jit() def mapping_matrix_from( pix_indexes_for_sub_slim_index: np.ndarray, pix_size_for_sub_slim_index: np.ndarray, @@ -550,87 +441,110 @@ def mapping_matrix_from( sub_fraction The fractional area each sub-pixel takes up in an pixel. """ + M_sub, B = pix_indexes_for_sub_slim_index.shape + M = total_mask_pixels + S = pixels + + # 1) Flatten + flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,) + flat_w = pix_weights_for_sub_slim_index.reshape(-1) # (M_sub*B,) + flat_parent = jnp.repeat(slim_index_for_sub_slim_index, B) # (M_sub*B,) + flat_count = jnp.repeat(pix_size_for_sub_slim_index, B) # (M_sub*B,) + + # 2) Build valid mask: k < pix_size[i] + k = jnp.tile(jnp.arange(B), M_sub) # (M_sub*B,) + valid = k < flat_count # (M_sub*B,) - mapping_matrix = np.zeros((total_mask_pixels, pixels)) + # 3) Zero out invalid weights + flat_w = flat_w * valid.astype(flat_w.dtype) - for sub_slim_index in range(slim_index_for_sub_slim_index.shape[0]): - slim_index = slim_index_for_sub_slim_index[sub_slim_index] + # 4) Redirect -1 indices to extra bin S + OUT = S + flat_pixidx = jnp.where(flat_pixidx < 0, OUT, flat_pixidx) - for pix_count in range(pix_size_for_sub_slim_index[sub_slim_index]): - pix_index = pix_indexes_for_sub_slim_index[sub_slim_index, pix_count] - pix_weight = pix_weights_for_sub_slim_index[sub_slim_index, pix_count] + # 5) Multiply by sub_fraction of the slim row + flat_frac = sub_fraction[flat_parent] # (M_sub*B,) + flat_contrib = flat_w * flat_frac # (M_sub*B,) - mapping_matrix[slim_index][pix_index] += ( - sub_fraction[slim_index] * pix_weight - ) + # 6) Scatter into (M × (S+1)), summing duplicates + mat = jnp.zeros((M, S + 1), dtype=flat_contrib.dtype) + mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib) - return mapping_matrix + # 7) Drop the extra column and return + return mat[:, :S] -@numba_util.jit() def mapped_to_source_via_mapping_matrix_from( mapping_matrix: np.ndarray, array_slim: np.ndarray ) -> np.ndarray: """ - Map a masked 2d image in the image domain to the source domain and sum up all mappings on the source-pixels. + Map a masked 2D image (in slim form) into the source plane by summing and averaging + each image-pixel's contribution to its mapped source-pixels. - For example, suppose we have an image and a mapper. We can map every image-pixel to its corresponding mapper's - source pixel and sum the values based on these mappings. - - This will produce something similar to a `reconstruction`, albeit it bypasses the linear algebra / inversion. + Each row i of `mapping_matrix` describes how image-pixel i is distributed (with + weights) across the source-pixels j. `array_slim[i]` is then multiplied by those + weights and summed over i to give each source-pixel’s total mapped value; finally, + we divide by the number of nonzero contributions to form an average. Parameters ---------- - mapping_matrix - The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels. - array_slim - The masked 2D array of values in its slim representation (e.g. the image data) which are mapped to the - source domain in order to compute their average values. - """ - - mapped_to_source = np.zeros(mapping_matrix.shape[1]) + mapping_matrix : ndarray of shape (M, N) + mapping_matrix[i, j] ≥ 0 is the weight by which image-pixel i contributes to + source-pixel j. Zero means “no contribution.” + array_slim : ndarray of shape (M,) + The slimmed image values for each image-pixel i. - source_pixel_count = np.zeros(mapping_matrix.shape[1]) + Returns + ------- + mapped_to_source : ndarray of shape (N,) + The averaged, mapped values on each of the N source-pixels. + """ + # weighted sums: sum over i of array_slim[i] * mapping_matrix[i, j] + # ==> vector‐matrix multiply: (1×M) dot (M×N) → (N,) + mapped_to_source = array_slim @ mapping_matrix - for i in range(mapping_matrix.shape[0]): - for j in range(mapping_matrix.shape[1]): - if mapping_matrix[i, j] > 0: - mapped_to_source[j] += array_slim[i] * mapping_matrix[i, j] - source_pixel_count[j] += 1 + # count how many nonzero contributions each source-pixel j received + counts = np.count_nonzero(mapping_matrix > 0.0, axis=0) - for j in range(mapping_matrix.shape[1]): - if source_pixel_count[j] > 0: - mapped_to_source[j] /= source_pixel_count[j] + # avoid division by zero: only divide where counts > 0 + nonzero = counts > 0 + mapped_to_source[nonzero] /= counts[nonzero] return mapped_to_source -@numba_util.jit() def data_weight_total_for_pix_from( - pix_indexes_for_sub_slim_index: np.ndarray, - pix_weights_for_sub_slim_index: np.ndarray, + pix_indexes_for_sub_slim_index: np.ndarray, # shape (M, B) + pix_weights_for_sub_slim_index: np.ndarray, # shape (M, B) pixels: int, ) -> np.ndarray: """ - Returns the total weight of every pixelization pixel, which is the sum of the weights of all data-points that - map to that pixel. + Returns the total weight of every pixelization pixel, which is the sum of + the weights of all data‐points (sub‐pixels) that map to that pixel. Parameters ---------- - pix_indexes_for_sub_slim_index - The mappings from a data sub-pixel index to a pixelization pixel index. - pix_weights_for_sub_slim_index - The weights of the mappings of every data sub-pixel and pixelization pixel. - pixels - The number of pixels in the pixelization. - """ + pix_indexes_for_sub_slim_index : np.ndarray, shape (M, B), int + For each of M sub‐slim indexes, the B pixelization‐pixel indices it maps to. + pix_weights_for_sub_slim_index : np.ndarray, shape (M, B), float + For each of those mappings, the corresponding interpolation weight. + pixels : int + The total number of pixelization pixels N. - pix_weight_total = np.zeros(pixels) + Returns + ------- + np.ndarray, shape (N,) + The per‐pixel total weight: for each j in [0..N-1], the sum of all + pix_weights_for_sub_slim_index[i,k] such that pix_indexes_for_sub_slim_index[i,k] == j. + """ + # Flatten arrays + flat_idxs = pix_indexes_for_sub_slim_index.ravel() + flat_weights = pix_weights_for_sub_slim_index.ravel() - for slim_index, pix_indexes in enumerate(pix_indexes_for_sub_slim_index): - for pix_index, weight in zip( - pix_indexes, pix_weights_for_sub_slim_index[slim_index] - ): - pix_weight_total[int(pix_index)] += weight + # Filter out -1 (invalid mappings) + valid_mask = flat_idxs >= 0 + flat_idxs = flat_idxs[valid_mask] + flat_weights = flat_weights[valid_mask] - return pix_weight_total + # Sum weights by pixel index + return np.bincount(flat_idxs, weights=flat_weights, minlength=pixels) diff --git a/autoarray/inversion/pixelization/mappers/rectangular.py b/autoarray/inversion/pixelization/mappers/rectangular.py index 9a78c1b8a..14fd3fd9f 100644 --- a/autoarray/inversion/pixelization/mappers/rectangular.py +++ b/autoarray/inversion/pixelization/mappers/rectangular.py @@ -1,4 +1,4 @@ -import numpy as np +import jax.numpy as jnp from typing import Tuple from autoconf import cached_property @@ -6,8 +6,7 @@ from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights -from autoarray.numba_util import profile_func -from autoarray.geometry import geometry_util +from autoarray.inversion.pixelization.mappers import mapper_util class MapperRectangular(AbstractMapper): @@ -56,8 +55,6 @@ class MapperRectangular(AbstractMapper): regularization The regularization scheme which may be applied to this linear object in order to smooth its solution, which for a mapper smooths neighboring pixels on the mesh. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ @property @@ -65,7 +62,6 @@ def shape_native(self) -> Tuple[int, ...]: return self.source_plane_mesh_grid.shape_native @cached_property - @profile_func def pix_sub_weights(self) -> PixSubWeights: """ Computes the following three quantities describing the mappings between of every sub-pixel in the masked data @@ -99,19 +95,51 @@ def pix_sub_weights(self) -> PixSubWeights: dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index` are equal to 1.0. """ - mappings = geometry_util.grid_pixel_indexes_2d_slim_from( - grid_scaled_2d_slim=np.array(self.source_plane_data_grid.over_sampled), - shape_native=self.source_plane_mesh_grid.shape_native, - pixel_scales=self.source_plane_mesh_grid.pixel_scales, - origin=self.source_plane_mesh_grid.origin, - ).astype("int") - - mappings = mappings.reshape((len(mappings), 1)) + mappings, weights = ( + mapper_util.adaptive_rectangular_mappings_weights_via_interpolation_from( + source_grid_size=self.shape_native[0], + source_plane_data_grid=self.source_plane_data_grid.array, + source_plane_data_grid_over_sampled=jnp.array( + self.source_plane_data_grid.over_sampled + ), + ) + ) return PixSubWeights( mappings=mappings, - sizes=np.ones(len(mappings), dtype="int"), - weights=np.ones( - (len(self.source_plane_data_grid.over_sampled), 1), dtype="int" - ), + sizes=4 * jnp.ones(len(mappings), dtype="int"), + weights=weights, + ) + + @cached_property + def areas_transformed(self): + """ + A class packing the ndarrays describing the neighbors of every pixel in the rectangular pixelization (see + `Neighbors` for a complete description of the neighboring scheme). + + The neighbors of a rectangular pixelization are computed by exploiting the uniform and symmetric nature of the + rectangular grid, as described in the method `mesh_util.rectangular_neighbors_from`. + """ + return mapper_util.adaptive_rectangular_areas_from( + source_grid_size=self.shape_native[0], + source_plane_data_grid=self.source_plane_data_grid.array, + ) + + @cached_property + def edges_transformed(self): + """ + A class packing the ndarrays describing the neighbors of every pixel in the rectangular pixelization (see + `Neighbors` for a complete description of the neighboring scheme). + + The neighbors of a rectangular pixelization are computed by exploiting the uniform and symmetric nature of the + rectangular grid, as described in the method `mesh_util.rectangular_neighbors_from`. + """ + + # edges defined in 0 -> 1 space, there is one more edge than pixel centers on each side + edges = jnp.linspace(0, 1, self.shape_native[0] + 1) + edges_reshaped = jnp.stack([edges, edges]).T + + return mapper_util.adaptive_rectangular_transformed_grid_from( + source_plane_data_grid=self.source_plane_data_grid.array, + grid=edges_reshaped, ) diff --git a/autoarray/inversion/pixelization/mappers/rectangular_uniform.py b/autoarray/inversion/pixelization/mappers/rectangular_uniform.py new file mode 100644 index 000000000..3c58813cb --- /dev/null +++ b/autoarray/inversion/pixelization/mappers/rectangular_uniform.py @@ -0,0 +1,106 @@ +import jax.numpy as jnp + +from autoconf import cached_property + +from autoarray.inversion.pixelization.mappers.rectangular import MapperRectangular +from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights + +from autoarray.inversion.pixelization.mappers import mapper_util + + +class MapperRectangularUniform(MapperRectangular): + """ + To understand a `Mapper` one must be familiar `Mesh` objects and the `mesh` and `pixelization` packages, where + the four grids grouped in a `MapperGrids` object are explained (`image_plane_data_grid`, `source_plane_data_grid`, + `image_plane_mesh_grid`,`source_plane_mesh_grid`) + + If you are unfamliar withe above objects, read through the docstrings of the `pixelization`, `mesh` and + `mapper_grids` packages. + + A `Mapper` determines the mappings between the masked data grid's pixels (`image_plane_data_grid` and + `source_plane_data_grid`) and the mesh's pixels (`image_plane_mesh_grid` and `source_plane_mesh_grid`). + + The 1D Indexing of each grid is identical in the `data` and `source` frames (e.g. the transformation does not + change the indexing, such that `source_plane_data_grid[0]` corresponds to the transformed value + of `image_plane_data_grid[0]` and so on). + + A mapper therefore only needs to determine the index mappings between the `grid_slim` and `mesh_grid`, + noting that associations are made by pairing `source_plane_mesh_grid` with `source_plane_data_grid`. + + Mappings are represented in the 2D ndarray `pix_indexes_for_sub_slim_index`, whereby the index of + a pixel on the `mesh_grid` maps to the index of a pixel on the `grid_slim` as follows: + + - pix_indexes_for_sub_slim_index[0, 0] = 0: the data's 1st sub-pixel maps to the mesh's 1st pixel. + - pix_indexes_for_sub_slim_index[1, 0] = 3: the data's 2nd sub-pixel maps to the mesh's 4th pixel. + - pix_indexes_for_sub_slim_index[2, 0] = 1: the data's 3rd sub-pixel maps to the mesh's 2nd pixel. + + The second dimension of this array (where all three examples above are 0) is used for cases where a + single pixel on the `grid_slim` maps to multiple pixels on the `mesh_grid`. For example, a + `Delaunay` triangulation, where every `grid_slim` pixel maps to three Delaunay pixels (the corners of the + triangles) with varying interpolation weights . + + For a `Rectangular` mesh every pixel in the masked data maps to only one pixel, thus the second + dimension of `pix_indexes_for_sub_slim_index` is always of size 1. + + The mapper allows us to create a mapping matrix, which is a matrix representing the mapping between every + unmasked data pixel annd the pixels of a mesh. This matrix is the basis of performing an `Inversion`, + which reconstructs the data using the `source_plane_mesh_grid`. + + Parameters + ---------- + mapper_grids + An object containing the data grid and mesh grid in both the data-frame and source-frame used by the + mapper to map data-points to linear object parameters. + regularization + The regularization scheme which may be applied to this linear object in order to smooth its solution, + which for a mapper smooths neighboring pixels on the mesh. + """ + + @cached_property + def pix_sub_weights(self) -> PixSubWeights: + """ + Computes the following three quantities describing the mappings between of every sub-pixel in the masked data + and pixel in the `Rectangular` mesh. + + - `pix_indexes_for_sub_slim_index`: the mapping of every data pixel (given its `sub_slim_index`) + to mesh pixels (given their `pix_indexes`). + + - `pix_sizes_for_sub_slim_index`: the number of mappings of every data pixel to mesh pixels. + + - `pix_weights_for_sub_slim_index`: the interpolation weights of every data pixel's mesh + pixel mapping + + These are packaged into the class `PixSubWeights` with attributes `mappings`, `sizes` and `weights`. + + The `sub_slim_index` refers to the masked data sub-pixels and `pix_indexes` the mesh pixel indexes, + for example: + + - `pix_indexes_for_sub_slim_index[0, 0] = 2`: The data's first (index 0) sub-pixel maps to the Rectangular + mesh's third (index 2) pixel. + + - `pix_indexes_for_sub_slim_index[2, 0] = 4`: The data's third (index 2) sub-pixel maps to the Rectangular + mesh's fifth (index 4) pixel. + + The second dimension of the array `pix_indexes_for_sub_slim_index`, which is 0 in both examples above, is used + for cases where a data pixel maps to more than one mesh pixel (for example a `Delaunay` triangulation + where each data pixel maps to 3 Delaunay triangles with interpolation weights). The weights of multiple mappings + are stored in the array `pix_weights_for_sub_slim_index`. + + For a Rectangular pixelization each data sub-pixel maps to a single mesh pixel, thus the second + dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index` + are equal to 1.0. + """ + + mappings, weights = ( + mapper_util.rectangular_mappings_weights_via_interpolation_from( + shape_native=self.shape_native, + source_plane_mesh_grid=self.source_plane_mesh_grid.array, + source_plane_data_grid=self.source_plane_data_grid.over_sampled, + ) + ) + + return PixSubWeights( + mappings=mappings, + sizes=4 * jnp.ones(len(mappings), dtype="int"), + weights=weights, + ) diff --git a/autoarray/inversion/pixelization/mappers/voronoi.py b/autoarray/inversion/pixelization/mappers/voronoi.py index c8e54cbf3..db364319e 100644 --- a/autoarray/inversion/pixelization/mappers/voronoi.py +++ b/autoarray/inversion/pixelization/mappers/voronoi.py @@ -7,8 +7,7 @@ from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.numba_util import profile_func -from autoarray.inversion.pixelization.mappers import mapper_util +from autoarray.inversion.pixelization.mappers import mapper_numba_util class MapperVoronoi(AbstractMapper): @@ -57,8 +56,6 @@ class MapperVoronoi(AbstractMapper): regularization The regularization scheme which may be applied to this linear object in order to smooth its solution, which for a mapper smooths neighboring pixels on the mesh. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ @property @@ -78,7 +75,7 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights: This property returns a unique set of `PixSubWeights` used for these regularization schemes which compute mappings and weights at each point on the split cross. """ - (mappings, sizes, weights) = mapper_util.pix_size_weights_voronoi_nn_from( + (mappings, sizes, weights) = mapper_numba_util.pix_size_weights_voronoi_nn_from( grid=self.source_plane_mesh_grid.split_cross, mesh_grid=self.source_plane_mesh_grid, ) @@ -86,7 +83,6 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights: return PixSubWeights(mappings=mappings, sizes=sizes, weights=weights) @cached_property - @profile_func def pix_sub_weights(self) -> PixSubWeights: """ Computes the following three quantities describing the mappings between of every sub-pixel in the masked data @@ -129,7 +125,7 @@ def pix_sub_weights(self) -> PixSubWeights: The interpolation weights of these multiple mappings are stored in the array `pix_weights_for_sub_slim_index`. """ - mappings, sizes, weights = mapper_util.pix_size_weights_voronoi_nn_from( + mappings, sizes, weights = mapper_numba_util.pix_size_weights_voronoi_nn_from( grid=self.source_plane_data_grid.over_sampled, mesh_grid=self.source_plane_mesh_grid, ) @@ -172,5 +168,8 @@ def interpolated_array_from( is input. """ return self.source_plane_mesh_grid.interpolated_array_from( - values=values, shape_native=shape_native, extent=extent, use_nn=True + values=np.array(values), + shape_native=shape_native, + extent=extent, + use_nn=True, ) diff --git a/autoarray/inversion/pixelization/mesh/__init__.py b/autoarray/inversion/pixelization/mesh/__init__.py index a14f53f69..28f35f116 100644 --- a/autoarray/inversion/pixelization/mesh/__init__.py +++ b/autoarray/inversion/pixelization/mesh/__init__.py @@ -1,4 +1,5 @@ from .abstract import AbstractMesh as Mesh from .rectangular import Rectangular +from .rectangular_uniform import RectangularUniform from .voronoi import Voronoi from .delaunay import Delaunay diff --git a/autoarray/inversion/pixelization/mesh/abstract.py b/autoarray/inversion/pixelization/mesh/abstract.py index 95d3d1ce3..772e02c05 100644 --- a/autoarray/inversion/pixelization/mesh/abstract.py +++ b/autoarray/inversion/pixelization/mesh/abstract.py @@ -6,14 +6,11 @@ from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.grids.irregular_2d import Grid2DIrregular -from autoarray.numba_util import profile_func - class AbstractMesh: def __eq__(self, other): return self.__dict__ == other.__dict__ and self.__class__ is other.__class__ - @profile_func def relocated_grid_from( self, border_relocator: BorderRelocator, @@ -44,9 +41,14 @@ def relocated_grid_from( """ if border_relocator is not None: return border_relocator.relocated_grid_from(grid=source_plane_data_grid) - return source_plane_data_grid - @profile_func + return Grid2D( + values=source_plane_data_grid.array, + mask=source_plane_data_grid.mask, + over_sample_size=source_plane_data_grid.over_sampler.sub_size, + over_sampled=source_plane_data_grid.over_sampled.array, + ) + def relocated_mesh_grid_from( self, border_relocator: Optional[BorderRelocator], @@ -95,7 +97,6 @@ def mapper_grids_from( source_plane_mesh_grid: Optional[Grid2DIrregular] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: np.ndarray = None, - run_time_dict: Optional[Dict] = None, ) -> MapperGrids: raise NotImplementedError diff --git a/autoarray/inversion/pixelization/mesh/delaunay.py b/autoarray/inversion/pixelization/mesh/delaunay.py index ca8142b38..93675f023 100644 --- a/autoarray/inversion/pixelization/mesh/delaunay.py +++ b/autoarray/inversion/pixelization/mesh/delaunay.py @@ -1,8 +1,6 @@ from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay from autoarray.inversion.pixelization.mesh.triangulation import Triangulation -from autoarray.numba_util import profile_func - class Delaunay(Triangulation): def __init__(self): @@ -31,7 +29,6 @@ def __init__(self): """ super().__init__() - @profile_func def mesh_grid_from( self, source_plane_data_grid=None, diff --git a/autoarray/inversion/pixelization/mesh/mesh_numba_util.py b/autoarray/inversion/pixelization/mesh/mesh_numba_util.py new file mode 100644 index 000000000..e3fc2232f --- /dev/null +++ b/autoarray/inversion/pixelization/mesh/mesh_numba_util.py @@ -0,0 +1,301 @@ +import numpy as np + +from typing import List, Tuple, Union + +from autoarray import numba_util + + +@numba_util.jit() +def delaunay_triangle_area_from( + corner_0: Tuple[float, float], + corner_1: Tuple[float, float], + corner_2: Tuple[float, float], +) -> float: + """ + Returns the area within a Delaunay triangle where the three corners are located at the (x,y) coordinates given by + the inputs `corner_a` `corner_b` and `corner_c`. + + This function actually returns the area of any triangle, but the term `delaunay` is included in the title to + separate it from the `rectangular` and `voronoi` methods in `mesh_util.py`. + + Parameters + ---------- + corner_0 + The (x,y) coordinates of the triangle's first corner. + corner_1 + The (x,y) coordinates of the triangle's second corner. + corner_2 + The (x,y) coordinates of the triangle's third corner. + + Returns + ------- + The area of the triangle given the input (x,y) corners. + """ + + x1 = corner_0[0] + y1 = corner_0[1] + x2 = corner_1[0] + y2 = corner_1[1] + x3 = corner_2[0] + y3 = corner_2[1] + + return 0.5 * np.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3) + + +def delaunay_interpolated_array_from( + shape_native: Tuple[int, int], + interpolation_grid_slim: np.ndarray, + pixel_values: np.ndarray, + delaunay: "scipy.spatial.Delaunay", +) -> np.ndarray: + """ + Given a Delaunay triangulation and 1D values at the node of each Delaunay pixel (e.g. the connecting points where + triangles meet), interpolate these values to a uniform 2D (y,x) grid. + + By mapping the delaunay's value to a regular grid this enables a source reconstruction of an inversion to be + output to a .fits file. + + The `grid_interpolate_slim`, which gives the (y,x) coordinates the values are evaluated at for interpolation, + need not be regular and can have undergone coordinate transforms (e.g. it can be the `source_plane_mesh_grid`) + of a `Mapper`. + + The shape of `grid_interpolate_slim` therefore must be equal to `shape_native[0] * shape_native[1]`, but the (y,x) + coordinates themselves do not need to be uniform. + + Parameters + ---------- + shape_native + The 2D (y,x) shape of the uniform grid the values are interpolated on too. + interpolation_grid_slim + A 1D grid of (y,x) coordinates where each interpolation is evaluated. The shape of this grid must be equal to + shape_native[0] * shape_native[1], but it does not need to be uniform itself. + pixel_values + The values of the Delaunay nodes (e.g. the connecting points where triangles meet) which are interpolated + to compute the value in each pixel on the `interpolated_grid`. + delaunay + A `scipy.spatial.Delaunay` object which contains all functionality describing the Delaunay triangulation. + + Returns + ------- + The input values interpolated to the `grid_interpolate_slim` (y,x) coordintes given the Delaunay triangulation. + + """ + simplex_index_for_interpolate_index = delaunay.find_simplex(interpolation_grid_slim) + + simplices = delaunay.simplices + pixel_points = delaunay.points + + interpolated_array = np.zeros(len(interpolation_grid_slim)) + + for slim_index in range(len(interpolation_grid_slim)): + simplex_index = simplex_index_for_interpolate_index[slim_index] + interpolating_point = tuple(interpolation_grid_slim[slim_index]) + + if simplex_index == -1: + cloest_pixel_index = np.argmin( + np.sum((pixel_points - interpolating_point) ** 2.0, axis=1) + ) + interpolated_array[slim_index] = pixel_values[cloest_pixel_index] + else: + triangle_points = pixel_points[simplices[simplex_index]] + triangle_values = pixel_values[simplices[simplex_index]] + + area_0 = delaunay_triangle_area_from( + corner_0=triangle_points[1], + corner_1=triangle_points[2], + corner_2=interpolating_point, + ) + area_1 = delaunay_triangle_area_from( + corner_0=triangle_points[0], + corner_1=triangle_points[2], + corner_2=interpolating_point, + ) + area_2 = delaunay_triangle_area_from( + corner_0=triangle_points[0], + corner_1=triangle_points[1], + corner_2=interpolating_point, + ) + norm = area_0 + area_1 + area_2 + + weight_abc = np.array([area_0, area_1, area_2]) / norm + + interpolated_array[slim_index] = np.sum(weight_abc * triangle_values) + + return interpolated_array.reshape(shape_native) + + +@numba_util.jit() +def voronoi_neighbors_from( + pixels: int, ridge_points: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """ + Returns the adjacent neighbors of every pixel on a Voronoi mesh as an ndarray of shape + [total_pixels, voronoi_pixel_with_max_neighbors], using the `ridge_points` output from the `scipy.spatial.Voronoi()` + object. + + Entries with values of `-1` signify edge pixels which do not have neighbors. This function therefore also returns + an ndarray with the number of neighbors of every pixel, `neighbors_sizes`, which is iterated over when using + the `neighbors` ndarray. + + Indexing is defined in an arbritrary manner due to the irregular nature of a Voronoi mesh. + + For example, if `neighbors[0,:] = [1, 5, 36, 2, -1, -1]`, this informs us that the first Voronoi pixel has + 4 neighbors which have indexes 1, 5, 36, 2. Correspondingly `neighbors_sizes[0] = 4`. + + Parameters + ---------- + pixels + The number of pixels on the Voronoi mesh. + ridge_points + Contains the information on every Voronoi source pixel and its neighbors. + + Returns + ------- + The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. + """ + neighbors_sizes = np.zeros(shape=(pixels)) + + for ridge_index in range(ridge_points.shape[0]): + pair0 = ridge_points[ridge_index, 0] + pair1 = ridge_points[ridge_index, 1] + neighbors_sizes[pair0] += 1 + neighbors_sizes[pair1] += 1 + + neighbors_index = np.zeros(shape=(pixels)) + neighbors = -1 * np.ones(shape=(pixels, int(np.max(neighbors_sizes)))) + + for ridge_index in range(ridge_points.shape[0]): + pair0 = ridge_points[ridge_index, 0] + pair1 = ridge_points[ridge_index, 1] + neighbors[pair0, int(neighbors_index[pair0])] = pair1 + neighbors[pair1, int(neighbors_index[pair1])] = pair0 + neighbors_index[pair0] += 1 + neighbors_index[pair1] += 1 + + return neighbors, neighbors_sizes + + +def voronoi_edge_pixels_from(regions: np.ndarray, point_region: np.ndarray) -> List: + """ + Returns the edge pixels of a Voronoi mesh, where the edge pixels are defined as those pixels which are on the + edge of the Voronoi diagram. + + Parameters + ---------- + regions + Indices of the Voronoi vertices forming each Voronoi region, where -1 indicates vertex outside the Voronoi + diagram. + """ + + voronoi_edge_pixel_list = [] + + for index, i in enumerate(point_region): + if -1 in regions[i]: + voronoi_edge_pixel_list.append(index) + + return voronoi_edge_pixel_list + + +def voronoi_revised_from( + voronoi: "scipy.spatial.Voronoi", +) -> Union[List[Tuple], np.ndarray]: + """ + To plot a Voronoi mesh using the `matplotlib.fill()` function a revised Voronoi mesh must be + computed, where 2D infinite voronoi regions are converted to finite 2D regions. + + This function returns a list of tuples containing the indices of the vertices of each revised Voronoi cell and + a list of tuples containing the revised Voronoi vertex vertices. + + Parameters + ---------- + voronoi + The input Voronoi diagram that is being plotted. + """ + + if voronoi.points.shape[1] != 2: + raise ValueError("Requires 2D input") + + region_list = [] + vertex_list = voronoi.vertices.tolist() + + center = voronoi.points.mean(axis=0) + radius = np.ptp(voronoi.points).max() * 2 + + # Construct a map containing all ridges for a given point + all_ridges = {} + for (p1, p2), (v1, v2) in zip(voronoi.ridge_points, voronoi.ridge_vertices): + all_ridges.setdefault(p1, []).append((p2, v1, v2)) + all_ridges.setdefault(p2, []).append((p1, v1, v2)) + + # Reconstruct infinite regions + for p1, region in enumerate(voronoi.point_region): + vertices = voronoi.regions[region] + + if all(v >= 0 for v in vertices): + # finite region + region_list.append(vertices) + continue + + # reconstruct a non-finite region + ridges = all_ridges[p1] + region = [v for v in vertices if v >= 0] + + for p2, v1, v2 in ridges: + if v2 < 0: + v1, v2 = v2, v1 + if v1 >= 0: + # finite ridge: already in the region + continue + + # Compute the missing endpoint of an infinite ridge + + t = voronoi.points[p2] - voronoi.points[p1] # tangent + t /= np.linalg.norm(t) + n = np.array([-t[1], t[0]]) + + midpoint = voronoi.points[[p1, p2]].mean(axis=0) + direction = np.sign(np.dot(midpoint - center, n)) * n + far_point = voronoi.vertices[v2] + direction * radius + + region.append(len(vertex_list)) + vertex_list.append(far_point.tolist()) + + # sort region counterclockwise + vs = np.asarray([vertex_list[v] for v in region]) + c = vs.mean(axis=0) + angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0]) + region = np.array(region)[np.argsort(angles)] + + # finish + region_list.append(region.tolist()) + + return region_list, np.asarray(vertex_list) + + +def voronoi_nn_interpolated_array_from( + shape_native: Tuple[int, int], + interpolation_grid_slim: np.ndarray, + pixel_values: np.ndarray, + voronoi: "scipy.spatial.Voronoi", +) -> np.ndarray: + try: + from autoarray.util.nn import nn_py + except ImportError as e: + raise ImportError( + "In order to use the Voronoi pixelization you must install the " + "Natural Neighbor Interpolation c package.\n\n" + "" + "See: https://github.com/Jammy2211/PyAutoArray/tree/main/autoarray/util/nn" + ) from e + + pixel_points = voronoi.points + + interpolated_array = nn_py.natural_interpolation( + pixel_points[:, 0], + pixel_points[:, 1], + pixel_values, + interpolation_grid_slim[:, 1], + interpolation_grid_slim[:, 0], + ) + + return interpolated_array.reshape(shape_native) diff --git a/autoarray/inversion/pixelization/mesh/mesh_util.py b/autoarray/inversion/pixelization/mesh/mesh_util.py index 305b56b72..86e2c4110 100644 --- a/autoarray/inversion/pixelization/mesh/mesh_util.py +++ b/autoarray/inversion/pixelization/mesh/mesh_util.py @@ -1,11 +1,9 @@ +import jax.numpy as jnp import numpy as np -import scipy.spatial -from typing import List, Tuple, Union -from autoarray import numba_util +from typing import List, Tuple -@numba_util.jit() def rectangular_neighbors_from( shape_native: Tuple[int, int], ) -> Tuple[np.ndarray, np.ndarray]: @@ -68,7 +66,6 @@ def rectangular_neighbors_from( return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_corner_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -113,7 +110,6 @@ def rectangular_corner_neighbors( return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_top_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -136,17 +132,20 @@ def rectangular_top_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[1] - 1): - pixel_index = pix - neighbors[pixel_index, 0:3] = np.array( - [pixel_index - 1, pixel_index + 1, pixel_index + shape_native[1]] - ) - neighbors_sizes[pixel_index] = 3 + """ + Vectorized version of the top edge neighbor update using NumPy arithmetic. + """ + # Pixels along the top edge, excluding corners + top_edge_pixels = np.arange(1, shape_native[1] - 1) + + neighbors[top_edge_pixels, 0] = top_edge_pixels - 1 + neighbors[top_edge_pixels, 1] = top_edge_pixels + 1 + neighbors[top_edge_pixels, 2] = top_edge_pixels + shape_native[1] + neighbors_sizes[top_edge_pixels] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_left_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -169,21 +168,20 @@ def rectangular_left_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[0] - 1): - pixel_index = pix * shape_native[1] - neighbors[pixel_index, 0:3] = np.array( - [ - pixel_index - shape_native[1], - pixel_index + 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 3 + # Row indices (excluding top and bottom corners) + rows = np.arange(1, shape_native[0] - 1) + + # Convert to flat pixel indices for the left edge (first column) + pixel_indices = rows * shape_native[1] + + neighbors[pixel_indices, 0] = pixel_indices - shape_native[1] + neighbors[pixel_indices, 1] = pixel_indices + 1 + neighbors[pixel_indices, 2] = pixel_indices + shape_native[1] + neighbors_sizes[pixel_indices] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_right_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -206,21 +204,20 @@ def rectangular_right_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for pix in range(1, shape_native[0] - 1): - pixel_index = pix * shape_native[1] + shape_native[1] - 1 - neighbors[pixel_index, 0:3] = np.array( - [ - pixel_index - shape_native[1], - pixel_index - 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 3 + # Rows excluding the top and bottom corners + rows = np.arange(1, shape_native[0] - 1) + + # Flat indices for the right edge pixels + pixel_indices = rows * shape_native[1] + shape_native[1] - 1 + + neighbors[pixel_indices, 0] = pixel_indices - shape_native[1] + neighbors[pixel_indices, 1] = pixel_indices - 1 + neighbors[pixel_indices, 2] = pixel_indices + shape_native[1] + neighbors_sizes[pixel_indices] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_bottom_edge_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -243,19 +240,21 @@ def rectangular_bottom_edge_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - pixels = int(shape_native[0] * shape_native[1]) + n_rows, n_cols = shape_native + pixels = n_rows * n_cols + + # Horizontal pixel positions along bottom row, excluding corners + cols = np.arange(1, n_cols - 1) + pixel_indices = pixels - cols - 1 # Reverse order from right to left - for pix in range(1, shape_native[1] - 1): - pixel_index = pixels - pix - 1 - neighbors[pixel_index, 0:3] = np.array( - [pixel_index - shape_native[1], pixel_index - 1, pixel_index + 1] - ) - neighbors_sizes[pixel_index] = 3 + neighbors[pixel_indices, 0] = pixel_indices - n_cols + neighbors[pixel_indices, 1] = pixel_indices - 1 + neighbors[pixel_indices, 2] = pixel_indices + 1 + neighbors_sizes[pixel_indices] = 3 return neighbors, neighbors_sizes -@numba_util.jit() def rectangular_central_neighbors( neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int] ) -> Tuple[np.ndarray, np.ndarray]: @@ -279,339 +278,110 @@ def rectangular_central_neighbors( ------- The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. """ - for x in range(1, shape_native[0] - 1): - for y in range(1, shape_native[1] - 1): - pixel_index = x * shape_native[1] + y - neighbors[pixel_index, 0:4] = np.array( - [ - pixel_index - shape_native[1], - pixel_index - 1, - pixel_index + 1, - pixel_index + shape_native[1], - ] - ) - neighbors_sizes[pixel_index] = 4 + n_rows, n_cols = shape_native - return neighbors, neighbors_sizes + # Grid coordinates excluding edges + xs = np.arange(1, n_rows - 1) + ys = np.arange(1, n_cols - 1) + # 2D grid of central pixel indices + grid_x, grid_y = np.meshgrid(xs, ys, indexing="ij") + pixel_indices = grid_x * n_cols + grid_y + pixel_indices = pixel_indices.ravel() -def rectangular_edge_pixel_list_from(neighbors: np.ndarray) -> List: - """ - Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization. + # Compute neighbor indices + neighbors[pixel_indices, 0] = pixel_indices - n_cols # Up + neighbors[pixel_indices, 1] = pixel_indices - 1 # Left + neighbors[pixel_indices, 2] = pixel_indices + 1 # Right + neighbors[pixel_indices, 3] = pixel_indices + n_cols # Down - This is computed by searching the `neighbors` array for pixels that have a neighbor with index -1, meaning there - is at least one neighbor from the 4 expected missing. + neighbors_sizes[pixel_indices] = 4 - Parameters - ---------- - neighbors - An array of dimensions [total_pixels, 4] which provides the index of all neighbors of every pixel in the - rectangular pixelization (entries of -1 correspond to no neighbor). - - Returns - ------- - A list of the 1D indices of all pixels on the edge of a rectangular pixelization. - """ - edge_pixel_list = [] - - for i, neighbors in enumerate(neighbors): - if -1 in neighbors: - edge_pixel_list.append(i) - - return edge_pixel_list + return neighbors, neighbors_sizes -@numba_util.jit() -def delaunay_triangle_area_from( - corner_0: Tuple[float, float], - corner_1: Tuple[float, float], - corner_2: Tuple[float, float], -) -> float: +def rectangular_edges_from(shape_native, pixel_scales): """ - Returns the area within a Delaunay triangle where the three corners are located at the (x,y) coordinates given by - the inputs `corner_a` `corner_b` and `corner_c`. - - This function actually returns the area of any triangle, but the term `delaunay` is included in the title to - separate it from the `rectangular` and `voronoi` methods in `mesh_util.py`. - - Parameters - ---------- - corner_0 - The (x,y) coordinates of the triangle's first corner. - corner_1 - The (x,y) coordinates of the triangle's second corner. - corner_2 - The (x,y) coordinates of the triangle's third corner. - - Returns - ------- - The area of the triangle given the input (x,y) corners. + Returns all pixel edges for a rectangular grid as a JAX array of shape (N, 4, 2, 2), + where N = Ny * Nx. Edge order per pixel matches the user's convention: + + 0: (x1, y0) -> (x1, y1) + 1: (x1, y1) -> (x0, y1) + 2: (x0, y1) -> (x0, y0) + 3: (x0, y0) -> (x1, y0) + + Notes + ----- + - x is flipped so that the leftmost column has the largest +x (e.g. centres start at x=+1.0). + - y increases upward (top row has the most negative y when dy>0). """ - - x1 = corner_0[0] - y1 = corner_0[1] - x2 = corner_1[0] - y2 = corner_1[1] - x3 = corner_2[0] - y3 = corner_2[1] - - return 0.5 * np.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3) - - -def delaunay_interpolated_array_from( - shape_native: Tuple[int, int], - interpolation_grid_slim: np.ndarray, - pixel_values: np.ndarray, - delaunay: scipy.spatial.Delaunay, -) -> np.ndarray: + Ny, Nx = shape_native + dy, dx = pixel_scales + + # Grid edge coordinates. Flip x so leftmost column has largest +x, matching your convention. + x_edges = ((jnp.arange(Nx + 1) - Nx / 2) * dx)[::-1] + y_edges = (jnp.arange(Ny + 1) - Ny / 2) * dy + + edges_list = [] + + # Pixel order: row-major (y outer, x inner). If you want column-major, swap the loop nesting. + for j in range(Ny): + for i in range(Nx): + y0, y1 = y_edges[i], y_edges[i + 1] + xa, xb = ( + x_edges[j], + x_edges[j + 1], + ) # xa is the "right" boundary in your convention + + # Edge order to match your pytest: [(xa,y0)->(xa,y1), (xa,y1)->(xb,y1), (xb,y1)->(xb,y0), (xb,y0)->(xa,y0)] + e0 = jnp.array( + [[xa, y0], [xa, y1]] + ) # "top" in your test (vertical at x=xa) + e1 = jnp.array( + [[xa, y1], [xb, y1]] + ) # "right" in your test (horizontal at y=y1) + e2 = jnp.array( + [[xb, y1], [xb, y0]] + ) # "bottom" in your test (vertical at x=xb) + e3 = jnp.array( + [[xb, y0], [xa, y0]] + ) # "left" in your test (horizontal at y=y0) + + edges_list.append(jnp.stack([e0, e1, e2, e3], axis=0)) + + return jnp.stack(edges_list, axis=0) + + +def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]: """ - Given a Delaunay triangulation and 1D values at the node of each Delaunay pixel (e.g. the connecting points where - triangles meet), interpolate these values to a uniform 2D (y,x) grid. - - By mapping the delaunay's value to a regular grid this enables a source reconstruction of an inversion to be - output to a .fits file. - - The `grid_interpolate_slim`, which gives the (y,x) coordinates the values are evaluated at for interpolation, - need not be regular and can have undergone coordinate transforms (e.g. it can be the `source_plane_mesh_grid`) - of a `Mapper`. - - The shape of `grid_interpolate_slim` therefore must be equal to `shape_native[0] * shape_native[1]`, but the (y,x) - coordinates themselves do not need to be uniform. + Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization, + based on its 2D shape. Parameters ---------- shape_native - The 2D (y,x) shape of the uniform grid the values are interpolated on too. - interpolation_grid_slim - A 1D grid of (y,x) coordinates where each interpolation is evaluated. The shape of this grid must be equal to - shape_native[0] * shape_native[1], but it does not need to be uniform itself. - pixel_values - The values of the Delaunay nodes (e.g. the connecting points where triangles meet) which are interpolated - to compute the value in each pixel on the `interpolated_grid`. - delaunay - A `scipy.spatial.Delaunay` object which contains all functionality describing the Delaunay triangulation. - - Returns - ------- - The input values interpolated to the `grid_interpolate_slim` (y,x) coordintes given the Delaunay triangulation. - - """ - simplex_index_for_interpolate_index = delaunay.find_simplex(interpolation_grid_slim) - - simplices = delaunay.simplices - pixel_points = delaunay.points - - interpolated_array = np.zeros(len(interpolation_grid_slim)) - - for slim_index in range(len(interpolation_grid_slim)): - simplex_index = simplex_index_for_interpolate_index[slim_index] - interpolating_point = tuple(interpolation_grid_slim[slim_index]) - - if simplex_index == -1: - cloest_pixel_index = np.argmin( - np.sum((pixel_points - interpolating_point) ** 2.0, axis=1) - ) - interpolated_array[slim_index] = pixel_values[cloest_pixel_index] - else: - triangle_points = pixel_points[simplices[simplex_index]] - triangle_values = pixel_values[simplices[simplex_index]] - - area_0 = delaunay_triangle_area_from( - corner_0=triangle_points[1], - corner_1=triangle_points[2], - corner_2=interpolating_point, - ) - area_1 = delaunay_triangle_area_from( - corner_0=triangle_points[0], - corner_1=triangle_points[2], - corner_2=interpolating_point, - ) - area_2 = delaunay_triangle_area_from( - corner_0=triangle_points[0], - corner_1=triangle_points[1], - corner_2=interpolating_point, - ) - norm = area_0 + area_1 + area_2 - - weight_abc = np.array([area_0, area_1, area_2]) / norm - - interpolated_array[slim_index] = np.sum(weight_abc * triangle_values) - - return interpolated_array.reshape(shape_native) - - -@numba_util.jit() -def voronoi_neighbors_from( - pixels: int, ridge_points: np.ndarray -) -> Tuple[np.ndarray, np.ndarray]: - """ - Returns the adjacent neighbors of every pixel on a Voronoi mesh as an ndarray of shape - [total_pixels, voronoi_pixel_with_max_neighbors], using the `ridge_points` output from the `scipy.spatial.Voronoi()` - object. - - Entries with values of `-1` signify edge pixels which do not have neighbors. This function therefore also returns - an ndarray with the number of neighbors of every pixel, `neighbors_sizes`, which is iterated over when using - the `neighbors` ndarray. - - Indexing is defined in an arbritrary manner due to the irregular nature of a Voronoi mesh. - - For example, if `neighbors[0,:] = [1, 5, 36, 2, -1, -1]`, this informs us that the first Voronoi pixel has - 4 neighbors which have indexes 1, 5, 36, 2. Correspondingly `neighbors_sizes[0] = 4`. - - Parameters - ---------- - pixels - The number of pixels on the Voronoi mesh. - ridge_points - Contains the information on every Voronoi source pixel and its neighbors. + The (rows, cols) shape of the rectangular 2D pixel grid. Returns ------- - The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has. + A list of the 1D indices of all edge pixels. """ - neighbors_sizes = np.zeros(shape=(pixels)) - - for ridge_index in range(ridge_points.shape[0]): - pair0 = ridge_points[ridge_index, 0] - pair1 = ridge_points[ridge_index, 1] - neighbors_sizes[pair0] += 1 - neighbors_sizes[pair1] += 1 + rows, cols = shape_native - neighbors_index = np.zeros(shape=(pixels)) - neighbors = -1 * np.ones(shape=(pixels, int(np.max(neighbors_sizes)))) - - for ridge_index in range(ridge_points.shape[0]): - pair0 = ridge_points[ridge_index, 0] - pair1 = ridge_points[ridge_index, 1] - neighbors[pair0, int(neighbors_index[pair0])] = pair1 - neighbors[pair1, int(neighbors_index[pair1])] = pair0 - neighbors_index[pair0] += 1 - neighbors_index[pair1] += 1 - - return neighbors, neighbors_sizes + # Top row + top = np.arange(0, cols) + # Bottom row + bottom = np.arange((rows - 1) * cols, rows * cols) -def voronoi_edge_pixels_from(regions: np.ndarray, point_region: np.ndarray) -> List: - """ - Returns the edge pixels of a Voronoi mesh, where the edge pixels are defined as those pixels which are on the - edge of the Voronoi diagram. - - Parameters - ---------- - regions - Indices of the Voronoi vertices forming each Voronoi region, where -1 indicates vertex outside the Voronoi - diagram. - """ - - voronoi_edge_pixel_list = [] - - for index, i in enumerate(point_region): - if -1 in regions[i]: - voronoi_edge_pixel_list.append(index) - - return voronoi_edge_pixel_list - - -def voronoi_revised_from( - voronoi: scipy.spatial.Voronoi, -) -> Union[List[Tuple], np.ndarray]: - """ - To plot a Voronoi mesh using the `matplotlib.fill()` function a revised Voronoi mesh must be - computed, where 2D infinite voronoi regions are converted to finite 2D regions. + # Left column (excluding corners) + left = np.arange(1, rows - 1) * cols - This function returns a list of tuples containing the indices of the vertices of each revised Voronoi cell and - a list of tuples containing the revised Voronoi vertex vertices. + # Right column (excluding corners) + right = (np.arange(1, rows - 1) + 1) * cols - 1 - Parameters - ---------- - voronoi - The input Voronoi diagram that is being plotted. - """ - - if voronoi.points.shape[1] != 2: - raise ValueError("Requires 2D input") - - region_list = [] - vertex_list = voronoi.vertices.tolist() - - center = voronoi.points.mean(axis=0) - radius = np.ptp(voronoi.points).max() * 2 - - # Construct a map containing all ridges for a given point - all_ridges = {} - for (p1, p2), (v1, v2) in zip(voronoi.ridge_points, voronoi.ridge_vertices): - all_ridges.setdefault(p1, []).append((p2, v1, v2)) - all_ridges.setdefault(p2, []).append((p1, v1, v2)) - - # Reconstruct infinite regions - for p1, region in enumerate(voronoi.point_region): - vertices = voronoi.regions[region] - - if all(v >= 0 for v in vertices): - # finite region - region_list.append(vertices) - continue - - # reconstruct a non-finite region - ridges = all_ridges[p1] - region = [v for v in vertices if v >= 0] - - for p2, v1, v2 in ridges: - if v2 < 0: - v1, v2 = v2, v1 - if v1 >= 0: - # finite ridge: already in the region - continue - - # Compute the missing endpoint of an infinite ridge - - t = voronoi.points[p2] - voronoi.points[p1] # tangent - t /= np.linalg.norm(t) - n = np.array([-t[1], t[0]]) - - midpoint = voronoi.points[[p1, p2]].mean(axis=0) - direction = np.sign(np.dot(midpoint - center, n)) * n - far_point = voronoi.vertices[v2] + direction * radius - - region.append(len(vertex_list)) - vertex_list.append(far_point.tolist()) - - # sort region counterclockwise - vs = np.asarray([vertex_list[v] for v in region]) - c = vs.mean(axis=0) - angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0]) - region = np.array(region)[np.argsort(angles)] - - # finish - region_list.append(region.tolist()) - - return region_list, np.asarray(vertex_list) - - -def voronoi_nn_interpolated_array_from( - shape_native: Tuple[int, int], - interpolation_grid_slim: np.ndarray, - pixel_values: np.ndarray, - voronoi: scipy.spatial.Voronoi, -) -> np.ndarray: - try: - from autoarray.util.nn import nn_py - except ImportError as e: - raise ImportError( - "In order to use the Voronoi pixelization you must install the " - "Natural Neighbor Interpolation c package.\n\n" - "" - "See: https://github.com/Jammy2211/PyAutoArray/tree/main/autoarray/util/nn" - ) from e - - pixel_points = voronoi.points - - interpolated_array = nn_py.natural_interpolation( - pixel_points[:, 0], - pixel_points[:, 1], - pixel_values, - interpolation_grid_slim[:, 1], - interpolation_grid_slim[:, 0], - ) + # Concatenate all edge indices + edge_pixel_indices = np.concatenate([top, left, right, bottom]) - return interpolated_array.reshape(shape_native) + # Sort and return + return np.sort(edge_pixel_indices).tolist() diff --git a/autoarray/inversion/pixelization/mesh/rectangular.py b/autoarray/inversion/pixelization/mesh/rectangular.py index b1b5a7017..6f9cc3af6 100644 --- a/autoarray/inversion/pixelization/mesh/rectangular.py +++ b/autoarray/inversion/pixelization/mesh/rectangular.py @@ -1,7 +1,8 @@ import numpy as np -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple +from autoarray.structures.grids.irregular_2d import Grid2DIrregular from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular @@ -10,7 +11,6 @@ from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray import exc -from autoarray.numba_util import profile_func class Rectangular(AbstractMesh): @@ -52,8 +52,6 @@ def __init__(self, shape: Tuple[int, int] = (3, 3)): self.pixels = self.shape[0] * self.shape[1] super().__init__() - self.run_time_dict = {} - def mapper_grids_from( self, mask, @@ -62,7 +60,6 @@ def mapper_grids_from( source_plane_mesh_grid: Grid2D = None, image_plane_mesh_grid: Grid2D = None, adapt_data: np.ndarray = None, - run_time_dict: Optional[Dict] = None, ) -> MapperGrids: """ Mapper objects describe the mappings between pixels in the masked 2D data and the pixels in a pixelization, @@ -94,12 +91,8 @@ def mapper_grids_from( Not used for a rectangular pixelization. adapt_data Not used for a rectangular pixelization. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ - self.run_time_dict = run_time_dict - relocated_grid = self.relocated_grid_from( border_relocator=border_relocator, source_plane_data_grid=source_plane_data_grid, @@ -113,10 +106,8 @@ def mapper_grids_from( source_plane_mesh_grid=mesh_grid, image_plane_mesh_grid=image_plane_mesh_grid, adapt_data=adapt_data, - run_time_dict=run_time_dict, ) - @profile_func def mesh_grid_from( self, source_plane_data_grid: Optional[Grid2D] = None, @@ -136,7 +127,8 @@ def mesh_grid_from( by overlaying the `source_plane_data_grid` with the rectangular pixelization. """ return Mesh2DRectangular.overlay_grid( - shape_native=self.shape, grid=source_plane_data_grid.over_sampled + shape_native=self.shape, + grid=Grid2DIrregular(source_plane_data_grid.over_sampled), ) @property diff --git a/autoarray/inversion/pixelization/mesh/rectangular_uniform.py b/autoarray/inversion/pixelization/mesh/rectangular_uniform.py new file mode 100644 index 000000000..3f291068b --- /dev/null +++ b/autoarray/inversion/pixelization/mesh/rectangular_uniform.py @@ -0,0 +1,34 @@ +from autoarray.inversion.pixelization.mesh.rectangular import Rectangular + +from typing import Optional + + +from autoarray.structures.grids.irregular_2d import Grid2DIrregular +from autoarray.structures.grids.uniform_2d import Grid2D +from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform + + +class RectangularUniform(Rectangular): + + def mesh_grid_from( + self, + source_plane_data_grid: Optional[Grid2D] = None, + source_plane_mesh_grid: Optional[Grid2D] = None, + ) -> Mesh2DRectangularUniform: + """ + Return the rectangular `source_plane_mesh_grid` as a `Mesh2DRectangular` object, which provides additional + functionality for perform operatons that exploit the geometry of a rectangular pixelization. + + Parameters + ---------- + source_plane_data_grid + The (y,x) grid of coordinates over which the rectangular pixelization is overlaid, where this grid may have + had exterior pixels relocated to its edge via the border. + source_plane_mesh_grid + Not used for a rectangular pixelization, because the pixelization grid in the `source` frame is computed + by overlaying the `source_plane_data_grid` with the rectangular pixelization. + """ + return Mesh2DRectangularUniform.overlay_grid( + shape_native=self.shape, + grid=Grid2DIrregular(source_plane_data_grid.over_sampled), + ) diff --git a/autoarray/inversion/pixelization/mesh/triangulation.py b/autoarray/inversion/pixelization/mesh/triangulation.py index a72236e50..3e7d6cd2c 100644 --- a/autoarray/inversion/pixelization/mesh/triangulation.py +++ b/autoarray/inversion/pixelization/mesh/triangulation.py @@ -17,7 +17,6 @@ def mapper_grids_from( source_plane_mesh_grid: Optional[Grid2DIrregular] = None, image_plane_mesh_grid: Optional[Grid2DIrregular] = None, adapt_data: np.ndarray = None, - run_time_dict: Optional[Dict] = None, ) -> MapperGrids: """ Mapper objects describe the mappings between pixels in the masked 2D data and the pixels in a mesh, @@ -59,12 +58,8 @@ def mapper_grids_from( transformation applied to it to create the `source_plane_mesh_grid`. adapt_data Not used for a rectangular mesh. - run_time_dict - A dictionary which contains timing of certain functions calls which is used for profiling. """ - self.run_time_dict = run_time_dict - relocated_grid = self.relocated_grid_from( border_relocator=border_relocator, source_plane_data_grid=source_plane_data_grid, @@ -90,5 +85,4 @@ def mapper_grids_from( source_plane_mesh_grid=source_plane_mesh_grid, image_plane_mesh_grid=image_plane_mesh_grid, adapt_data=adapt_data, - run_time_dict=run_time_dict, ) diff --git a/autoarray/inversion/pixelization/mesh/voronoi.py b/autoarray/inversion/pixelization/mesh/voronoi.py index dc8d9310f..99954d850 100644 --- a/autoarray/inversion/pixelization/mesh/voronoi.py +++ b/autoarray/inversion/pixelization/mesh/voronoi.py @@ -1,8 +1,6 @@ from autoarray.structures.mesh.voronoi_2d import Mesh2DVoronoi from autoarray.inversion.pixelization.mesh.triangulation import Triangulation -from autoarray.numba_util import profile_func - class Voronoi(Triangulation): def __init__(self): @@ -33,7 +31,6 @@ def __init__(self): """ super().__init__() - @profile_func def mesh_grid_from( self, source_plane_data_grid=None, diff --git a/autoarray/inversion/pixelization/pixelization.py b/autoarray/inversion/pixelization/pixelization.py index 4f6084168..b6b09f664 100644 --- a/autoarray/inversion/pixelization/pixelization.py +++ b/autoarray/inversion/pixelization/pixelization.py @@ -139,7 +139,6 @@ def __init__( # The example below shows how a `Pixelization` is used in modeling. - import autofit as af import autogalaxy as ag mesh = af.Model(ag.mesh.Rectangular) diff --git a/autoarray/inversion/plot/inversion_plotters.py b/autoarray/inversion/plot/inversion_plotters.py index 98cede938..506388609 100644 --- a/autoarray/inversion/plot/inversion_plotters.py +++ b/autoarray/inversion/plot/inversion_plotters.py @@ -3,9 +3,8 @@ from autoconf import conf from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.structures.arrays.uniform_2d import Array2D @@ -14,13 +13,12 @@ from autoarray.inversion.plot.mapper_plotters import MapperPlotter -class InversionPlotter(Plotter): +class InversionPlotter(AbstractPlotter): def __init__( self, inversion: AbstractInversion, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, residuals_symmetric_cmap: bool = True, ): """ @@ -32,8 +30,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Inversion` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Inversion` and plotted via the visuals object. Parameters ---------- @@ -43,35 +40,12 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Inversion` are extracted and plotted as visuals for 2D plots. - """ - super().__init__( - mat_plot_2d=mat_plot_2d, include_2d=include_2d, visuals_2d=visuals_2d - ) + super().__init__(mat_plot_2d=mat_plot_2d, visuals_2d=visuals_2d) self.inversion = inversion self.residuals_symmetric_cmap = residuals_symmetric_cmap - def get_visuals_2d_for_data(self) -> Visuals2D: - try: - mapper = self.inversion.cls_list_from(cls=AbstractMapper)[0] - - visuals = self.get_2d.via_mapper_for_data_from(mapper=mapper) - - if self.visuals_2d.pix_indexes is not None: - indexes = mapper.pix_indexes_for_slim_indexes( - pix_indexes=self.visuals_2d.pix_indexes - ) - - visuals.indexes = indexes - - return visuals - - except (AttributeError, IndexError): - return self.visuals_2d - def mapper_plotter_from(self, mapper_index: int) -> MapperPlotter: """ Returns a `MapperPlotter` corresponding to the `Mapper` in the `Inversion`'s `linear_obj_list` given an input @@ -91,7 +65,6 @@ def mapper_plotter_from(self, mapper_index: int) -> MapperPlotter: mapper=self.inversion.cls_list_from(cls=AbstractMapper)[mapper_index], mat_plot_2d=self.mat_plot_2d, visuals_2d=self.visuals_2d, - include_2d=self.include_2d, ) def figures_2d(self, reconstructed_image: bool = False): @@ -109,7 +82,7 @@ def figures_2d(self, reconstructed_image: bool = False): if reconstructed_image: self.mat_plot_2d.plot_array( array=self.inversion.mapped_reconstructed_image, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Reconstructed Image", filename="reconstructed_image" ), @@ -183,7 +156,7 @@ def figures_2d_of_pixelization( self.mat_plot_2d.plot_array( array=array, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, grid_indexes=mapper_plotter.mapper.over_sampler.uniform_over_sampled, auto_labels=AutoLabels( title="Data Subtracted", filename="data_subtracted" @@ -199,7 +172,7 @@ def figures_2d_of_pixelization( self.mat_plot_2d.plot_array( array=array, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, grid_indexes=mapper_plotter.mapper.over_sampler.uniform_over_sampled, auto_labels=AutoLabels( title="Reconstructed Image", filename="reconstructed_image" @@ -292,7 +265,7 @@ def figures_2d_of_pixelization( self.mat_plot_2d.plot_array( array=sub_size, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Sub Pixels Per Image Pixels", filename="sub_pixels_per_image_pixels", @@ -307,7 +280,7 @@ def figures_2d_of_pixelization( self.mat_plot_2d.plot_array( array=mesh_pixels_per_image_pixels, - visuals_2d=self.get_visuals_2d_for_data(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels( title="Mesh Pixels Per Image Pixels", filename="mesh_pixels_per_image_pixels", @@ -350,10 +323,6 @@ def subplot_of_mapper( if self.mat_plot_2d.use_log10: self.mat_plot_2d.contour = False - mapper_image_plane_mesh_grid = self.include_2d._mapper_image_plane_mesh_grid - - self.include_2d._mapper_image_plane_mesh_grid = False - self.figures_2d_of_pixelization( pixelization_index=mapper_index, data_subtracted=True ) @@ -371,15 +340,21 @@ def subplot_of_mapper( self.mat_plot_2d.use_log10 = False - self.include_2d._mapper_image_plane_mesh_grid = mapper_image_plane_mesh_grid - self.include_2d._mapper_image_plane_mesh_grid = True + mapper = self.inversion.cls_list_from(cls=AbstractMapper)[mapper_index] + + self.visuals_2d += Visuals2D( + mesh_grid=mapper.mapper_grids.image_plane_mesh_grid + ) + self.set_title(label="Mesh Pixel Grid Overlaid") self.figures_2d_of_pixelization( pixelization_index=mapper_index, reconstructed_image=True ) self.set_title(label=None) - self.include_2d._mapper_image_plane_mesh_grid = False + self.visuals_2d.mesh_grid = None + + # self.include_2d._mapper_image_plane_mesh_grid = False self.figures_2d_of_pixelization( pixelization_index=mapper_index, reconstruction=True @@ -436,8 +411,6 @@ def subplot_mappings( ): self.open_subplot_figure(number_subplots=4) - self.include_2d._mapper_image_plane_mesh_grid = False - self.figures_2d_of_pixelization( pixelization_index=pixelization_index, data_subtracted=True ) @@ -456,9 +429,9 @@ def subplot_mappings( total_pixels=total_pixels, filter_neighbors=True ) - self.visuals_2d.pix_indexes = [ - [index] for index in pix_indexes[pixelization_index] - ] + indexes = mapper.slim_indexes_for_pix_indexes(pix_indexes=pix_indexes) + + self.visuals_2d.indexes = indexes self.figures_2d_of_pixelization( pixelization_index=pixelization_index, reconstructed_image=True diff --git a/autoarray/inversion/plot/mapper_plotters.py b/autoarray/inversion/plot/mapper_plotters.py index 9fd6608d7..08b53a710 100644 --- a/autoarray/inversion/plot/mapper_plotters.py +++ b/autoarray/inversion/plot/mapper_plotters.py @@ -1,9 +1,8 @@ import numpy as np from typing import Union -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels from autoarray.structures.arrays.uniform_2d import Array2D @@ -16,13 +15,12 @@ logger = logging.getLogger(__name__) -class MapperPlotter(Plotter): +class MapperPlotter(AbstractPlotter): def __init__( self, mapper: MapperRectangular, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots the attributes of `Mapper` objects using the matplotlib method `imshow()` and many other matplotlib @@ -33,8 +31,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Mapper` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Mapper` and plotted via the visuals object. Parameters ---------- @@ -44,23 +41,13 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Mapper` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - visuals_2d=visuals_2d, include_2d=include_2d, mat_plot_2d=mat_plot_2d - ) + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) self.mapper = mapper - def get_visuals_2d_for_data(self) -> Visuals2D: - return self.get_2d.via_mapper_for_data_from(mapper=self.mapper) - - def get_visuals_2d_for_source(self) -> Visuals2D: - return self.get_2d.via_mapper_for_source_from(mapper=self.mapper) - def figure_2d( - self, interpolate_to_uniform: bool = True, solution_vector: bool = None + self, interpolate_to_uniform: bool = False, solution_vector: bool = None ): """ Plots the plotter's `Mapper` object in 2D. @@ -76,7 +63,7 @@ def figure_2d( """ self.mat_plot_2d.plot_mapper( mapper=self.mapper, - visuals_2d=self.get_2d.via_mapper_for_source_from(mapper=self.mapper), + visuals_2d=self.visuals_2d, interpolate_to_uniform=interpolate_to_uniform, pixel_values=solution_vector, auto_labels=AutoLabels( @@ -84,8 +71,19 @@ def figure_2d( ), ) + def figure_2d_image(self, image): + + self.mat_plot_2d.plot_array( + array=image, + visuals_2d=self.visuals_2d, + grid_indexes=self.mapper.mapper_grids.image_plane_data_grid.over_sampled, + auto_labels=AutoLabels( + title="Image (Image-Plane)", filename="mapper_image" + ), + ) + def subplot_image_and_mapper( - self, image: Array2D, interpolate_to_uniform: bool = True + self, image: Array2D, interpolate_to_uniform: bool = False ): """ Make a subplot of an input image and the `Mapper`'s source-plane reconstruction. @@ -105,22 +103,7 @@ def subplot_image_and_mapper( """ self.open_subplot_figure(number_subplots=2) - self.mat_plot_2d.plot_array( - array=image, - visuals_2d=self.get_visuals_2d_for_data(), - auto_labels=AutoLabels(title="Image (Image-Plane)"), - ) - - if self.visuals_2d.pix_indexes is not None: - indexes = self.mapper.pix_indexes_for_slim_indexes( - pix_indexes=self.visuals_2d.pix_indexes - ) - - self.mat_plot_2d.index_scatter.scatter_grid_indexes( - grid=self.mapper.over_sampler.uniform_over_sampled, - indexes=indexes, - ) - + self.figure_2d_image(image=image) self.figure_2d(interpolate_to_uniform=interpolate_to_uniform) self.mat_plot_2d.output.subplot_to_figure( @@ -154,7 +137,7 @@ def plot_source_from( try: self.mat_plot_2d.plot_mapper( mapper=self.mapper, - visuals_2d=self.get_visuals_2d_for_source(), + visuals_2d=self.visuals_2d, auto_labels=auto_labels, pixel_values=pixel_values, zoom_to_brightest=zoom_to_brightest, diff --git a/autoarray/inversion/regularization/abstract.py b/autoarray/inversion/regularization/abstract.py index 47cadc302..838eaf942 100644 --- a/autoarray/inversion/regularization/abstract.py +++ b/autoarray/inversion/regularization/abstract.py @@ -5,13 +5,6 @@ if TYPE_CHECKING: from autoarray.inversion.linear_obj.linear_obj import LinearObj -try: - import pylops - - PyLopsOperator = pylops.LinearOperator -except ModuleNotFoundError: - PyLopsOperator = object - class AbstractRegularization: def __init__(self): @@ -174,19 +167,3 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ raise NotImplementedError - - -class RegularizationLop(PyLopsOperator): - def __init__(self, regularization_matrix): - self.regularization_matrix = regularization_matrix - self.pixels = regularization_matrix.shape[0] - self.dims = self.pixels - self.shape = (self.pixels, self.pixels) - self.dtype = dtype - self.explicit = False - - def _matvec(self, x): - return np.dot(self.regularization_matrix, x) - - def _rmatvec(self, x): - return np.dot(self.regularization_matrix.T, x) diff --git a/autoarray/inversion/regularization/adaptive_brightness.py b/autoarray/inversion/regularization/adaptive_brightness.py index d7c322817..c0ba845d0 100644 --- a/autoarray/inversion/regularization/adaptive_brightness.py +++ b/autoarray/inversion/regularization/adaptive_brightness.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,114 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def adaptive_regularization_weights_from( + inner_coefficient: float, outer_coefficient: float, pixel_signals: jnp.ndarray +) -> jnp.ndarray: + """ + Returns the regularization weights for the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). + + The weights define the effective regularization coefficient of every mesh parameter (typically pixels + of a ``Mapper``). + + They are computed using an estimate of the expected signal in each pixel. + + Two regularization coefficients are used, corresponding to the: + + 1) pixel_signals: pixels with a high pixel-signal (i.e. where the signal is located in the pixelization). + 2) 1.0 - pixel_signals: pixels with a low pixel-signal (i.e. where the signal is not located in the pixelization). + + Parameters + ---------- + inner_coefficient + The inner regularization coefficients which controls the degree of smoothing of the inversion reconstruction + in the inner regions of a mesh's reconstruction. + outer_coefficient + The outer regularization coefficients which controls the degree of smoothing of the inversion reconstruction + in the outer regions of a mesh's reconstruction. + pixel_signals + The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal + and low signal pixelizations. + + Returns + ------- + jnp.ndarray + The adaptive regularization weights which act as the effective regularization coefficients of + every source pixel. + """ + return ( + inner_coefficient * pixel_signals + outer_coefficient * (1.0 - pixel_signals) + ) ** 2.0 + + +def weighted_regularization_matrix_from( + regularization_weights: jnp.ndarray, + neighbors: jnp.ndarray, +) -> jnp.ndarray: + """ + Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). + + This matrix is computed using the regularization weights of every mesh pixel, which are computed using the + function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of + every mesh pixel. + + The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate + neighbor calculation of the corresponding ``Mapper`` class. + + Parameters + ---------- + regularization_weights + The regularization weight of each pixel, adaptively governing the degree of gradient regularization + applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). + neighbors + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the mesh grid (entries of -1 correspond to no neighbor). + neighbors_sizes + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + jnp.ndarray + The regularization matrix computed using an adaptive regularization scheme where the effective regularization + coefficient of every source pixel is different. + """ + S, P = neighbors.shape + reg_w = regularization_weights**2 + + # 1) Flatten the (i→j) neighbor pairs + I = jnp.repeat(jnp.arange(S), P) # (S*P,) + J = neighbors.reshape(-1) # (S*P,) + + # 2) Remap “no neighbor” entries to an extra slot S, whose weight=0 + OUT = S + J = jnp.where(J < 0, OUT, J) + + # 3) Build an extended weight vector with a zero at index S + reg_w_ext = jnp.concatenate([reg_w, jnp.zeros((1,))], axis=0) + w_ij = reg_w_ext[J] # (S*P,) + + # 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely + mat = jnp.zeros((S + 1, S + 1), dtype=regularization_weights.dtype) + + # 5) Scatter into the diagonal: + # - the tiny 1e-8 floor on each i < S + # - sum_j reg_w[j] into diag[i] + # - sum contributions reg_w[j] into diag[j] + # (diagonal at OUT=S picks up zeros only) + diag_updates_i = jnp.concatenate( + [jnp.full((S,), 1e-8), jnp.zeros((1,))], axis=0 # out‐of‐bounds slot stays zero + ) + mat = mat.at[jnp.diag_indices(S + 1)].add(diag_updates_i) + mat = mat.at[I, I].add(w_ij) + mat = mat.at[J, J].add(w_ij) + + # 6) Scatter the off‐diagonal subtractions: + mat = mat.at[I, J].add(-w_ij) + mat = mat.at[J, I].add(-w_ij) + + # 7) Drop the extra row/column S and return the S×S result + return mat[:S, :S] class AdaptiveBrightness(AbstractRegularization): @@ -70,7 +177,7 @@ def __init__( self.outer_coefficient = outer_coefficient self.signal_scale = signal_scale - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -91,13 +198,13 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: """ pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale) - return regularization_util.adaptive_regularization_weights_from( + return adaptive_regularization_weights_from( inner_coefficient=self.inner_coefficient, outer_coefficient=self.outer_coefficient, pixel_signals=pixel_signals, ) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -112,8 +219,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ regularization_weights = self.regularization_weights_from(linear_obj=linear_obj) - return regularization_util.weighted_regularization_matrix_from( + return weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=linear_obj.source_plane_mesh_grid.neighbors, - neighbors_sizes=linear_obj.source_plane_mesh_grid.neighbors.sizes, ) diff --git a/autoarray/inversion/regularization/adaptive_brightness_split.py b/autoarray/inversion/regularization/adaptive_brightness_split.py index 7b09993db..b781ef7a6 100644 --- a/autoarray/inversion/regularization/adaptive_brightness_split.py +++ b/autoarray/inversion/regularization/adaptive_brightness_split.py @@ -22,8 +22,7 @@ def __init__( adapted to the data being fitted to smooth an inversion's solution. An adaptive regularization scheme which splits every source pixel into a cross of four regularization points - and interpolates to these points in order - to smooth an inversion's solution. + and interpolates to these points in order to smooth an inversion's solution. The size of this cross is determined via the size of the source-pixel, for example if the source pixel is a Voronoi pixel the area of the pixel is computed and the distance of each point of the cross is given by diff --git a/autoarray/inversion/regularization/brightness_zeroth.py b/autoarray/inversion/regularization/brightness_zeroth.py index 4cab4e6d2..6cd765aec 100644 --- a/autoarray/inversion/regularization/brightness_zeroth.py +++ b/autoarray/inversion/regularization/brightness_zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -10,6 +10,60 @@ from autoarray.inversion.regularization import regularization_util +def brightness_zeroth_regularization_weights_from( + coefficient: float, pixel_signals: jnp.ndarray +) -> jnp.ndarray: + """ + Returns the regularization weights for the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``). + + The weights define the level of zeroth order regularization applied to every mesh parameter (typically pixels + of a ``Mapper``). + + They are computed using an estimate of the expected signal in each pixel. + + The zeroth order regularization coefficients is applied in combination with 1.0 - pixel_signals, which are + the pixels with a low pixel-signal (i.e. where the signal is not located near the source being reconstructed in + the pixelization). + + Parameters + ---------- + coefficient + The level of zeroth order regularization applied to every mesh parameter (typically pixels of a ``Mapper``), + with the degree applied varying based on the ``pixel_signals``. + pixel_signals + The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal + and low signal pixelizations. + + Returns + ------- + jnp.ndarray + The zeroth order regularization weights which act as the effective level of zeroth order regularization + applied to every mesh parameter. + """ + return coefficient * (1.0 - pixel_signals) + + +def brightness_zeroth_regularization_matrix_from( + regularization_weights: jnp.ndarray, +) -> jnp.ndarray: + """ + Returns the regularization matrix for the zeroth-order brightness regularization scheme. + + Parameters + ---------- + regularization_weights + The regularization weights for each pixel, governing the strength of zeroth-order + regularization applied per inversion parameter. + + Returns + ------- + A diagonal regularization matrix where each diagonal element is the squared regularization weight + for that pixel. + """ + regularization_weight_squared = regularization_weights**2.0 + return jnp.diag(regularization_weight_squared) + + class BrightnessZeroth(AbstractRegularization): def __init__( self, @@ -45,7 +99,7 @@ def __init__( self.coefficient = coefficient self.signal_scale = signal_scale - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of the ``BrightnessZeroth`` regularization scheme. @@ -65,11 +119,11 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: """ pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale) - return regularization_util.brightness_zeroth_regularization_weights_from( + return brightness_zeroth_regularization_weights_from( coefficient=self.coefficient, pixel_signals=pixel_signals ) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -84,6 +138,6 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ regularization_weights = self.regularization_weights_from(linear_obj=linear_obj) - return regularization_util.brightness_zeroth_regularization_matrix_from( + return brightness_zeroth_regularization_matrix_from( regularization_weights=regularization_weights ) diff --git a/autoarray/inversion/regularization/constant.py b/autoarray/inversion/regularization/constant.py index 690b248bd..d9737d075 100644 --- a/autoarray/inversion/regularization/constant.py +++ b/autoarray/inversion/regularization/constant.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,57 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def constant_regularization_matrix_from( + coefficient: float, + neighbors: jnp.ndarray[[int, int], jnp.int64], + neighbors_sizes: jnp.ndarray[[int], jnp.int64], +) -> jnp.ndarray[[int, int], jnp.float64]: + """ + From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. + + A complete description of regularizatin and the `regularization_matrix` can be found in the `Regularization` + class in the module `autoarray.inversion.regularization`. + + Memory requirement: 2SP + S^2 + FLOPS: 1 + 2S + 2SP + + Parameters + ---------- + coefficient + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + neighbors : ndarray, shape (S, P), dtype=int64 + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the Voronoi grid (entries of -1 correspond to no neighbor). + neighbors_sizes : ndarray, shape (S,), dtype=int64 + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + regularization_matrix : ndarray, shape (S, S), dtype=float64 + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + S, P = neighbors.shape + # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) + OUT_OF_BOUND_IDX = S + regularization_coefficient = coefficient * coefficient + + # flatten it for feeding into the matrix as j indices + neighbors = neighbors.flatten() + # now create the corresponding i indices + I_IDX = jnp.repeat(jnp.arange(S), P) + # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. + # This ensures that JAX can efficiently drop these entries during matrix updates. + neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) + return ( + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[ + I_IDX, neighbors + ] + # unique indices should be guranteed by neighbors-spec + .add(-regularization_coefficient, mode="drop", unique_indices=True) + ) class Constant(AbstractRegularization): @@ -38,7 +88,7 @@ def __init__(self, coefficient: float = 1.0): super().__init__() - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -57,9 +107,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -73,7 +123,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ - return regularization_util.constant_regularization_matrix_from( + return constant_regularization_matrix_from( coefficient=self.coefficient, neighbors=linear_obj.neighbors, neighbors_sizes=linear_obj.neighbors.sizes, diff --git a/autoarray/inversion/regularization/constant_zeroth.py b/autoarray/inversion/regularization/constant_zeroth.py index 5e3d8acb3..11d7b9808 100644 --- a/autoarray/inversion/regularization/constant_zeroth.py +++ b/autoarray/inversion/regularization/constant_zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,61 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def constant_zeroth_regularization_matrix_from( + coefficient: float, + coefficient_zeroth: float, + neighbors: jnp.ndarray, + neighbors_sizes: jnp.ndarray[[int], jnp.int64], +) -> jnp.ndarray: + """ + From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. + + A complete description of regularizatin and the ``regularization_matrix`` can be found in the ``Regularization`` + class in the module ``autoarray.inversion.regularization``. + + Parameters + ---------- + coefficients + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + neighbors + An array of length (total_pixels) which provides the index of all neighbors of every pixel in + the Voronoi grid (entries of -1 correspond to no neighbor). + neighbors_sizes + An array of length (total_pixels) which gives the number of neighbors of every pixel in the + Voronoi grid. + + Returns + ------- + jnp.ndarray + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + S, P = neighbors.shape + # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do) + OUT_OF_BOUND_IDX = S + regularization_coefficient = coefficient * coefficient + + # flatten it for feeding into the matrix as j indices + neighbors = neighbors.flatten() + # now create the corresponding i indices + I_IDX = jnp.repeat(jnp.arange(S), P) + # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index. + # This ensures that JAX can efficiently drop these entries during matrix updates. + neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors) + const = ( + jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[ + I_IDX, neighbors + ] + # unique indices should be guranteed by neighbors-spec + .add(-regularization_coefficient, mode="drop", unique_indices=True) + ) + + reg_coeff = coefficient_zeroth**2.0 + # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T + zeroth = jnp.eye(P) * reg_coeff + + return const + zeroth class ConstantZeroth(AbstractRegularization): @@ -17,7 +71,7 @@ def __init__(self, coefficient_neighbor=1.0, coefficient_zeroth=1.0): self.coefficient_neighbor = coefficient_neighbor self.coefficient_zeroth = coefficient_zeroth - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -36,9 +90,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient_neighbor * np.ones(linear_obj.params) + return self.coefficient_neighbor * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -51,9 +105,8 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization matrix. """ - return regularization_util.constant_zeroth_regularization_matrix_from( + return constant_zeroth_regularization_matrix_from( coefficient=self.coefficient_neighbor, coefficient_zeroth=self.coefficient_zeroth, neighbors=linear_obj.neighbors, - neighbors_sizes=linear_obj.neighbors.sizes, ) diff --git a/autoarray/inversion/regularization/exponential_kernel.py b/autoarray/inversion/regularization/exponential_kernel.py index 73ead006d..cfb03186b 100644 --- a/autoarray/inversion/regularization/exponential_kernel.py +++ b/autoarray/inversion/regularization/exponential_kernel.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,52 +7,44 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray import numba_util - -@numba_util.jit() def exp_cov_matrix_from( scale: float, - pixel_points: np.ndarray, -) -> np.ndarray: + pixel_points: jnp.ndarray, # shape (N, 2) +) -> jnp.ndarray: # shape (N, N) """ - Consutruct the source brightness covariance matrix, which is used to determined the regularization - pattern (i.e, how the different source pixels are smoothed). + Construct the source brightness covariance matrix using an exponential kernel: + + cov[i,j] = exp(- d_{ij} / scale) - The covariance matrix includes one non-linear parameters, the scale coefficient, which is used to determine - the typical scale of the regularization pattern. + with a tiny jitter 1e-8 added on the diagonal for numerical stability. Parameters ---------- scale - The typical scale of the regularization pattern . + The length‐scale of the exponential kernel. pixel_points - An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane). - Something like [[y1,x1], [y2,x2], ...] + Array of shape (N, 2) giving the (y,x) coordinates of each source‐plane pixel. Returns ------- - np.ndarray - The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels]. + jnp.ndarray, shape (N, N) + The exponential covariance matrix. """ + # pairwise differences: shape (N, N, 2) + diff = pixel_points[:, None, :] - pixel_points[None, :, :] - pixels = len(pixel_points) - covariance_matrix = np.zeros(shape=(pixels, pixels)) + # Euclidean distances: shape (N, N) + d = jnp.linalg.norm(diff, axis=-1) - for i in range(pixels): - covariance_matrix[i, i] += 1e-8 - for j in range(pixels): - xi = pixel_points[i, 1] - yi = pixel_points[i, 0] - xj = pixel_points[j, 1] - yj = pixel_points[j, 0] - d_ij = np.sqrt( - (xi - xj) ** 2 + (yi - yj) ** 2 - ) # distance between the pixel i and j + # exponential kernel + cov = jnp.exp(-d / scale) - covariance_matrix[i, j] += np.exp(-1.0 * d_ij / scale) + # add a small jitter on the diagonal + N = pixel_points.shape[0] + cov = cov + jnp.eye(N) * 1e-8 - return covariance_matrix + return cov class ExponentialKernel(AbstractRegularization): @@ -83,7 +75,7 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0): super().__init__() - def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization weights of this regularization scheme. @@ -102,9 +94,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -119,7 +111,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: """ covariance_matrix = exp_cov_matrix_from( scale=self.scale, - pixel_points=np.array(linear_obj.source_plane_mesh_grid), + pixel_points=linear_obj.source_plane_mesh_grid.array, ) - return self.coefficient * np.linalg.inv(covariance_matrix) + return self.coefficient * jnp.linalg.inv(covariance_matrix) diff --git a/autoarray/inversion/regularization/gaussian_kernel.py b/autoarray/inversion/regularization/gaussian_kernel.py index e133a22a2..4b600fba5 100644 --- a/autoarray/inversion/regularization/gaussian_kernel.py +++ b/autoarray/inversion/regularization/gaussian_kernel.py @@ -1,4 +1,5 @@ from __future__ import annotations +import jax.numpy as jnp import numpy as np from typing import TYPE_CHECKING @@ -7,52 +8,46 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray import numba_util - -@numba_util.jit() def gauss_cov_matrix_from( scale: float, - pixel_points: np.ndarray, -) -> np.ndarray: + pixel_points: jnp.ndarray, # shape (N, 2) +) -> jnp.ndarray: """ - Consutruct the source brightness covariance matrix, which is used to determined the regularization - pattern (i.e, how the different source pixels are smoothed). + Construct the source‐pixel Gaussian covariance matrix for regularization. + + For N source‐pixels at coordinates (y_i, x_i), we define - the covariance matrix includes one non-linear parameters, the scale coefficient, which is used to - determine the typical scale of the regularization pattern. + C_ij = exp( -||p_i - p_j||^2 / (2 scale^2) ) + + plus a tiny diagonal “jitter” (1e-8) to ensure numerical stability. Parameters ---------- scale - the typical scale of the regularization pattern . + The characteristic length scale of the Gaussian kernel. pixel_points - An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane). - Something like [[y1,x1], [y2,x2], ...] + Array of shape (N, 2), giving the (y, x) coordinates of each source pixel. Returns ------- - np.ndarray - The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels]. + cov : jnp.ndarray, shape (N, N) + The Gaussian covariance matrix. """ + # Ensure array: + pts = jnp.asarray(pixel_points) # (N, 2) + # Compute squared distances: ||p_i - p_j||^2 + diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2) + d2 = jnp.sum(diffs**2, axis=-1) # (N, N) - pixels = len(pixel_points) - covariance_matrix = np.zeros(shape=(pixels, pixels)) - - for i in range(pixels): - covariance_matrix[i, i] += 1e-8 - for j in range(pixels): - xi = pixel_points[i, 1] - yi = pixel_points[i, 0] - xj = pixel_points[j, 1] - yj = pixel_points[j, 0] - d_ij = np.sqrt( - (xi - xj) ** 2 + (yi - yj) ** 2 - ) # distance between the pixel i and j + # Gaussian kernel + cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N) - covariance_matrix[i, j] += np.exp(-1.0 * d_ij**2 / (2 * scale**2)) + # Add tiny jitter on the diagonal + N = pts.shape[0] + cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8 - return covariance_matrix + return cov class GaussianKernel(AbstractRegularization): @@ -117,7 +112,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: The regularization matrix. """ covariance_matrix = gauss_cov_matrix_from( - scale=self.scale, pixel_points=np.array(linear_obj.source_plane_mesh_grid) + scale=self.scale, pixel_points=linear_obj.source_plane_mesh_grid.array ) - return self.coefficient * np.linalg.inv(covariance_matrix) + return self.coefficient * jnp.linalg.inv(covariance_matrix) diff --git a/autoarray/inversion/regularization/regularization_util.py b/autoarray/inversion/regularization/regularization_util.py index 291e91928..8cedca034 100644 --- a/autoarray/inversion/regularization/regularization_util.py +++ b/autoarray/inversion/regularization/regularization_util.py @@ -2,296 +2,29 @@ from typing import Tuple from autoarray import exc -from autoarray import numba_util - -@numba_util.jit() -def zeroth_regularization_matrix_from(coefficient: float, pixels: int) -> np.ndarray: - """ - Apply zeroth order regularization which penalizes every pixel's deviation from zero by addiing non-zero terms - to the regularization matrix. - - A complete description of regularization and the `regularization_matrix` can be found in the `Regularization` - class in the module `autoarray.inversion.regularization`. - - Parameters - ---------- - pixels - The number of pixels in the linear object which is to be regularized, being used to in the inversion. - coefficient - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - - Returns - ------- - np.ndarray - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - - regularization_matrix = np.zeros(shape=(pixels, pixels)) - - regularization_coefficient = coefficient**2.0 - - for i in range(pixels): - regularization_matrix[i, i] += regularization_coefficient - - return regularization_matrix - - -@numba_util.jit() -def constant_regularization_matrix_from( - coefficient: float, neighbors: np.ndarray, neighbors_sizes: np.ndarray -) -> np.ndarray: - """ - From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. - - A complete description of regularizatin and the `regularization_matrix` can be found in the `Regularization` - class in the module `autoarray.inversion.regularization`. - - Parameters - ---------- - coefficient - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - neighbors - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the Voronoi grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - np.ndarray - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - - parameters = len(neighbors) - - regularization_matrix = np.zeros(shape=(parameters, parameters)) - - regularization_coefficient = coefficient**2.0 - - for i in range(parameters): - regularization_matrix[i, i] += 1e-8 - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - regularization_matrix[i, i] += regularization_coefficient - regularization_matrix[i, neighbor_index] -= regularization_coefficient - - return regularization_matrix - - -@numba_util.jit() -def constant_zeroth_regularization_matrix_from( - coefficient: float, - coefficient_zeroth: float, - neighbors: np.ndarray, - neighbors_sizes: np.ndarray, -) -> np.ndarray: - """ - From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme. - - A complete description of regularizatin and the ``regularization_matrix`` can be found in the ``Regularization`` - class in the module ``autoarray.inversion.regularization``. - - Parameters - ---------- - coefficients - The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. - neighbors - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the Voronoi grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - np.ndarray - The regularization matrix computed using Regularization where the effective regularization - coefficient of every source pixel is the same. - """ - - pixels = len(neighbors) - - regularization_matrix = np.zeros(shape=(pixels, pixels)) - - regularization_coefficient = coefficient**2.0 - regularization_coefficient_zeroth = coefficient_zeroth**2.0 - - for i in range(pixels): - regularization_matrix[i, i] += 1e-8 - regularization_matrix[i, i] += regularization_coefficient_zeroth - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - regularization_matrix[i, i] += regularization_coefficient - regularization_matrix[i, neighbor_index] -= regularization_coefficient - - return regularization_matrix - - -def adaptive_regularization_weights_from( - inner_coefficient: float, outer_coefficient: float, pixel_signals: np.ndarray -) -> np.ndarray: - """ - Returns the regularization weights for the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). - - The weights define the effective regularization coefficient of every mesh parameter (typically pixels - of a ``Mapper``). - - They are computed using an estimate of the expected signal in each pixel. - - Two regularization coefficients are used, corresponding to the: - - 1) pixel_signals: pixels with a high pixel-signal (i.e. where the signal is located in the pixelization). - 2) 1.0 - pixel_signals: pixels with a low pixel-signal (i.e. where the signal is not located in the pixelization). - - Parameters - ---------- - inner_coefficient - The inner regularization coefficients which controls the degree of smoothing of the inversion reconstruction - in the inner regions of a mesh's reconstruction. - outer_coefficient - The outer regularization coefficients which controls the degree of smoothing of the inversion reconstruction - in the outer regions of a mesh's reconstruction. - pixel_signals - The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal - and low signal pixelizations. - - Returns - ------- - np.ndarray - The adaptive regularization weights which act as the effective regularization coefficients of - every source pixel. - """ - return ( - inner_coefficient * pixel_signals + outer_coefficient * (1.0 - pixel_signals) - ) ** 2.0 - - -def brightness_zeroth_regularization_weights_from( - coefficient: float, pixel_signals: np.ndarray -) -> np.ndarray: - """ - Returns the regularization weights for the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``). - - The weights define the level of zeroth order regularization applied to every mesh parameter (typically pixels - of a ``Mapper``). - - They are computed using an estimate of the expected signal in each pixel. - - The zeroth order regularization coefficients is applied in combination with 1.0 - pixel_signals, which are - the pixels with a low pixel-signal (i.e. where the signal is not located near the source being reconstructed in - the pixelization). - - Parameters - ---------- - coefficient - The level of zeroth order regularization applied to every mesh parameter (typically pixels of a ``Mapper``), - with the degree applied varying based on the ``pixel_signals``. - pixel_signals - The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal - and low signal pixelizations. - - Returns - ------- - np.ndarray - The zeroth order regularization weights which act as the effective level of zeroth order regularization - applied to every mesh parameter. - """ - return coefficient * (1.0 - pixel_signals) - - -@numba_util.jit() -def weighted_regularization_matrix_from( - regularization_weights: np.ndarray, - neighbors: np.ndarray, - neighbors_sizes: np.ndarray, -) -> np.ndarray: - """ - Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``). - - This matrix is computed using the regularization weights of every mesh pixel, which are computed using the - function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of - every mesh pixel. - - The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate - neighbor calculation of the corresponding ``Mapper`` class. - - Parameters - ---------- - regularization_weights - The regularization weight of each pixel, adaptively governing the degree of gradient regularization - applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). - neighbors - An array of length (total_pixels) which provides the index of all neighbors of every pixel in - the mesh grid (entries of -1 correspond to no neighbor). - neighbors_sizes - An array of length (total_pixels) which gives the number of neighbors of every pixel in the - Voronoi grid. - - Returns - ------- - np.ndarray - The regularization matrix computed using an adaptive regularization scheme where the effective regularization - coefficient of every source pixel is different. - """ - - parameters = len(regularization_weights) - - regularization_matrix = np.zeros(shape=(parameters, parameters)) - - regularization_weight = regularization_weights**2.0 - - for i in range(parameters): - regularization_matrix[i, i] += 1e-8 - for j in range(neighbors_sizes[i]): - neighbor_index = neighbors[i, j] - regularization_matrix[i, i] += regularization_weight[neighbor_index] - regularization_matrix[ - neighbor_index, neighbor_index - ] += regularization_weight[neighbor_index] - regularization_matrix[i, neighbor_index] -= regularization_weight[ - neighbor_index - ] - regularization_matrix[neighbor_index, i] -= regularization_weight[ - neighbor_index - ] - - return regularization_matrix - - -@numba_util.jit() -def brightness_zeroth_regularization_matrix_from( - regularization_weights: np.ndarray, -) -> np.ndarray: - """ - Returns the regularization matrix of the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``). - - Parameters - ---------- - regularization_weights - The regularization weight of each pixel, adaptively governing the degree of zeroth order regularization - applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``). - - Returns - ------- - np.ndarray - The regularization matrix computed using an adaptive regularization scheme where the effective regularization - coefficient of every source pixel is different. - """ - - parameters = len(regularization_weights) - - regularization_matrix = np.zeros(shape=(parameters, parameters)) - - regularization_weight = regularization_weights**2.0 - - for i in range(parameters): - regularization_matrix[i, i] += regularization_weight[i] - - return regularization_matrix +from autoarray.inversion.regularization.adaptive_brightness import ( + adaptive_regularization_weights_from, +) +from autoarray.inversion.regularization.adaptive_brightness import ( + weighted_regularization_matrix_from, +) +from autoarray.inversion.regularization.brightness_zeroth import ( + brightness_zeroth_regularization_matrix_from, +) +from autoarray.inversion.regularization.brightness_zeroth import ( + brightness_zeroth_regularization_weights_from, +) +from autoarray.inversion.regularization.constant import ( + constant_regularization_matrix_from, +) +from autoarray.inversion.regularization.constant_zeroth import ( + constant_zeroth_regularization_matrix_from, +) +from autoarray.inversion.regularization.exponential_kernel import exp_cov_matrix_from +from autoarray.inversion.regularization.gaussian_kernel import gauss_cov_matrix_from +from autoarray.inversion.regularization.matern_kernel import matern_kernel +from autoarray.inversion.regularization.zeroth import zeroth_regularization_matrix_from def reg_split_from( @@ -357,43 +90,60 @@ def reg_split_from( return splitted_mappings, splitted_sizes, splitted_weights -@numba_util.jit() def pixel_splitted_regularization_matrix_from( regularization_weights: np.ndarray, splitted_mappings: np.ndarray, splitted_sizes: np.ndarray, splitted_weights: np.ndarray, ) -> np.ndarray: - # I'm not sure what is the best way to add surface brightness weight to the regularization scheme here. - # Currently, I simply mulitply the i-th weight to the i-th source pixel, but there should be different ways. - # Need to keep an eye here. + """ + Returns the regularization matrix for the adaptive split-pixel regularization scheme. - parameters = int(len(splitted_mappings) / 4) + This scheme splits each source pixel into a cross of four regularization points and interpolates + to those points to smooth the inversion solution. It is designed to mitigate stochasticity in + the regularization that can arise when the number of neighboring pixels varies across a + mesh (e.g., in a Voronoi tessellation). - regularization_matrix = np.zeros(shape=(parameters, parameters)) + A visual description and further details are provided in the appendix of He et al. (2024): + https://arxiv.org/abs/2403.16253 + Parameters + ---------- + regularization_weights + The regularization weight per pixel, adaptively controlling the strength of regularization + applied to each inversion parameter. + splitted_mappings + The image pixel index mappings for each of the four regularization points into which each source pixel is split. + splitted_sizes + The number of neighbors or interpolation terms associated with each regularization point. + splitted_weights + The interpolation weights corresponding to each mapping entry, used to apply regularization + between split points. + + Returns + ------- + The regularization matrix of shape [source_pixels, source_pixels]. + """ + + parameters = splitted_mappings.shape[0] // 4 + regularization_matrix = np.zeros((parameters, parameters)) regularization_weight = regularization_weights**2.0 - for i in range(parameters): - regularization_matrix[i, i] += 2e-8 + # Add small constant to diagonal + np.fill_diagonal(regularization_matrix, 2e-8) + # Compute regularization contributions + for i in range(parameters): + reg_w = regularization_weight[i] for j in range(4): k = i * 4 + j - size = splitted_sizes[k] - mapping = splitted_mappings[k] - weight = splitted_weights[k] + mapping = splitted_mappings[k][:size] + weight = splitted_weights[k][:size] - for l in range(size): - for m in range(size - l): - regularization_matrix[mapping[l], mapping[l + m]] += ( - weight[l] * weight[l + m] * regularization_weight[i] - ) - regularization_matrix[mapping[l + m], mapping[l]] += ( - weight[l] * weight[l + m] * regularization_weight[i] - ) - - for i in range(parameters): - regularization_matrix[i, i] /= 2.0 + # Outer product of weights and symmetric updates + outer = np.outer(weight, weight) * reg_w + rows, cols = np.meshgrid(mapping, mapping, indexing="ij") + regularization_matrix[rows, cols] += outer return regularization_matrix diff --git a/autoarray/inversion/regularization/zeroth.py b/autoarray/inversion/regularization/zeroth.py index e30b1222e..04f61ad0e 100644 --- a/autoarray/inversion/regularization/zeroth.py +++ b/autoarray/inversion/regularization/zeroth.py @@ -1,5 +1,5 @@ from __future__ import annotations -import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -7,7 +7,34 @@ from autoarray.inversion.regularization.abstract import AbstractRegularization -from autoarray.inversion.regularization import regularization_util + +def zeroth_regularization_matrix_from(coefficient: float, pixels: int) -> jnp.ndarray: + """ + Apply zeroth order regularization which penalizes every pixel's deviation from zero by addiing non-zero terms + to the regularization matrix. + + A complete description of regularization and the `regularization_matrix` can be found in the `Regularization` + class in the module `autoarray.inversion.regularization`. + + Parameters + ---------- + pixels + The number of pixels in the linear object which is to be regularized, being used to in the inversion. + coefficient + The regularization coefficients which controls the degree of smoothing of the inversion reconstruction. + + Returns + ------- + np.ndarray + The regularization matrix computed using Regularization where the effective regularization + coefficient of every source pixel is the same. + """ + + reg_coeff = coefficient**2.0 + + # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T + + return jnp.eye(pixels) * reg_coeff class Zeroth(AbstractRegularization): @@ -60,9 +87,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization weights. """ - return self.coefficient * np.ones(linear_obj.params) + return self.coefficient * jnp.ones(linear_obj.params) - def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: + def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray: """ Returns the regularization matrix with shape [pixels, pixels]. @@ -75,6 +102,6 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: ------- The regularization matrix. """ - return regularization_util.zeroth_regularization_matrix_from( + return zeroth_regularization_matrix_from( coefficient=self.coefficient, pixels=linear_obj.params ) diff --git a/autoarray/mask/derive/grid_2d.py b/autoarray/mask/derive/grid_2d.py index 225c9a43b..702195a74 100644 --- a/autoarray/mask/derive/grid_2d.py +++ b/autoarray/mask/derive/grid_2d.py @@ -158,12 +158,12 @@ def unmasked(self) -> Grid2D: """ from autoarray.structures.grids.uniform_2d import Grid2D - grid_1d = grid_2d_util.grid_2d_slim_via_mask_from( - mask_2d=np.array(self.mask), + grid_2d = grid_2d_util.grid_2d_slim_via_mask_from( + mask_2d=self.mask, pixel_scales=self.mask.pixel_scales, origin=self.mask.origin, ) - return Grid2D(values=grid_1d, mask=self.mask) + return Grid2D(values=grid_2d, mask=self.mask) @property def edge(self) -> Grid2D: diff --git a/autoarray/mask/derive/indexes_2d.py b/autoarray/mask/derive/indexes_2d.py index 0d0a36b26..de45fd4bb 100644 --- a/autoarray/mask/derive/indexes_2d.py +++ b/autoarray/mask/derive/indexes_2d.py @@ -2,6 +2,8 @@ import logging import numpy as np +from autoconf import cached_property + from autoarray.numpy_wrapper import register_pytree_node_class from typing import TYPE_CHECKING @@ -110,7 +112,7 @@ def unmasked_slim(self) -> np.ndarray: print(derive_indexes_2d.unmasked_slim) """ return mask_2d_util.mask_slim_indexes_from( - mask_2d=np.array(self.mask), return_masked_indexes=False + mask_2d=self.mask, return_masked_indexes=False ).astype("int") @property @@ -152,7 +154,7 @@ def masked_slim(self) -> np.ndarray: print(derive_indexes_2d.masked_slim) """ return mask_2d_util.mask_slim_indexes_from( - mask_2d=np.array(self.mask), return_masked_indexes=True + mask_2d=self.mask, return_masked_indexes=True ).astype("int") @property @@ -200,7 +202,7 @@ def edge_slim(self) -> np.ndarray: print(derive_indexes_2d.edge_slim) """ return mask_2d_util.edge_1d_indexes_from( - mask_2d=np.array(self.mask).astype("bool") + mask_2d=self.mask.astype("bool") ).astype("int") @property @@ -302,7 +304,7 @@ def border_slim(self) -> np.ndarray: print(derive_indexes_2d.border_slim) """ return mask_2d_util.border_slim_indexes_from( - mask_2d=np.array(self.mask).astype("bool") + mask_2d=self.mask.astype("bool") ).astype("int") @property @@ -363,7 +365,7 @@ def border_native(self) -> np.ndarray: """ return self.native_for_slim[self.border_slim].astype("int") - @property + @cached_property def native_for_slim(self) -> np.ndarray: """ Derives a 1D ``ndarray`` which maps every 1D ``slim`` index of the ``Mask2D`` to its @@ -407,5 +409,5 @@ def native_for_slim(self) -> np.ndarray: print(derive_indexes_2d.native_for_slim) """ return mask_2d_util.native_index_for_slim_index_2d_from( - mask_2d=np.array(self.mask), + mask_2d=self.mask, ).astype("int") diff --git a/autoarray/mask/derive/mask_2d.py b/autoarray/mask/derive/mask_2d.py index 9332e273d..19e005362 100644 --- a/autoarray/mask/derive/mask_2d.py +++ b/autoarray/mask/derive/mask_2d.py @@ -204,7 +204,7 @@ def blurring_from(self, kernel_shape_native: Tuple[int, int]) -> Mask2D: raise exc.MaskException("psf_size of exterior region must be odd") blurring_mask = mask_2d_util.blurring_mask_2d_from( - mask_2d=np.array(self.mask), + mask_2d=self.mask, kernel_shape_native=kernel_shape_native, ) @@ -324,9 +324,9 @@ def edge_buffed(self) -> Mask2D: """ from autoarray.mask.mask_2d import Mask2D - edge_buffed_mask = mask_2d_util.buffed_mask_2d_from( - mask_2d=np.array(self.mask) - ).astype("bool") + edge_buffed_mask = mask_2d_util.buffed_mask_2d_from(mask_2d=self.mask).astype( + "bool" + ) return Mask2D( mask=edge_buffed_mask, diff --git a/autoarray/mask/derive/zoom_2d.py b/autoarray/mask/derive/zoom_2d.py new file mode 100644 index 000000000..af3db8fc5 --- /dev/null +++ b/autoarray/mask/derive/zoom_2d.py @@ -0,0 +1,325 @@ +from __future__ import annotations +import numpy as np +from typing import TYPE_CHECKING, List, Tuple, Union + +if TYPE_CHECKING: + from autoarray.structures.arrays.uniform_2d import Array2D + from autoarray.structures.arrays.rgb import Array2DRGB + +from autoarray.structures.arrays import array_2d_util +from autoarray.structures.grids import grid_2d_util + + +class Zoom2D: + + def __init__(self, mask: Union[np.ndarray, List]): + """ + Derives a zoomed in `Mask2D` object from a `Mask2D` object, which is typically used to visualize 2D arrays + zoomed in to only the unmasked region an analysis is performed on. + + A `Mask2D` masks values which are associated with a uniform 2D rectangular grid of pixels, where unmasked + entries (which are `False`) are used in subsequent calculations and masked values (which are `True`) are + omitted (for a full description see the :meth:`Mask2D` class API + documentation `). + + The `Zoom2D` object calculations many different zoomed in qu + + Parameters + ---------- + mask + The `Mask2D` from which zoomed in `Mask2D` objects are derived. + + Examples + -------- + + .. code-block:: python + + import autoarray as aa + + mask_2d = aa.Mask2D( + mask=[ + [True, True, True, True, True], + [True, False, False, False, True], + [True, False, False, False, True], + [True, False, False, False, True], + [True, True, True, True, True], + ], + pixel_scales=1.0, + ) + + zoom_2d = aa.Zoom2D(mask=mask_2d) + + print(zoom_2d.centre) + """ + self.mask = mask + + @property + def centre(self) -> Tuple[float, float]: + """ + Returns the centre of the zoomed in region, which is the average of the maximum and minimum y and x pixel values + of the unmasked region. + + The y and x pixel values are the pixel coordinates of the unmasked region, which are derived from the + `Mask2D` object. The pixel coordinates are in the same units as the pixel scales of the `Mask2D` object. + + Returns + ------- + The centre of the zoomed in region. + """ + from autoarray.structures.grids.uniform_2d import Grid2D + + grid = grid_2d_util.grid_2d_slim_via_mask_from( + mask_2d=np.array(self.mask), + pixel_scales=self.mask.pixel_scales, + origin=self.mask.origin, + ) + + grid = Grid2D(values=grid, mask=self.mask) + + extraction_grid_1d = self.mask.geometry.grid_pixels_2d_from(grid_scaled_2d=grid) + y_pixels_max = np.max(extraction_grid_1d[:, 0]) + y_pixels_min = np.min(extraction_grid_1d[:, 0]) + x_pixels_max = np.max(extraction_grid_1d[:, 1]) + x_pixels_min = np.min(extraction_grid_1d[:, 1]) + + return ( + ((y_pixels_max + y_pixels_min - 1.0) / 2.0), + ((x_pixels_max + x_pixels_min - 1.0) / 2.0), + ) + + @property + def offset_pixels(self) -> Tuple[float, float]: + """ + Returns the offset of the centred of the zoomed in region from the centre of the `Mask2D` object in pixel + units. + + This is computed by subtracting the pixel coordinates of the `Mask2D` object from the pixel coordinates of + the zoomed in region. + + Returns + ------- + The offset of the zoomed in region from the centre of the `Mask2D` object in pixel units. + """ + if self.mask.pixel_scales is None: + return self.mask.geometry.central_pixel_coordinates + + return ( + self.centre[0] - self.mask.geometry.central_pixel_coordinates[0], + self.centre[1] - self.mask.geometry.central_pixel_coordinates[1], + ) + + @property + def offset_scaled(self) -> Tuple[float, float]: + """ + Returns the offset of the centred of the zoomed in region from the centre of the `Mask2D` object in scaled + units. + + This is computed by subtracting the pixel coordinates of the `Mask2D` object from the pixel coordinates of + the zoomed in region. + + Returns + ------- + The offset of the zoomed in region from the centre of the `Mask2D` object in scaled units. + """ + return ( + -self.mask.pixel_scales[0] * self.offset_pixels[0], + self.mask.pixel_scales[1] * self.offset_pixels[1], + ) + + @property + def region(self) -> List[int]: + """ + The zoomed region corresponding to the square encompassing all unmasked values. + + This is used to zoom in on the region of an image that is used in an analysis for visualization. + + This zoomed extraction region is a square, even if the mask is rectangular, so that extraction regions are + always squares which is important for ensuring visualization does not have aspect ratio issues. + """ + + where = np.array(np.where(np.invert(self.mask.astype("bool")))) + y0, x0 = np.amin(where, axis=1) + y1, x1 = np.amax(where, axis=1) + + # Have to convert mask to bool for invert function to work. + + ylength = y1 - y0 + xlength = x1 - x0 + + if ylength > xlength: + length_difference = ylength - xlength + x1 += int(length_difference / 2.0) + x0 -= int(length_difference / 2.0) + elif xlength > ylength: + length_difference = xlength - ylength + y1 += int(length_difference / 2.0) + y0 -= int(length_difference / 2.0) + + return [y0, y1 + 1, x0, x1 + 1] + + @property + def shape_native(self) -> Tuple[int, int]: + """ + The shape of the zoomed in region in pixels. + + This is computed by subtracting the minimum and maximum y and x pixel values of the unmasked region. + + Returns + ------- + The shape of the zoomed in region in pixels. + """ + region = self.region + return (region[1] - region[0], region[3] - region[2]) + + def extent_from(self, buffer: int = 1) -> np.ndarray: + """ + For an extracted zoomed array computed from the method *zoomed_around_mask* compute its extent in scaled + coordinates. + + The extent of the grid in scaled units returned as an ``ndarray`` of the form [x_min, x_max, y_min, y_max]. + + This is used visualize zoomed and extracted arrays via the imshow() method. + + Parameters + ---------- + buffer + The number pixels around the extracted array used as a buffer. + """ + from autoarray.mask.mask_2d import Mask2D + + extracted_array_2d = array_2d_util.extracted_array_2d_from( + array_2d=np.array(self.mask), + y0=self.region[0] - buffer, + y1=self.region[1] + buffer, + x0=self.region[2] - buffer, + x1=self.region[3] + buffer, + ) + + mask = Mask2D.all_false( + shape_native=extracted_array_2d.shape, + pixel_scales=self.mask.pixel_scales, + origin=self.centre, + ) + + return mask.geometry.extent + + def mask_2d_from(self, buffer: int = 1) -> "Mask2D": + """ + Extract the 2D region of a mask corresponding to the rectangle encompassing all unmasked values. + + This is used to extract and visualize only the region of an image that is used in an analysis. + + Parameters + ---------- + buffer + The number pixels around the extracted array used as a buffer. + """ + from autoarray.mask.mask_2d import Mask2D + + extracted_mask_2d = array_2d_util.extracted_array_2d_from( + array_2d=np.array(self.mask), + y0=self.region[0] - buffer, + y1=self.region[1] + buffer, + x0=self.region[2] - buffer, + x1=self.region[3] + buffer, + ) + + return Mask2D( + mask=extracted_mask_2d, + pixel_scales=self.mask.pixel_scales, + origin=self.mask.origin, + ) + + def array_2d_from(self, array: Array2D, buffer: int = 1) -> Array2D: + """ + Extract the 2D region of an array corresponding to the rectangle encompassing all unmasked values. + + This is used to extract and visualize only the region of an image that is used in an analysis. + + Parameters + ---------- + buffer + The number pixels around the extracted array used as a buffer. + """ + from autoarray.structures.arrays.uniform_2d import Array2D + from autoarray.structures.arrays.rgb import Array2DRGB + from autoarray.mask.mask_2d import Mask2D + + if isinstance(array, Array2DRGB): + return self.array_2d_rgb_from(array=array, buffer=buffer) + + extracted_array_2d = array_2d_util.extracted_array_2d_from( + array_2d=array.native.array, + y0=self.region[0] - buffer, + y1=self.region[1] + buffer, + x0=self.region[2] - buffer, + x1=self.region[3] + buffer, + ) + + extracted_mask_2d = array_2d_util.extracted_array_2d_from( + array_2d=np.array(self.mask), + y0=self.region[0] - buffer, + y1=self.region[1] + buffer, + x0=self.region[2] - buffer, + x1=self.region[3] + buffer, + ) + + mask = Mask2D( + mask=extracted_mask_2d, + pixel_scales=array.pixel_scales, + origin=array.mask.mask_centre, + ) + + arr = array_2d_util.convert_array_2d(array_2d=extracted_array_2d, mask_2d=mask) + + return Array2D(values=arr, mask=mask, header=array.header).native + + def array_2d_rgb_from(self, array: Array2DRGB, buffer: int = 1) -> Array2DRGB: + """ + Extract the 2D region of an RGB array corresponding to the rectangle encompassing all unmasked values. + + This works the same as the `array_2d_from` method, but for RGB arrays, meaning that it iterates over the three + channels of the RGB array and extracts the region for each channel separately. + + This is used to extract and visualize only the region of an RGB image that is used in an analysis. + + Parameters + ---------- + buffer + The number pixels around the extracted array used as a buffer. + """ + from autoarray.structures.arrays.rgb import Array2DRGB + from autoarray.mask.mask_2d import Mask2D + + for i in range(3): + + extracted_array_2d = array_2d_util.extracted_array_2d_from( + array_2d=np.array(array.native[:, :, i]), + y0=self.region[0] - buffer, + y1=self.region[1] + buffer, + x0=self.region[2] - buffer, + x1=self.region[3] + buffer, + ) + + if i == 0: + array_2d_rgb = np.zeros( + (extracted_array_2d.shape[0], extracted_array_2d.shape[1], 3) + ) + + array_2d_rgb[:, :, i] = extracted_array_2d + + extracted_mask_2d = array_2d_util.extracted_array_2d_from( + array_2d=np.array(self.mask), + y0=self.region[0] - buffer, + y1=self.region[1] + buffer, + x0=self.region[2] - buffer, + x1=self.region[3] + buffer, + ) + + mask = Mask2D( + mask=extracted_mask_2d, + pixel_scales=array.pixel_scales, + origin=array.mask.mask_centre, + ) + + return Array2DRGB(values=array_2d_rgb.astype("int"), mask=mask) diff --git a/autoarray/mask/mask_1d.py b/autoarray/mask/mask_1d.py index 8c36d8866..55d8d1f8c 100644 --- a/autoarray/mask/mask_1d.py +++ b/autoarray/mask/mask_1d.py @@ -59,7 +59,7 @@ def __init__( mask = np.asarray(mask).astype("bool") if invert: - mask = np.invert(mask) + mask = ~mask if type(pixel_scales) is float: pixel_scales = (pixel_scales,) @@ -153,7 +153,9 @@ def from_fits( """ return cls( - array_1d_util.numpy_array_1d_via_fits_from(file_path=file_path, hdu=hdu), + mask=array_1d_util.numpy_array_1d_via_fits_from( + file_path=file_path, hdu=hdu + ), pixel_scales=pixel_scales, origin=origin, ) diff --git a/autoarray/mask/mask_1d_util.py b/autoarray/mask/mask_1d_util.py index 3d9943c19..add58c823 100644 --- a/autoarray/mask/mask_1d_util.py +++ b/autoarray/mask/mask_1d_util.py @@ -1,43 +1,7 @@ +import jax.numpy as jnp import numpy as np -from autoarray import numba_util - -@numba_util.jit() -def total_pixels_1d_from(mask_1d: np.ndarray) -> int: - """ - Returns the total number of unmasked pixels in a mask. - - Parameters - ---------- - mask_1d - A 2D array of bools, where `False` values are unmasked and included when counting pixels. - - Returns - ------- - int - The total number of pixels that are unmasked. - - Examples - -------- - - mask = np.array([[True, False, True], - [False, False, False] - [True, False, True]]) - - total_regular_pixels = total_regular_pixels_from(mask=mask) - """ - - total_regular_pixels = 0 - - for x in range(mask_1d.shape[0]): - if not mask_1d[x]: - total_regular_pixels += 1 - - return total_regular_pixels - - -@numba_util.jit() def native_index_for_slim_index_1d_from( mask_1d: np.ndarray, ) -> np.ndarray: @@ -70,14 +34,6 @@ def native_index_for_slim_index_1d_from( """ - total_pixels = total_pixels_1d_from(mask_1d=mask_1d) - native_index_for_slim_index_1d = np.zeros(shape=total_pixels) - - slim_index = 0 - - for x in range(mask_1d.shape[0]): - if not mask_1d[x]: - native_index_for_slim_index_1d[slim_index] = x - slim_index += 1 - - return native_index_for_slim_index_1d + if isinstance(mask_1d, np.ndarray): + return np.flatnonzero(~mask_1d) + return jnp.flatnonzero(~mask_1d) diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 05bd77852..0f4fe30f9 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -21,6 +21,7 @@ from autoarray.mask.derive.mask_2d import DeriveMask2D from autoarray.mask.derive.grid_2d import DeriveGrid2D from autoarray.mask.derive.indexes_2d import DeriveIndexes2D +from autoarray.mask.derive.zoom_2d import Zoom2D from autoarray.structures.arrays import array_2d_util from autoarray.geometry import geometry_util @@ -200,11 +201,8 @@ def __init__( if type(mask) is list: mask = np.asarray(mask).astype("bool") - if not isinstance(mask, np.ndarray): - mask = mask._array - if invert: - mask = np.invert(mask) + mask = ~mask pixel_scales = geometry_util.convert_pixel_scales_2d(pixel_scales=pixel_scales) @@ -217,6 +215,10 @@ def __init__( pixel_scales=pixel_scales, ) + @cached_property + def native_for_slim(self): + return self.derive_indexes.native_for_slim + __no_flatten__ = ("derive_indexes",) def __array_finalize__(self, obj): @@ -241,7 +243,7 @@ def geometry(self) -> Geometry2D: origin=self.origin, ) - @property + @cached_property def derive_indexes(self) -> DeriveIndexes2D: return DeriveIndexes2D(mask=self) @@ -253,6 +255,10 @@ def derive_mask(self) -> DeriveMask2D: def derive_grid(self) -> DeriveGrid2D: return DeriveGrid2D(mask=self) + @property + def zoom(self) -> Zoom2D: + return Zoom2D(mask=self) + @classmethod def all_false( cls, @@ -676,7 +682,7 @@ def header_dict(self) -> Dict: @property def mask_centre(self) -> Tuple[float, float]: grid = grid_2d_util.grid_2d_slim_via_mask_from( - mask_2d=np.array(self), + mask_2d=self, pixel_scales=self.pixel_scales, origin=self.origin, ) @@ -693,7 +699,7 @@ def shape_native_masked_pixels(self) -> Tuple[int, int]: and 12 False entries going horizontally in the central regions of the mask, then shape_masked_pixels=(15,12). """ - where = np.array(np.where(np.invert(self.astype("bool")))) + where = np.where(np.invert(self.astype("bool"))) y0, x0 = np.amin(where, axis=1) y1, x1 = np.amax(where, axis=1) @@ -743,7 +749,7 @@ def rescaled_from(self, rescale_factor) -> Mask2D: from autoarray.mask.mask_2d import Mask2D rescaled_mask = mask_2d_util.rescaled_mask_2d_from( - mask_2d=np.array(self), + mask_2d=self.array, rescale_factor=rescale_factor, ) @@ -802,7 +808,7 @@ def resized_from(self, new_shape, pad_value: int = 0.0) -> Mask2D: """ resized_mask = array_2d_util.resized_array_2d_from( - array_2d=np.array(self._array), + array_2d=self.array, resized_shape=new_shape, pad_value=pad_value, ).astype("bool") @@ -813,95 +819,6 @@ def resized_from(self, new_shape, pad_value: int = 0.0) -> Mask2D: origin=self.origin, ) - @property - def zoom_centre(self) -> Tuple[float, float]: - from autoarray.structures.grids.uniform_2d import Grid2D - - grid = grid_2d_util.grid_2d_slim_via_mask_from( - mask_2d=np.array(self), - pixel_scales=self.pixel_scales, - origin=self.origin, - ) - - grid = Grid2D(values=grid, mask=self) - - extraction_grid_1d = self.geometry.grid_pixels_2d_from(grid_scaled_2d=grid) - y_pixels_max = np.max(extraction_grid_1d[:, 0]) - y_pixels_min = np.min(extraction_grid_1d[:, 0]) - x_pixels_max = np.max(extraction_grid_1d[:, 1]) - x_pixels_min = np.min(extraction_grid_1d[:, 1]) - - return ( - ((y_pixels_max + y_pixels_min - 1.0) / 2.0), - ((x_pixels_max + x_pixels_min - 1.0) / 2.0), - ) - - @property - def zoom_offset_pixels(self) -> Tuple[float, float]: - if self.pixel_scales is None: - return self.geometry.central_pixel_coordinates - - return ( - self.zoom_centre[0] - self.geometry.central_pixel_coordinates[0], - self.zoom_centre[1] - self.geometry.central_pixel_coordinates[1], - ) - - @property - def zoom_offset_scaled(self) -> Tuple[float, float]: - return ( - -self.pixel_scales[0] * self.zoom_offset_pixels[0], - self.pixel_scales[1] * self.zoom_offset_pixels[1], - ) - - @property - def zoom_region(self) -> List[int]: - """ - The zoomed rectangular region corresponding to the square encompassing all unmasked values. This zoomed - extraction region is a squuare, even if the mask is rectangular. - - This is used to zoom in on the region of an image that is used in an analysis for visualization. - """ - - where = np.array(np.where(np.invert(self.astype("bool")))) - y0, x0 = np.amin(where, axis=1) - y1, x1 = np.amax(where, axis=1) - - # Have to convert mask to bool for invert function to work. - - ylength = y1 - y0 - xlength = x1 - x0 - - if ylength > xlength: - length_difference = ylength - xlength - x1 += int(length_difference / 2.0) - x0 -= int(length_difference / 2.0) - elif xlength > ylength: - length_difference = xlength - ylength - y1 += int(length_difference / 2.0) - y0 -= int(length_difference / 2.0) - - return [y0, y1 + 1, x0, x1 + 1] - - @property - def zoom_shape_native(self) -> Tuple[int, int]: - region = self.zoom_region - return (region[1] - region[0], region[3] - region[2]) - - @property - def zoom_mask_unmasked(self) -> "Mask2D": - """ - The scaled-grid of (y,x) coordinates of every pixel. - - This is defined from the top-left corner, such that the first pixel at location [0, 0] will have a negative x - value y value in scaled units. - """ - - return Mask2D.all_false( - shape_native=self.zoom_shape_native, - pixel_scales=self.pixel_scales, - origin=self.zoom_offset_scaled, - ) - @property def is_circular(self) -> bool: """ diff --git a/autoarray/mask/mask_2d_util.py b/autoarray/mask/mask_2d_util.py index 10a40b473..60eb0a25a 100644 --- a/autoarray/mask/mask_2d_util.py +++ b/autoarray/mask/mask_2d_util.py @@ -1,6 +1,5 @@ import numpy as np import jax.numpy as jnp -from scipy.ndimage import convolve from typing import Tuple import warnings @@ -51,7 +50,14 @@ def native_index_for_slim_index_2d_from( native_index_for_slim_index_2d = native_index_for_slim_index_2d_from(mask_2d=mask_2d) """ - return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T + + if isinstance(mask_2d, jnp.ndarray): + # JAX branch (assume jnp.ndarray) + rows, cols = jnp.where(~mask_2d.astype(bool)) + return jnp.stack([rows, cols], axis=1) + + rows, cols = np.where(~mask_2d.astype(bool)) + return np.stack([rows, cols], axis=1) def mask_2d_centres_from( @@ -497,6 +503,7 @@ def blurring_mask_2d_from( blurring_mask = blurring_from(mask=mask) """ + from scipy.ndimage import convolve # Get the distance from False values to edges y_distance, x_distance = min_false_distance_to_edge(mask_2d) diff --git a/autoarray/mock.py b/autoarray/mock.py index 2261ba5e2..d6dbe40b0 100644 --- a/autoarray/mock.py +++ b/autoarray/mock.py @@ -15,8 +15,8 @@ from autoarray.fit.mock.mock_fit_imaging import MockFitImaging from autoarray.fit.mock.mock_fit_interferometer import MockFitInterferometer from autoarray.mask.mock.mock_mask import MockMask +from autoarray.operators.mock.mock_psf import MockPSF from autoarray.structures.mock.mock_grid import MockGrid2DMesh from autoarray.structures.mock.mock_grid import MockMeshGrid -from autoarray.structures.mock.mock_decorators import MockGridRadialMinimum from autoarray.structures.mock.mock_decorators import MockGrid1DLikeObj from autoarray.structures.mock.mock_decorators import MockGrid2DLikeObj diff --git a/autoarray/numba_util.py b/autoarray/numba_util.py index 9e0298b73..475777f31 100644 --- a/autoarray/numba_util.py +++ b/autoarray/numba_util.py @@ -69,86 +69,3 @@ def wrapper(func): return func return wrapper - - -def profile_func(func: Callable): - """ - Time every function called in a class and averages over repeated calls for profiling likelihood functions. - - The timings are stored in the variable `_run_time_dict` of the class(s) from which each function is called, - which are collected at the end of the profiling process via recursion. - - Parameters - ---------- - func : (obj, grid, *args, **kwargs) -> Object - A function which is used in the likelihood function.. - - Returns - ------- - A function that times the function being called. - """ - - @wraps(func) - def wrapper(obj, *args, **kwargs): - """ - Time a function and average over repeated calls for profiling an `Analysis` class's likelihood function. The - time is stored in a `run_time_dict` attribute. - - It is possible for multiple functions with the `profile_func` decorator to be called. In this circumstance, - we risk repeated profiling of the same functionality in these nested functions. Thus, before added - the time to the run_time_dict, the keys of the dictionary are iterated over in reverse, subtracting off the - times of nested functions (which will already have been added to the profiling dict). - - Returns - ------- - The result of the function being timed. - """ - - if not hasattr(obj, "run_time_dict"): - return func(obj, *args, **kwargs) - - if obj.run_time_dict is None: - return func(obj, *args, **kwargs) - - repeats = conf.instance["general"]["profiling"]["repeats"] - - last_key_before_call = ( - list(obj.run_time_dict)[-1] if obj.run_time_dict else None - ) - - start = time.time() - for i in range(repeats): - result = func(obj, *args, **kwargs) - - time_func = (time.time() - start) / repeats - - last_key_after_call = list(obj.run_time_dict)[-1] if obj.run_time_dict else None - - profile_call_max = 5 - - for i in range(profile_call_max): - key_func = f"{func.__name__}_{i}" - - if key_func not in obj.run_time_dict: - if last_key_before_call == last_key_after_call: - obj.run_time_dict[key_func] = time_func - else: - for key, value in reversed(list(obj.run_time_dict.items())): - if last_key_before_call == key: - obj.run_time_dict[key_func] = time_func - break - - time_func -= obj.run_time_dict[key] - - break - - if i == 5: - raise exc.ProfilingException( - f"Attempt to make profiling dict failed, because a function has been" - f"called more than {profile_call_max} times, exceed the number of times" - f"a profiled function may be called" - ) - - return result - - return wrapper diff --git a/autoarray/numpy_wrapper.py b/autoarray/numpy_wrapper.py index 3f534d995..9a23ba5ed 100644 --- a/autoarray/numpy_wrapper.py +++ b/autoarray/numpy_wrapper.py @@ -2,7 +2,9 @@ from os import environ -use_jax = environ.get("USE_JAX", "0") == "1" +from autoconf import conf + +use_jax = conf.instance["general"]["jax"]["use_jax"] if use_jax: try: diff --git a/autoarray/operators/contour.py b/autoarray/operators/contour.py index 2de247d3c..693b14378 100644 --- a/autoarray/operators/contour.py +++ b/autoarray/operators/contour.py @@ -1,9 +1,6 @@ from __future__ import annotations import numpy as np import jax.numpy as jnp -from skimage import measure -from scipy.spatial import ConvexHull -from scipy.spatial import QhullError from autoarray.structures.grids.irregular_2d import Grid2DIrregular @@ -55,9 +52,15 @@ def contour_array(self): @property def contour_list(self): # make sure to use base numpy to convert JAX array back to a normal array - contour_indices_list = measure.find_contours( - np.array(self.contour_array), 0 - ) + + from skimage import measure + + if isinstance(self.contour_array, jnp.ndarray): + contour_array = np.array(self.contour_array) + else: + contour_array = np.array(self.contour_array.array) + + contour_indices_list = measure.find_contours(contour_array, 0) if len(contour_indices_list) == 0: return [] @@ -82,14 +85,22 @@ def contour_list(self): def hull( self, ): + + from scipy.spatial import ConvexHull + from scipy.spatial import QhullError + if self.grid.shape[0] < 3: return None # cast JAX arrays to base numpy arrays grid_convex = np.zeros((len(self.grid), 2)) - grid_convex[:, 0] = np.array(self.grid[:, 1]) - grid_convex[:, 1] = np.array(self.grid[:, 0]) + try: + grid_convex[:, 0] = np.array(self.grid.array[:, 1]) + grid_convex[:, 1] = np.array(self.grid.array[:, 0]) + except AttributeError: + grid_convex[:, 0] = np.array(self.grid[:, 1]) + grid_convex[:, 1] = np.array(self.grid[:, 0]) try: hull = ConvexHull(grid_convex) @@ -101,9 +112,6 @@ def hull( hull_x = grid_convex[hull_vertices, 0] hull_y = grid_convex[hull_vertices, 1] - grid_hull = jnp.zeros((len(hull_vertices), 2)) - - grid_hull[:, 1] = hull_x - grid_hull[:, 0] = hull_y + grid_hull = jnp.stack((hull_y, hull_x), axis=-1) return grid_hull diff --git a/autoarray/operators/mock/mock_convolver.py b/autoarray/operators/mock/mock_psf.py similarity index 72% rename from autoarray/operators/mock/mock_convolver.py rename to autoarray/operators/mock/mock_psf.py index 7dc5dfcbb..e89d2b732 100644 --- a/autoarray/operators/mock/mock_convolver.py +++ b/autoarray/operators/mock/mock_psf.py @@ -2,5 +2,5 @@ class MockPSF: def __init__(self, operated_mapping_matrix=None): self.operated_mapping_matrix = operated_mapping_matrix - def convolve_mapping_matrix(self, mapping_matrix): + def convolve_mapping_matrix(self, mapping_matrix, mask): return self.operated_mapping_matrix diff --git a/autoarray/operators/over_sampling/over_sample_util.py b/autoarray/operators/over_sampling/over_sample_util.py index 084d0bd0d..75465bd3f 100644 --- a/autoarray/operators/over_sampling/over_sample_util.py +++ b/autoarray/operators/over_sampling/over_sample_util.py @@ -1,6 +1,6 @@ from __future__ import annotations import numpy as np -from typing import TYPE_CHECKING, Union, List, Tuple +from typing import TYPE_CHECKING, Union from typing import List, Tuple from autoarray.structures.arrays.uniform_2d import Array2D @@ -8,11 +8,8 @@ if TYPE_CHECKING: from autoarray.structures.grids.uniform_2d import Grid2D -from autoarray.geometry import geometry_util from autoarray.mask.mask_2d import Mask2D -from autoarray import numba_util -from autoarray.mask import mask_2d_util from autoarray import type as ty @@ -45,12 +42,11 @@ def over_sample_size_convert_to_array_2d_from( if isinstance(over_sample_size, int): over_sample_size = np.full( fill_value=over_sample_size, shape=mask.pixels_in_mask - ).astype("int") + ) - return Array2D(values=over_sample_size, mask=mask) + return Array2D(values=np.array(over_sample_size).astype("int"), mask=mask) -@numba_util.jit() def total_sub_pixels_2d_from(sub_size: np.ndarray) -> int: """ Returns the total number of sub-pixels in unmasked pixels in a mask. @@ -79,90 +75,6 @@ def total_sub_pixels_2d_from(sub_size: np.ndarray) -> int: return int(np.sum(sub_size**2)) -@numba_util.jit() -def native_sub_index_for_slim_sub_index_2d_from( - mask_2d: np.ndarray, sub_size: np.ndarray -) -> np.ndarray: - """ - Returns an array of shape [total_unmasked_pixels*sub_size] that maps every unmasked sub-pixel to its - corresponding native 2D pixel using its (y,x) pixel indexes. - - For example, for the following ``Mask2D`` for ``sub_size=1``: - - :: - [[True, True, True, True] - [True, False, False, True], - [True, False, True, True], - [True, True, True, True]] - - This has three unmasked (``False`` values) which have the ``slim`` indexes: - - :: - [0, 1, 2] - - The array ``native_index_for_slim_index_2d`` is therefore: - - :: - [[1,1], [1,2], [2,1]] - - For a ``Mask2D`` with ``sub_size=2`` each unmasked ``False`` entry is split into a sub-pixel of size 2x2 and - there are therefore 12 ``slim`` indexes: - - :: - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - - The array ``native_index_for_slim_index_2d`` is therefore: - - :: - [[2,2], [2,3], [2,4], [2,5], [3,2], [3,3], [3,4], [3,5], [4,2], [4,3], [5,2], [5,3]] - - Parameters - ---------- - mask_2d - A 2D array of bools, where `False` values are unmasked. - sub_size - The size of the sub-grid in each mask pixel. - - Returns - ------- - ndarray - An array that maps pixels from a slimmed array of shape [total_unmasked_pixels*sub_size] to its native array - of shape [total_pixels*sub_size, total_pixels*sub_size]. - - Examples - -------- - mask_2d = np.array([[True, True, True], - [True, False, True] - [True, True, True]]) - - sub_native_index_for_sub_slim_index_2d = sub_native_index_for_sub_slim_index_via_mask_2d_from(mask_2d=mask_2d, sub_size=1) - """ - - total_sub_pixels = total_sub_pixels_2d_from(sub_size=sub_size) - sub_native_index_for_sub_slim_index_2d = np.zeros(shape=(total_sub_pixels, 2)) - - slim_index = 0 - sub_slim_index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - sub = sub_size[slim_index] - - for y1 in range(sub): - for x1 in range(sub): - sub_native_index_for_sub_slim_index_2d[sub_slim_index, :] = ( - (y * sub) + y1, - (x * sub) + x1, - ) - sub_slim_index += 1 - - slim_index += 1 - - return sub_native_index_for_sub_slim_index_2d - - -@numba_util.jit() def slim_index_for_sub_slim_index_via_mask_2d_from( mask_2d: np.ndarray, sub_size: np.ndarray ) -> np.ndarray: @@ -195,131 +107,20 @@ def slim_index_for_sub_slim_index_via_mask_2d_from( slim_index_for_sub_slim_index = slim_index_for_sub_slim_index_via_mask_2d_from(mask_2d=mask_2d, sub_size=2) """ - total_sub_pixels = total_sub_pixels_2d_from(sub_size=sub_size) - - slim_index_for_sub_slim_index = np.zeros(shape=total_sub_pixels) - slim_index = 0 - sub_slim_index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - sub = sub_size[slim_index] + # Step 1: Identify unmasked (False) pixels + unmasked_indices = np.argwhere(~mask_2d) + n_unmasked = unmasked_indices.shape[0] - for y1 in range(sub): - for x1 in range(sub): - slim_index_for_sub_slim_index[sub_slim_index] = slim_index - sub_slim_index += 1 + # Step 2: Compute total number of sub-pixels + sub_pixels_per_pixel = sub_size**2 - slim_index += 1 + # Step 3: Repeat slim indices for each sub-pixel + slim_indices = np.arange(n_unmasked) + slim_index_for_sub_slim_index = np.repeat(slim_indices, sub_pixels_per_pixel) return slim_index_for_sub_slim_index -@numba_util.jit() -def sub_slim_index_for_sub_native_index_from(sub_mask_2d: np.ndarray): - """ - Returns a 2D array which maps every `False` entry of a 2D mask to its sub slim mask array. Every - True entry is given a value -1. - - This is used as a convenience tool for creating structures util between different grids and structures. - - For example, if we had a 3x4 mask: - - [[False, True, False, False], - [False, True, False, False], - [False, False, False, True]]] - - The sub_slim_index_for_sub_native_index array would be: - - [[0, -1, 2, 3], - [4, -1, 5, 6], - [7, 8, 9, -1]] - - Parameters - ---------- - sub_mask_2d - The 2D mask that the util array is created for. - - Returns - ------- - ndarray - The 2D array mapping 2D mask entries to their 1D masked array indexes. - - Examples - -------- - mask = np.full(fill_value=False, shape=(9,9)) - sub_two_to_one = mask_to_mask_1d_index_from(mask=mask) - """ - - sub_slim_index_for_sub_native_index = -1 * np.ones(shape=sub_mask_2d.shape) - - sub_mask_1d_index = 0 - - for sub_mask_y in range(sub_mask_2d.shape[0]): - for sub_mask_x in range(sub_mask_2d.shape[1]): - if sub_mask_2d[sub_mask_y, sub_mask_x] == False: - sub_slim_index_for_sub_native_index[sub_mask_y, sub_mask_x] = ( - sub_mask_1d_index - ) - sub_mask_1d_index += 1 - - return sub_slim_index_for_sub_native_index - - -@numba_util.jit() -def oversample_mask_2d_from(mask: np.ndarray, sub_size: int) -> np.ndarray: - """ - Returns a new mask of shape (mask.shape[0] * sub_size, mask.shape[1] * sub_size) where all boolean values are - expanded according to the `sub_size`. - - For example, if the input mask is: - - mask = np.array([ - [True, True, True], - [True, False, True], - [True, True, True] - ]) - - and the sub_size is 2, the output mask would be: - - expanded_mask = np.array([ - [True, True, True, True, True, True], - [True, True, True, True, True, True], - [True, True, False, False, True, True], - [True, True, False, False, True, True], - [True, True, True, True, True, True], - [True, True, True, True, True, True] - ]) - - This is used throughout the code to handle uniform oversampling calculations. - - Parameters - ---------- - mask - The mask from which the over sample mask is computed. - sub_size - The factor by which the mask is oversampled. - - Returns - ------- - The mask oversampled by the input sub_size. - """ - oversample_mask = np.full( - (mask.shape[0] * sub_size, mask.shape[1] * sub_size), True - ) - - for y in range(mask.shape[0]): - for x in range(mask.shape[1]): - if not mask[y, x]: - oversample_mask[ - y * sub_size : (y + 1) * sub_size, x * sub_size : (x + 1) * sub_size - ] = False - - return oversample_mask - - -@numba_util.jit() def sub_size_radial_bins_from( radial_grid: np.ndarray, sub_size_list: np.ndarray, @@ -357,22 +158,18 @@ def sub_size_radial_bins_from( the centre of the mask. """ - sub_size = sub_size_list[-1] * np.ones(radial_grid.shape) + # Use np.searchsorted to find the first index where radial_grid[i] < radial_list[j] + bin_indices = np.searchsorted(radial_list, radial_grid, side="left") + + # Clip indices to stay within bounds of sub_size_list + bin_indices = np.clip(bin_indices, 0, len(sub_size_list) - 1) - for i in range(radial_grid.shape[0]): - for j in range(len(radial_list)): - if radial_grid[i] < radial_list[j]: - # if use_jax: - # # while this makes it run, it is very, very slow - # sub_size = sub_size.at[i].set(sub_size_list[j]) - # else: - sub_size[i] = sub_size_list[j] - break + return sub_size_list[bin_indices] - return sub_size + +from autoarray.geometry import geometry_util -@numba_util.jit() def grid_2d_slim_over_sampled_via_mask_from( mask_2d: np.ndarray, pixel_scales: ty.PixelScales, @@ -418,11 +215,16 @@ def grid_2d_slim_over_sampled_via_mask_from( grid_slim = grid_2d_slim_over_sampled_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), sub_size=1, origin=(0.0, 0.0)) """ + pixels_in_mask = (np.size(mask_2d) - np.sum(mask_2d)).astype(int) + + if isinstance(sub_size, int): + sub_size = np.full(fill_value=sub_size, shape=pixels_in_mask) + total_sub_pixels = np.sum(sub_size**2) grid_slim = np.zeros(shape=(total_sub_pixels, 2)) - centres_scaled = geometry_util.central_scaled_coordinate_2d_numba_from( + centres_scaled = geometry_util.central_scaled_coordinate_2d_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) @@ -445,23 +247,6 @@ def grid_2d_slim_over_sampled_via_mask_from( for y1 in range(sub): for x1 in range(sub): - # if use_jax: - # # while this makes it run, it is very, very slow - # grid_slim = grid_slim.at[sub_index, 0].set( - # -( - # y_scaled - # - y_sub_half - # + y1 * y_sub_step - # + (y_sub_step / 2.0) - # ) - # ) - # grid_slim = grid_slim.at[sub_index, 1].set( - # x_scaled - # - x_sub_half - # + x1 * x_sub_step - # + (x_sub_step / 2.0) - # ) - # else: grid_slim[sub_index, 0] = -( y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0) ) @@ -475,80 +260,97 @@ def grid_2d_slim_over_sampled_via_mask_from( return grid_slim -@numba_util.jit() -def binned_array_2d_from( - array_2d: np.ndarray, - mask_2d: np.ndarray, - sub_size: np.ndarray, -) -> np.ndarray: - """ - For a sub-grid, every unmasked pixel of its 2D mask with shape (total_y_pixels, total_x_pixels) is divided into - a finer uniform grid of shape (total_y_pixels*sub_size, total_x_pixels*sub_size). This routine computes the (y,x) - scaled coordinates a the centre of every sub-pixel defined by this 2D mask array. - - The sub-grid is returned on an array of shape (total_unmasked_pixels*sub_size**2, 2). y coordinates are - stored in the 0 index of the second dimension, x coordinates in the 1 index. Masked coordinates are therefore - removed and not included in the slimmed grid. - - Grid2D are defined from the top-left corner, where the first unmasked sub-pixel corresponds to index 0. - Sub-pixels that are part of the same mask array pixel are indexed next to one another, such that the second - sub-pixel in the first pixel has index 1, its next sub-pixel has index 2, and so forth. - - Parameters - ---------- - mask_2d - A 2D array of bools, where `False` values are unmasked and therefore included as part of the calculated - sub-grid. - pixel_scales - The (y,x) scaled units to pixel units conversion factor of the 2D mask array. - sub_size - The size of the sub-grid that each pixel of the 2D mask array is divided into. - origin - The (y,x) origin of the 2D array, which the sub-grid is shifted around. - - Returns - ------- - ndarray - A slimmed sub grid of (y,x) scaled coordinates at the centre of every pixel unmasked pixel on the 2D mask - array. The sub grid array has dimensions (total_unmasked_pixels*sub_size**2, 2). - - Examples - -------- - mask = np.array([[True, False, True], - [False, False, False] - [True, False, True]]) - grid_slim = grid_2d_slim_over_sampled_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), sub_size=1, origin=(0.0, 0.0)) - """ - - total_pixels = np.sum(~mask_2d) - - sub_fraction = 1.0 / sub_size**2 - - binned_array_2d_slim = np.zeros(shape=total_pixels) - - index = 0 - sub_index = 0 - - for y in range(mask_2d.shape[0]): - for x in range(mask_2d.shape[1]): - if not mask_2d[y, x]: - sub = sub_size[index] - - for y1 in range(sub): - for x1 in range(sub): - # if use_jax: - # binned_array_2d_slim = binned_array_2d_slim.at[index].add( - # array_2d[sub_index] * sub_fraction[index] - # ) - # else: - binned_array_2d_slim[index] += ( - array_2d[sub_index] * sub_fraction[index] - ) - sub_index += 1 - - index += 1 - - return binned_array_2d_slim +# + +# def grid_2d_slim_over_sampled_via_mask_from( +# mask_2d: np.ndarray, +# pixel_scales: ty.PixelScales, +# sub_size: np.ndarray, +# origin: Tuple[float, float] = (0.0, 0.0), +# ) -> np.ndarray: +# """ +# For a sub-grid, every unmasked pixel of its 2D mask with shape (total_y_pixels, total_x_pixels) is divided into +# a finer uniform grid of shape (total_y_pixels*sub_size, total_x_pixels*sub_size). This routine computes the (y,x) +# scaled coordinates at the centre of every sub-pixel defined by this 2D mask array. +# +# The sub-grid is returned on an array of shape (total_unmasked_pixels*sub_size**2, 2). y coordinates are +# stored in the 0 index of the second dimension, x coordinates in the 1 index. Masked coordinates are therefore +# removed and not included in the slimmed grid. +# +# Grid2D are defined from the top-left corner, where the first unmasked sub-pixel corresponds to index 0. +# Sub-pixels that are part of the same mask array pixel are indexed next to one another, such that the second +# sub-pixel in the first pixel has index 1, its next sub-pixel has index 2, and so forth. +# +# Parameters +# ---------- +# mask_2d +# A 2D array of bools, where `False` values are unmasked and therefore included as part of the calculated +# sub-grid. +# pixel_scales +# The (y,x) scaled units to pixel units conversion factor of the 2D mask array. +# sub_size +# The size of the sub-grid that each pixel of the 2D mask array is divided into. +# origin +# The (y,x) origin of the 2D array, which the sub-grid is shifted around. +# +# Returns +# ------- +# ndarray +# A slimmed sub grid of (y,x) scaled coordinates at the centre of every pixel unmasked pixel on the 2D mask +# array. The sub grid array has dimensions (total_unmasked_pixels*sub_size**2, 2). +# +# Examples +# -------- +# mask = np.array([[True, False, True], +# [False, False, False] +# [True, False, True]]) +# grid_slim = grid_2d_slim_over_sampled_via_mask_from(mask=mask, pixel_scales=(0.5, 0.5), sub_size=1, origin=(0.0, 0.0)) +# """ +# +# H, W = mask_2d.shape +# sy, sx = pixel_scales +# oy, ox = origin +# +# # 1) Find unmasked pixel indices in row-major order +# rows, cols = np.nonzero(~mask_2d) +# Npix = rows.size +# +# # 2) Broadcast or validate sub_size array +# sub_arr = np.asarray(sub_size) +# sub_arr = np.full(Npix, sub_arr, dtype=int) if sub_arr.size == 1 else sub_arr +# +# # 3) Compute pixel centers (y ↑ up, x → right) +# cy = (H - 1) / 2.0 +# cx = (W - 1) / 2.0 +# y_pix = (cy - rows) * sy + oy +# x_pix = (cols - cx) * sx + ox +# +# # 4) For each pixel, generate its sub-pixel coords and collect +# coords_list = [] +# for i in range(Npix): +# s = sub_arr[i] +# dy = sy / s +# dx = sx / s +# +# # y offsets: from top (+sy/2 - dy/2) down to bottom (-sy/2 + dy/2) +# y_off = np.linspace(+sy/2 - dy/2, -sy/2 + dy/2, s) +# # x offsets: left to right +# x_off = np.linspace(-sx/2 + dx/2, +sx/2 - dx/2, s) +# +# # build subgrid +# y_sub, x_sub = np.meshgrid(y_off, x_off, indexing="ij") +# y_sub = y_sub.ravel() +# x_sub = x_sub.ravel() +# +# # center + offsets +# y_center = y_pix[i] +# x_center = x_pix[i] +# coords = np.stack([y_center + y_sub, x_center + x_sub], axis=1) +# +# coords_list.append(coords) +# +# # 5) Concatenate all sub-pixel blocks in row-major pixel order +# return np.vstack(coords_list) def over_sample_size_via_radial_bins_from( @@ -601,7 +403,7 @@ def over_sample_size_via_radial_bins_from( radial_grid = grid.distances_to_coordinate_from(coordinate=centre) sub_size_of_centre = sub_size_radial_bins_from( - radial_grid=np.array(radial_grid), + radial_grid=np.array(radial_grid.array), sub_size_list=np.array(sub_size_list), radial_list=np.array(radial_list), ) diff --git a/autoarray/operators/over_sampling/over_sampler.py b/autoarray/operators/over_sampling/over_sampler.py index 9fda67bb7..5411fa93a 100644 --- a/autoarray/operators/over_sampling/over_sampler.py +++ b/autoarray/operators/over_sampling/over_sampler.py @@ -1,5 +1,6 @@ import numpy as np import jax.numpy as jnp +import jax from typing import Union from autoconf import conf @@ -128,7 +129,7 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): based on the sub-grid sizes. The over sampling class has functions dedicated to mapping between the sub-grid and pixel-grid, for example - `sub_mask_native_for_sub_mask_slim` and `slim_for_sub_slim`. + `slim_for_sub_slim`. The class `OverSampling` is used for the high level API, whereby this is where users input their preferred over-sampling configuration. This class, `OverSampler`, contains the functionality @@ -147,15 +148,39 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): over_sample_size=sub_size, mask=mask ) + self.sub_total = int(np.sum(self.sub_size**2)) + self.sub_length = self.sub_size**self.mask.dimensions + self.sub_fraction = Array2D( + values=jnp.array(1.0 / self.sub_length.array), mask=self.mask + ) + + # Used for JAX based adaptive over sampling. + + # Define group sizes + group_sizes = np.array(self.sub_size.array**2) + + # Compute the cumulative sum of group sizes to get split points + self.split_indices = np.cumsum(group_sizes) + + # Ensure correct concatenation by making 0 a JAX array + self.start_indices = np.concatenate((np.array([0]), self.split_indices[:-1])) + + # Compute segment ids for each element in the flattened array + self.segment_ids = np.empty(np.sum(sub_size**2), dtype=np.int32) + + for seg_id, (start, end) in enumerate( + zip(self.start_indices, self.split_indices) + ): + self.segment_ids[start:end] = seg_id + + self.segment_ids = jnp.array(self.segment_ids) @property def sub_is_uniform(self) -> bool: """ Returns True if the sub_size is uniform across all pixels in the mask. """ - return np.all( - np.isclose(self.sub_size.array, self.sub_size.array[0]) - ) + return np.all(np.isclose(self.sub_size, self.sub_size[0])) def tree_flatten(self): return (self.mask, self.sub_size), () @@ -164,32 +189,6 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): return cls(mask=children[0], sub_size=children[1]) - @property - def sub_total(self): - """ - The total number of sub-pixels in the entire mask. - """ - return int(np.sum(self.sub_size**2)) - - @property - def sub_length(self) -> Array2D: - """ - The total number of sub-pixels in a give pixel, - - For example, a sub-size of 3x3 means every pixel has 9 sub-pixels. - """ - return self.sub_size**self.mask.dimensions - - @property - def sub_fraction(self) -> Array2D: - """ - The fraction of the area of a pixel every sub-pixel contains. - - For example, a sub-size of 3x3 mean every pixel contains 1/9 the area. - """ - - return 1.0 / self.sub_length - @property def sub_pixel_areas(self) -> np.ndarray: """ @@ -222,97 +221,52 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": In **PyAutoCTI** all `Array2D` objects are used in their `native` representation without sub-gridding. Significant memory can be saved by only store this format, thus the `native_binned_only` config override can force this behaviour. It is recommended users do not use this option to avoid unexpected behaviour. + + Old docstring: + + For a sub-grid, every unmasked pixel of its 2D mask with shape (total_y_pixels, total_x_pixels) is divided into + a finer uniform grid of shape (total_y_pixels*sub_size, total_x_pixels*sub_size). This routine computes the (y,x) + scaled coordinates a the centre of every sub-pixel defined by this 2D mask array. + + The sub-grid is returned on an array of shape (total_unmasked_pixels*sub_size**2, 2). y coordinates are + stored in the 0 index of the second dimension, x coordinates in the 1 index. Masked coordinates are therefore + removed and not included in the slimmed grid. + + Grid2D are defined from the top-left corner, where the first unmasked sub-pixel corresponds to index 0. + Sub-pixels that are part of the same mask array pixel are indexed next to one another, such that the second + sub-pixel in the first pixel has index 1, its next sub-pixel has index 2, and so forth. """ if conf.instance["general"]["structures"]["native_binned_only"]: return self try: - array = array.slim + array = array.slim.array except AttributeError: pass if self.sub_is_uniform: + binned_array_2d = array.reshape( self.mask.shape_slim, self.sub_size[0] ** 2 ).mean(axis=1) - else: - - # Define group sizes - group_sizes = jnp.array(self.sub_size.array.astype("int") ** 2) - - # Compute the cumulative sum of group sizes to get split points - split_indices = jnp.cumsum(group_sizes) - # Ensure correct concatenation by making 0 a JAX array - start_indices = jnp.concatenate((jnp.array([0]), split_indices[:-1])) + else: # Compute the group means - binned_array_2d = jnp.array( - [array[start:end].mean() for start, end in zip(start_indices, split_indices)]) + + sums = jax.ops.segment_sum( + array, self.segment_ids, self.mask.pixels_in_mask + ) + counts = jax.ops.segment_sum( + jnp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask + ) + binned_array_2d = sums / counts return Array2D( values=binned_array_2d, mask=self.mask, ) - @cached_property - def sub_mask_native_for_sub_mask_slim(self) -> np.ndarray: - """ - Derives a 1D ``ndarray`` which maps every subgridded 1D ``slim`` index of the ``Mask2D`` to its - subgridded 2D ``native`` index. - - For example, for the following ``Mask2D`` for ``sub_size=1``: - - :: - [[True, True, True, True] - [True, False, False, True], - [True, False, True, True], - [True, True, True, True]] - - This has three unmasked (``False`` values) which have the ``slim`` indexes: - - :: - [0, 1, 2] - - The array ``sub_mask_native_for_sub_mask_slim`` is therefore: - - :: - [[1,1], [1,2], [2,1]] - - For a ``Mask2D`` with ``sub_size=2`` each unmasked ``False`` entry is split into a sub-pixel of size 2x2 and - there are therefore 12 ``slim`` indexes: - - :: - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - - The array ``native_for_slim`` is therefore: - - :: - [[2,2], [2,3], [2,4], [2,5], [3,2], [3,3], [3,4], [3,5], [4,2], [4,3], [5,2], [5,3]] - - Examples - -------- - - .. code-block:: python - - import autoarray as aa - - mask_2d = aa.Mask2D( - mask=[[True, True, True, True] - [True, False, False, True], - [True, False, True, True], - [True, True, True, True]] - pixel_scales=1.0, - ) - - derive_indexes_2d = aa.DeriveIndexes2D(mask=mask_2d) - - print(derive_indexes_2d.sub_mask_native_for_sub_mask_slim) - """ - return over_sample_util.native_sub_index_for_slim_sub_index_2d_from( - mask_2d=self.mask.array, sub_size=np.array(self.sub_size) - ).astype("int") - @cached_property def slim_for_sub_slim(self) -> np.ndarray: """ @@ -363,7 +317,7 @@ def slim_for_sub_slim(self) -> np.ndarray: print(derive_indexes_2d.slim_for_sub_slim) """ return over_sample_util.slim_index_for_sub_slim_index_via_mask_2d_from( - mask_2d=np.array(self.mask), sub_size=np.array(self.sub_size) + mask_2d=np.array(self.mask), sub_size=self.sub_size.array ).astype("int") @property @@ -386,7 +340,7 @@ def uniform_over_sampled(self): grid = over_sample_util.grid_2d_slim_over_sampled_via_mask_from( mask_2d=np.array(self.mask), pixel_scales=self.mask.pixel_scales, - sub_size=np.array(self.sub_size).astype("int"), + sub_size=self.sub_size.array.astype("int"), origin=self.mask.origin, ) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index acbe6bb73..d2d8d3d2d 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -1,29 +1,21 @@ -from astropy import units import copy +import jax.numpy as jnp import numpy as np import warnings +from typing import Tuple class NUFFTPlaceholder: pass -class PyLopsPlaceholder: - pass - - try: from pynufft.linalg.nufft_cpu import NUFFT_cpu except ModuleNotFoundError: NUFFT_cpu = NUFFTPlaceholder -try: - import pylops - - PyLopsOperator = pylops.LinearOperator -except ModuleNotFoundError: - PyLopsOperator = PyLopsPlaceholder +from autoarray.mask.mask_2d import Mask2D from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.visibilities import Visibilities @@ -42,21 +34,50 @@ def pynufft_exception(): ) -def pylops_exception(): - raise ModuleNotFoundError( - "\n--------------------\n" - "You are attempting to perform interferometer analysis.\n\n" - "However, the optional library PyLops (https://github.com/PyLops/pylops) is not installed.\n\n" - "Install it via the command `pip install pylops==2.3.1`.\n\n" - "----------------------" - ) - - -class TransformerDFT(PyLopsOperator): - def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): - if isinstance(self, PyLopsPlaceholder): - pylops_exception() - +class TransformerDFT: + def __init__( + self, + uv_wavelengths: np.ndarray, + real_space_mask: Mask2D, + preload_transform: bool = True, + ): + """ + A direct Fourier transform (DFT) operator for radio interferometric imaging. + + This class performs the forward and inverse mapping between real-space images and + complex visibilities measured by an interferometer. It uses a direct implementation + of the Fourier transform (not FFT-based), making it suitable for irregular uv-coverage. + + Optionally, it precomputes and stores the sine and cosine terms used in the transform, + which can significantly improve performance for repeated operations but at the cost of memory. + + Parameters + ---------- + uv_wavelengths + The (u, v) coordinates in wavelengths of the measured visibilities. + real_space_mask + The real-space mask that defines the image grid and which pixels are valid. + preload_transform + If True, precomputes and stores the cosine and sine terms for the Fourier transform. + This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets). + + Attributes + ---------- + grid : ndarray + The unmasked real-space grid in radians. + total_visibilities : int + The number of measured visibilities. + total_image_pixels : int + The number of unmasked pixels in the real-space image grid. + preload_real_transforms : ndarray, optional + The precomputed cosine terms used in the real part of the DFT. + preload_imag_transforms : ndarray, optional + The precomputed sine terms used in the imaginary part of the DFT. + real_space_pixels : int + Alias for `total_image_pixels`. + adjoint_scaling : float + Scaling factor applied to the adjoint operator to normalize the inverse transform. + """ super().__init__() self.uv_wavelengths = uv_wavelengths.astype("float") @@ -69,58 +90,88 @@ def __init__(self, uv_wavelengths, real_space_mask, preload_transform=True): self.preload_transform = preload_transform if preload_transform: - self.preload_real_transforms = transformer_util.preload_real_transforms( - grid_radians=np.array(self.grid), - uv_wavelengths=self.uv_wavelengths, + + self.preload_real_transforms = ( + transformer_util.preload_real_transforms_from( + grid_radians=np.array(self.grid.array), + uv_wavelengths=self.uv_wavelengths, + ) ) - self.preload_imag_transforms = transformer_util.preload_imag_transforms( - grid_radians=np.array(self.grid), - uv_wavelengths=self.uv_wavelengths, + self.preload_imag_transforms = ( + transformer_util.preload_imag_transforms_from( + grid_radians=np.array(self.grid.array), + uv_wavelengths=self.uv_wavelengths, + ) ) self.real_space_pixels = self.real_space_mask.pixels_in_mask - self.shape = ( - int(np.prod(self.total_visibilities)), - int(np.prod(self.real_space_pixels)), - ) - self.dtype = "complex128" - self.explicit = False - # NOTE: This is the scaling factor that needs to be applied to the adjoint operator self.adjoint_scaling = (2.0 * self.grid.shape_native[0]) * ( 2.0 * self.grid.shape_native[1] ) - self.matvec_count = 0 - self.rmatvec_count = 0 - self.matmat_count = 0 - self.rmatmat_count = 0 + def visibilities_from(self, image: Array2D) -> Visibilities: + """ + Computes the visibilities from a real-space image using the direct Fourier transform (DFT). + + This method transforms the input image into the uv-plane (Fourier space), simulating the + measurements made by an interferometer at specified uv-wavelengths. + + If `preload_transform` is True, it uses precomputed sine and cosine terms to accelerate the computation. - def visibilities_from(self, image): + Parameters + ---------- + image + The real-space image to be transformed to the uv-plane. Must be defined on the + same grid and mask as this transformer's `real_space_mask`. + + Returns + ------- + The complex visibilities resulting from the Fourier transform of the input image. + """ if self.preload_transform: - visibilities = transformer_util.visibilities_via_preload_jit_from( - image_1d=np.array(image), + visibilities = transformer_util.visibilities_via_preload_from( + image_1d=image.array, preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, ) - else: - visibilities = transformer_util.visibilities_jit( - image_1d=np.array(image.slim), - grid_radians=np.array(self.grid), + visibilities = transformer_util.visibilities_from( + image_1d=image.slim.array, + grid_radians=self.grid.array, uv_wavelengths=self.uv_wavelengths, ) - return Visibilities(visibilities=visibilities) + return Visibilities(visibilities=jnp.array(visibilities)) - def image_from(self, visibilities, use_adjoint_scaling: bool = False): - image_slim = transformer_util.image_via_jit_from( - n_pixels=self.grid.shape[0], - grid_radians=np.array(self.grid), - uv_wavelengths=self.uv_wavelengths, + def image_from( + self, visibilities: Visibilities, use_adjoint_scaling: bool = False + ) -> Array2D: + """ + Computes the real-space image from a set of visibilities using the adjoint of the DFT. + + This is not a true inverse Fourier transform, but rather the adjoint operation, which maps + complex visibilities back into image space. This is typically used as the first step + in inverse imaging algorithms like CLEAN or regularized reconstruction. + + Parameters + ---------- + visibilities + The complex visibilities to be transformed into a real-space image. + use_adjoint_scaling + If True, the result is scaled by a normalization factor. Currently unused. + + Returns + ------- + The real-space image resulting from the adjoint DFT operation, defined on the same + mask as this transformer's `real_space_mask`. + """ + image_slim = transformer_util.image_direct_from( visibilities=visibilities.in_array, + grid_radians=self.grid.array, + uv_wavelengths=self.uv_wavelengths, ) image_native = array_2d_util.array_2d_native_from( @@ -130,30 +181,91 @@ def image_from(self, visibilities, use_adjoint_scaling: bool = False): return Array2D(values=image_native, mask=self.real_space_mask) - def transform_mapping_matrix(self, mapping_matrix): + def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: + """ + Applies the DFT to a mapping matrix that maps source pixels to image pixels. + + This is used in linear inversion frameworks, where the transform of each source basis function + (represented by a column of the mapping matrix) is computed individually. The result is a matrix + mapping source pixels directly to visibilities. + + If `preload_transform` is True, the computation is accelerated using precomputed sine and cosine terms. + + Parameters + ---------- + mapping_matrix + A 2D array of shape (n_image_pixels, n_source_pixels) that maps source pixels to image-plane pixels. + + Returns + ------- + A 2D complex-valued array of shape (n_visibilities, n_source_pixels) that maps source-plane basis + functions directly to the visibilities. + """ if self.preload_transform: - return transformer_util.transformed_mapping_matrix_via_preload_jit_from( + return transformer_util.transformed_mapping_matrix_via_preload_from( mapping_matrix=mapping_matrix, preloaded_reals=self.preload_real_transforms, preloaded_imags=self.preload_imag_transforms, ) - else: - return transformer_util.transformed_mapping_matrix_jit( - mapping_matrix=mapping_matrix, - grid_radians=np.array(self.grid), - uv_wavelengths=self.uv_wavelengths, - ) + return transformer_util.transformed_mapping_matrix_from( + mapping_matrix=mapping_matrix, + grid_radians=self.grid.array, + uv_wavelengths=self.uv_wavelengths, + ) + +class TransformerNUFFT(NUFFT_cpu): + def __init__(self, uv_wavelengths: np.ndarray, real_space_mask: Mask2D, **kwargs): + """ + Performs the Non-Uniform Fast Fourier Transform (NUFFT) for interferometric image reconstruction. + + This transformer uses the PyNUFFT library to efficiently compute the Fourier transform + of an image defined on a regular real-space grid to a set of non-uniform uv-plane (Fourier space) + coordinates, as is typical in radio interferometry. + + It is initialized with the interferometer uv-wavelengths and a real-space mask, which defines + the pixelized image domain. + + Parameters + ---------- + uv_wavelengths + The uv-coordinates (Fourier-space sampling points) corresponding to the measured visibilities. + Should be an array of shape (n_vis, 2), where the two columns represent u and v coordinates in wavelengths. + + real_space_mask + The 2D mask defining the real-space pixel grid on which the image is defined. Used to create the + unmasked grid required for NUFFT planning. + + Notes + ----- + - The `initialize_plan()` method builds the internal NUFFT plan based on the input grid and uv sampling. + - A complex exponential `shift` factor is applied to align the center of the Fourier transform correctly, + accounting for the pixel-center offset in the real-space grid. + - The adjoint operation (used in inverse imaging) must be scaled by `adjoint_scaling` to normalize its output. + - This transformer inherits directly from PyNUFFT's `NUFFT_cpu` base class. + - If `NUFFTPlaceholder` is detected (indicating PyNUFFT is not available), an exception is raised. + + Attributes + ---------- + grid : Grid2D + The real-space pixel grid derived from the mask, in radians. + native_index_for_slim_index : np.ndarray + Index map converting from slim (1D) grid to native (2D) indexing, for image reshaping. + shift : np.ndarray + Complex exponential phase shift applied to account for real-space pixel centering. + real_space_pixels : int + Total number of valid real-space pixels defined by the mask. + total_visibilities : int + Total number of visibilities across all uv-wavelength components. + adjoint_scaling : float + Scaling factor for adjoint operations to normalize reconstructed images. + """ + from astropy import units -class TransformerNUFFT(NUFFT_cpu, PyLopsOperator): - def __init__(self, uv_wavelengths, real_space_mask): if isinstance(self, NUFFTPlaceholder): pynufft_exception() - if isinstance(self, PyLopsPlaceholder): - pylops_exception() - super(TransformerNUFFT, self).__init__() self.uv_wavelengths = uv_wavelengths @@ -189,27 +301,42 @@ def __init__(self, uv_wavelengths, real_space_mask): # NOTE: If reshaped the shape of the operator is (2 x Nvis, Np) else it is (Nvis, Np) self.total_visibilities = int(uv_wavelengths.shape[0] * uv_wavelengths.shape[1]) - self.shape = ( - int(np.prod(self.total_visibilities)), - int(np.prod(self.real_space_pixels)), - ) - - # NOTE: If the operator is reshaped then the output is real. - self.dtype = "float64" - - self.explicit = False - # NOTE: This is the scaling factor that needs to be applied to the adjoint operator self.adjoint_scaling = (2.0 * self.grid.shape_native[0]) * ( 2.0 * self.grid.shape_native[1] ) - self.matvec_count = 0 - self.rmatvec_count = 0 - self.matmat_count = 0 - self.rmatmat_count = 0 + def initialize_plan(self, ratio: int = 2, interp_kernel: Tuple[int, int] = (6, 6)): + """ + Initializes the PyNUFFT plan for performing the NUFFT operation. + + This method precomputes the interpolation structure and gridding + needed by the NUFFT algorithm to map between the regular real-space + image grid and the non-uniform uv-plane sampling defined by the + interferometric visibilities. + + Parameters + ---------- + ratio + The oversampling ratio used to pad the Fourier grid before interpolation. + A higher value improves accuracy at the cost of increased memory and computation. + Default is 2 (i.e., the Fourier grid is twice the size of the image grid). + + interp_kernel + The interpolation kernel size along each axis, given as (Jy, Jx). + This determines how many neighboring Fourier grid points are used + to interpolate each uv-point. + Default is (6, 6), a good trade-off between accuracy and performance. + + Notes + ----- + - The uv-coordinates are normalized and rescaled into the range expected by PyNUFFT + using the real-space grid’s pixel scale and the Nyquist frequency limit. + - The plan must be initialized before performing any NUFFT operations (e.g., forward or adjoint). + - This method modifies the internal state of the NUFFT object by calling `self.plan(...)`. + """ + from astropy import units - def initialize_plan(self, ratio=2, interp_kernel=(6, 6)): if not isinstance(ratio, int): ratio = int(ratio) @@ -233,9 +360,23 @@ def initialize_plan(self, ratio=2, interp_kernel=(6, 6)): Jd=interp_kernel, ) - def visibilities_from(self, image): + def visibilities_from(self, image: Array2D) -> Visibilities: """ - ... + Computes visibilities from a real-space image using the NUFFT forward transform. + + Parameters + ---------- + image + The input image in real space, represented as a 2D array object. + + Returns + ------- + The complex visibilities in the uv-plane computed via the NUFFT forward operation. + + Notes + ----- + - The image is flipped vertically before transformation to account for PyNUFFT’s internal data layout. + - Warnings during the NUFFT computation are suppressed for cleaner output. """ warnings.filterwarnings("ignore") @@ -246,7 +387,29 @@ def visibilities_from(self, image): ) # flip due to PyNUFFT internal flip ) - def image_from(self, visibilities, use_adjoint_scaling: bool = False): + def image_from( + self, visibilities: Visibilities, use_adjoint_scaling: bool = False + ) -> Array2D: + """ + Reconstructs a real-space image from visibilities using the NUFFT adjoint transform. + + Parameters + ---------- + visibilities + The complex visibilities in the uv-plane to be inverted. + use_adjoint_scaling + If True, apply a scaling factor to the adjoint result to improve accuracy. + Default is False. + + Returns + ------- + The reconstructed real-space image after applying the NUFFT adjoint transform. + + Notes + ----- + - The output image is flipped vertically to align with the input image orientation. + - Warnings during the adjoint operation are suppressed. + """ with warnings.catch_warnings(): warnings.simplefilter("ignore") image = np.real(self.adjoint(visibilities))[::-1, :] @@ -256,7 +419,25 @@ def image_from(self, visibilities, use_adjoint_scaling: bool = False): return Array2D(values=image, mask=self.real_space_mask) - def transform_mapping_matrix(self, mapping_matrix): + def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray: + """ + Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities. + + Parameters + ---------- + mapping_matrix + A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space. + + Returns + ------- + A complex-valued 2D array where each column contains the visibilities corresponding to the respective column + in the input mapping matrix. + + Notes + ----- + - Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation. + - This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive. + """ transformed_mapping_matrix = 0 + 0j * np.zeros( (self.uv_wavelengths.shape[0], mapping_matrix.shape[1]) ) @@ -274,61 +455,3 @@ def transform_mapping_matrix(self, mapping_matrix): transformed_mapping_matrix[:, source_pixel_1d_index] = visibilities return transformed_mapping_matrix - - def forward_lop(self, x): - """ - Forward NUFFT on CPU - :param x: The input numpy array, with the size of Nd or Nd + (batch,) - :type: numpy array with the dtype of numpy.complex64 - :return: y: The output numpy array, with the size of (M,) or (M, batch) - :rtype: numpy array with the dtype of numpy.complex64 - """ - - warnings.filterwarnings("ignore") - - x2d = array_2d_util.array_2d_native_complex_via_indexes_from( - array_2d_slim=x, - shape_native=self.real_space_mask.shape_native, - native_index_for_slim_index_2d=self.native_index_for_slim_index, - )[::-1, :] - - y = self.k2y(self.xx2k(self.x2xx(x2d))) - return np.concatenate((y.real, y.imag), axis=0) - - def adjoint_lop(self, y): - """ - Adjoint NUFFT on CPU - :param y: The input numpy array, with the size of (M,) or (M, batch) - :type: numpy array with the dtype of numpy.complex64 - :return: x: The output numpy array, - with the size of Nd or Nd + (batch, ) - :rtype: numpy array with the dtype of numpy.complex64 - """ - - warnings.filterwarnings("ignore") - - def a_complex_from(a_real, a_imag): - return a_real + 1j * a_imag - - y = a_complex_from( - a_real=y[: int(self.shape[0] / 2.0)], a_imag=y[int(self.shape[0] / 2.0) :] - ) - - x2d = np.real(self.xx2x(self.k2xx(self.y2k(y)))) - - x = array_2d_util.array_2d_slim_complex_from( - array_2d_native=x2d[::-1, :], - mask=np.array(self.real_space_mask), - ) - x = x.real # NOTE: - - # NOTE: - x *= self.adjoint_scaling - - return x - - def _matvec(self, x): - return self.forward_lop(x) - - def _rmatvec(self, x): - return self.adjoint_lop(x) diff --git a/autoarray/operators/transformer_util.py b/autoarray/operators/transformer_util.py index 395034794..34659510a 100644 --- a/autoarray/operators/transformer_util.py +++ b/autoarray/operators/transformer_util.py @@ -1,20 +1,18 @@ +import jax.numpy as jnp import numpy as np -from autoarray import numba_util - -@numba_util.jit() -def preload_real_transforms( +def preload_real_transforms_from( grid_radians: np.ndarray, uv_wavelengths: np.ndarray ) -> np.ndarray: """ - Sets up the real preloaded values used by the direct fourier transform (`TransformerDFT`) to speed up + Sets up the real preloaded values used by the direct Fourier transform (`TransformerDFT`) to speed up the Fourier transform calculations. The preloaded values are the cosine terms of every (y,x) radian coordinate on the real-space grid multiplied by - everu `uv_wavelength` value. + every `uv_wavelength` value. - For large numbers of visibilities (> 100000) this array requires large amounts of memory ( > 1 GB) and it is + For large numbers of visibilities (> 100000) this array requires large amounts of memory (> 1 GB) and it is recommended this preloading is not used. Parameters @@ -28,179 +26,267 @@ def preload_real_transforms( Returns ------- - np.ndarray - The preloaded values of the cosine terms in the calculation of real entries of the direct Fourier transform. - + The preloaded values of the cosine terms in the calculation of real entries of the direct Fourier transform. """ - - preloaded_real_transforms = np.zeros( - shape=(grid_radians.shape[0], uv_wavelengths.shape[0]) + # Compute the phase matrix: shape (n_pixels, n_visibilities) + phase = ( + -2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + ) ) - for image_1d_index in range(grid_radians.shape[0]): - for vis_1d_index in range(uv_wavelengths.shape[0]): - preloaded_real_transforms[image_1d_index, vis_1d_index] += np.cos( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) + # Compute cosine of the phase matrix + preloaded_real_transforms = np.cos(phase) return preloaded_real_transforms -@numba_util.jit() -def preload_imag_transforms(grid_radians, uv_wavelengths): - preloaded_imag_transforms = np.zeros( - shape=(grid_radians.shape[0], uv_wavelengths.shape[0]) +def preload_imag_transforms_from( + grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: + """ + Sets up the imaginary preloaded values used by the direct Fourier transform (`TransformerDFT`) to speed up + the Fourier transform calculations in interferometric imaging. + + The preloaded values are the sine terms of every (y,x) radian coordinate on the real-space grid multiplied by + every `uv_wavelength` value. These are used to compute the imaginary components of visibilities. + + For large numbers of visibilities (> 100000), this array can require significant memory (> 1 GB), so preloading + should be used with care. + + Parameters + ---------- + grid_radians + The grid in radians corresponding to the (y,x) coordinates in real space. + uv_wavelengths + The (u,v) coordinates in the Fourier plane (in units of wavelengths). + + Returns + ------- + The sine term preloads used in imaginary-part DFT calculations. + """ + # Compute the phase matrix: shape (n_pixels, n_visibilities) + phase = ( + -2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + ) ) - for image_1d_index in range(grid_radians.shape[0]): - for vis_1d_index in range(uv_wavelengths.shape[0]): - preloaded_imag_transforms[image_1d_index, vis_1d_index] += np.sin( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) + # Compute sine of the phase matrix + preloaded_imag_transforms = np.sin(phase) return preloaded_imag_transforms -@numba_util.jit() -def visibilities_via_preload_jit_from(image_1d, preloaded_reals, preloaded_imags): - visibilities = 0 + 0j * np.zeros(shape=(preloaded_reals.shape[1])) +def visibilities_via_preload_from( + image_1d: np.ndarray, preloaded_reals: np.ndarray, preloaded_imags: np.ndarray +) -> np.ndarray: + """ + Computes interferometric visibilities using preloaded real and imaginary DFT transform components. + + This function performs a direct Fourier transform (DFT) using precomputed cosine (real) and sine (imaginary) + terms. It is used in radio astronomy to compute visibilities from an image for a given interferometric + observation setup. + + Parameters + ---------- + image_1d : ndarray of shape (n_pixels,) + The 1D image vector (real-space brightness values). + preloaded_reals : ndarray of shape (n_pixels, n_visibilities) + The preloaded cosine terms (real part of DFT matrix). + preloaded_imags : ndarray of shape (n_pixels, n_visibilities) + The preloaded sine terms (imaginary part of DFT matrix). + + Returns + ------- + visibilities : ndarray of shape (n_visibilities,) + The complex visibilities computed by summing over all pixels. + """ + # Perform the dot product between the image and preloaded transform matrices + vis_real = jnp.dot(image_1d, preloaded_reals) # shape (n_visibilities,) + vis_imag = jnp.dot(image_1d, preloaded_imags) # shape (n_visibilities,) - for image_1d_index in range(image_1d.shape[0]): - for vis_1d_index in range(preloaded_reals.shape[1]): - vis_real = ( - image_1d[image_1d_index] * preloaded_reals[image_1d_index, vis_1d_index] - ) - vis_imag = ( - image_1d[image_1d_index] * preloaded_imags[image_1d_index, vis_1d_index] - ) - visibilities[vis_1d_index] += vis_real + 1j * vis_imag + visibilities = vis_real + 1j * vis_imag return visibilities -@numba_util.jit() -def visibilities_jit(image_1d, grid_radians, uv_wavelengths): - visibilities = 0 + 0j * np.zeros(shape=(uv_wavelengths.shape[0])) - - for image_1d_index in range(image_1d.shape[0]): - for vis_1d_index in range(uv_wavelengths.shape[0]): - vis_real = image_1d[image_1d_index] * np.cos( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) - vis_imag = image_1d[image_1d_index] * np.sin( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) - visibilities[vis_1d_index] += vis_real + 1j * vis_imag +def visibilities_from( + image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: + """ + Compute complex visibilities from an input sky image using the Fourier transform, + simulating the response of an astronomical radio interferometer. + + This function converts an image defined on a sky coordinate grid into its + visibility-space representation, given a set of (u,v) spatial frequency + coordinates (in wavelengths), as sampled by a radio interferometer. + + Parameters + ---------- + image_1d + The 1D flattened sky brightness values corresponding to each pixel in the grid. + grid_radians + The angular (y, x) positions of each image pixel in radians, matching image_1d. + uv_wavelengths + The (u, v) spatial frequencies in units of wavelengths, for each baseline + of the interferometer. + + Returns + ------- + visibilities + The complex visibilities (Fourier components) corresponding to each + (u, v) coordinate, representing the interferometer’s measurement. + """ + + # Compute the dot product for each pixel-uv pair + phase = ( + -2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + ) + ) # shape (n_pixels, n_vis) + + # Multiply image values with phase terms + vis_real = image_1d[:, None] * np.cos(phase) + vis_imag = image_1d[:, None] * np.sin(phase) + + # Sum over all pixels for each visibility + visibilities = np.sum(vis_real + 1j * vis_imag, axis=0) return visibilities -@numba_util.jit() -def image_via_jit_from(n_pixels, grid_radians, uv_wavelengths, visibilities): - image_1d = np.zeros(n_pixels) - - for image_1d_index in range(image_1d.shape[0]): - for vis_1d_index in range(uv_wavelengths.shape[0]): - image_1d[image_1d_index] += visibilities[vis_1d_index, 0] * np.cos( - 2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) - - image_1d[image_1d_index] -= visibilities[vis_1d_index, 1] * np.sin( - 2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] * uv_wavelengths[vis_1d_index, 1] - ) - ) +def image_direct_from( + visibilities: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: + """ + Reconstruct a real-valued sky image from complex interferometric visibilities + using an inverse Fourier transform approximation. - return image_1d + This function simulates the synthesis imaging equation of a radio interferometer + by summing sinusoidal components across all (u, v) spatial frequencies. + Parameters + ---------- + visibilities + The real and imaginary parts of the complex visibilities for each (u, v) point. -@numba_util.jit() -def transformed_mapping_matrix_via_preload_jit_from( - mapping_matrix, preloaded_reals, preloaded_imags -): - transfomed_mapping_matrix = 0 + 0j * np.zeros( - (preloaded_reals.shape[1], mapping_matrix.shape[1]) + grid_radians + The angular (y, x) coordinates of each pixel in radians. + + uv_wavelengths + The (u, v) spatial frequencies in units of wavelengths for each baseline. + + Returns + ------- + image_1d + The reconstructed real-valued image in sky coordinates. + """ + # Compute the phase term for each (pixel, visibility) pair + phase = ( + 2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) + ) ) - for pixel_1d_index in range(mapping_matrix.shape[1]): - for image_1d_index in range(mapping_matrix.shape[0]): - value = mapping_matrix[image_1d_index, pixel_1d_index] + real_part = np.dot(np.cos(phase), visibilities[:, 0]) + imag_part = np.dot(np.sin(phase), visibilities[:, 1]) + + image_1d = real_part - imag_part + + return image_1d + + +def transformed_mapping_matrix_via_preload_from( + mapping_matrix: np.ndarray, preloaded_reals: np.ndarray, preloaded_imags: np.ndarray +) -> np.ndarray: + """ + Computes the Fourier-transformed mapping matrix using preloaded sine and cosine terms for efficiency. + + This function transforms each source pixel's mapping to visibilities by using precomputed + real (cosine) and imaginary (sine) terms from the direct Fourier transform. + It is used in radio interferometric imaging where source-to-image mappings are projected + into the visibility space. + + Parameters + ---------- + mapping_matrix + The mapping matrix from image-plane pixels to source-plane pixels. + preloaded_reals + Precomputed cosine terms for each pixel-vis pair: cos(-2π(yu + xv)). + preloaded_imags + Precomputed sine terms for each pixel-vis pair: sin(-2π(yu + xv)). + + Returns + ------- + Complex-valued matrix mapping source pixels to visibilities. + """ - if value > 0: - for vis_1d_index in range(preloaded_reals.shape[1]): - vis_real = value * preloaded_reals[image_1d_index, vis_1d_index] - vis_imag = value * preloaded_imags[image_1d_index, vis_1d_index] - transfomed_mapping_matrix[vis_1d_index, pixel_1d_index] += ( - vis_real + 1j * vis_imag - ) + # Broadcasted multiplication and matrix multiplication over non-zero entries - return transfomed_mapping_matrix + vis_real = preloaded_reals.T @ mapping_matrix # (n_visibilities, n_source_pixels) + vis_imag = preloaded_imags.T @ mapping_matrix + transformed_matrix = vis_real + 1j * vis_imag + + return transformed_matrix + + +def transformed_mapping_matrix_from( + mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray +) -> np.ndarray: + """ + Computes the Fourier-transformed mapping matrix used in radio interferometric imaging. -@numba_util.jit() -def transformed_mapping_matrix_jit(mapping_matrix, grid_radians, uv_wavelengths): - transfomed_mapping_matrix = 0 + 0j * np.zeros( - (uv_wavelengths.shape[0], mapping_matrix.shape[1]) + This function applies a direct Fourier transform to each pixel column of the mapping matrix using the + uv-wavelength coordinates. The result is a matrix that maps source pixel intensities to complex visibilities, + which represent how a model image would appear to an interferometer. + + Parameters + ---------- + mapping_matrix : ndarray of shape (n_image_pixels, n_source_pixels) + The mapping matrix from image-plane pixels to source-plane pixels. + grid_radians : ndarray of shape (n_image_pixels, 2) + The (y,x) positions of each image pixel in radians. + uv_wavelengths : ndarray of shape (n_visibilities, 2) + The (u,v) coordinates of the sampled Fourier modes in units of wavelength. + + Returns + ------- + transformed_matrix : ndarray of shape (n_visibilities, n_source_pixels) + The transformed mapping matrix in the visibility domain (complex-valued). + """ + # Compute phase term: (n_image_pixels, n_visibilities) + phase = ( + -2.0 + * np.pi + * ( + np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u + + np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v + ) ) - for pixel_1d_index in range(mapping_matrix.shape[1]): - for image_1d_index in range(mapping_matrix.shape[0]): - value = mapping_matrix[image_1d_index, pixel_1d_index] - - if value > 0: - for vis_1d_index in range(uv_wavelengths.shape[0]): - vis_real = value * np.cos( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] - * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] - * uv_wavelengths[vis_1d_index, 1] - ) - ) - - vis_imag = value * np.sin( - -2.0 - * np.pi - * ( - grid_radians[image_1d_index, 1] - * uv_wavelengths[vis_1d_index, 0] - + grid_radians[image_1d_index, 0] - * uv_wavelengths[vis_1d_index, 1] - ) - ) - - transfomed_mapping_matrix[vis_1d_index, pixel_1d_index] += ( - vis_real + 1j * vis_imag - ) - - return transfomed_mapping_matrix + # Compute real and imaginary Fourier matrices + fourier_real = np.cos(phase) + fourier_imag = np.sin(phase) + + # Only compute contributions from non-zero mapping entries + # This matrix multiplication is: (n_visibilities x n_image_pixels) dot (n_image_pixels x n_source_pixels) + vis_real = fourier_real.T @ mapping_matrix # (n_vis, n_src) + vis_imag = fourier_imag.T @ mapping_matrix # (n_vis, n_src) + + transformed_matrix = vis_real + 1j * vis_imag + + return transformed_matrix diff --git a/autoarray/plot/__init__.py b/autoarray/plot/__init__.py index 71d7abb45..c45d31702 100644 --- a/autoarray/plot/__init__.py +++ b/autoarray/plot/__init__.py @@ -22,6 +22,7 @@ from autoarray.plot.wrap.two_d.array_overlay import ArrayOverlay from autoarray.plot.wrap.two_d.contour import Contour +from autoarray.plot.wrap.two_d.fill import Fill from autoarray.plot.wrap.two_d.grid_scatter import GridScatter from autoarray.plot.wrap.two_d.grid_plot import GridPlot from autoarray.plot.wrap.two_d.grid_errorbar import GridErrorbar @@ -43,12 +44,8 @@ from autoarray.plot.wrap.two_d.serial_prescan_plot import SerialPrescanPlot from autoarray.plot.wrap.two_d.serial_overscan_plot import SerialOverscanPlot -from autoarray.plot.get_visuals.one_d import GetVisuals1D -from autoarray.plot.get_visuals.two_d import GetVisuals2D from autoarray.plot.mat_plot.one_d import MatPlot1D from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.visuals.one_d import Visuals1D from autoarray.plot.visuals.two_d import Visuals2D from autoarray.plot.auto_labels import AutoLabels diff --git a/autoarray/plot/abstract_plotters.py b/autoarray/plot/abstract_plotters.py index 07db07291..07ec41354 100644 --- a/autoarray/plot/abstract_plotters.py +++ b/autoarray/plot/abstract_plotters.py @@ -4,17 +4,12 @@ set_backend() -import matplotlib.pyplot as plt from typing import Optional, Tuple from autoarray.plot.visuals.one_d import Visuals1D from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.one_d import MatPlot1D from autoarray.plot.mat_plot.two_d import MatPlot2D -from autoarray.plot.get_visuals.one_d import GetVisuals1D -from autoarray.plot.get_visuals.two_d import GetVisuals2D class AbstractPlotter: @@ -22,18 +17,14 @@ def __init__( self, mat_plot_1d: MatPlot1D = None, visuals_1d: Visuals1D = None, - include_1d: Include1D = None, mat_plot_2d: MatPlot2D = None, visuals_2d: Visuals2D = None, - include_2d: Include2D = None, ): - self.visuals_1d = visuals_1d - self.include_1d = include_1d - self.mat_plot_1d = mat_plot_1d + self.visuals_1d = visuals_1d or Visuals1D() + self.mat_plot_1d = mat_plot_1d or MatPlot1D() - self.visuals_2d = visuals_2d - self.include_2d = include_2d - self.mat_plot_2d = mat_plot_2d + self.visuals_2d = visuals_2d or Visuals2D() + self.mat_plot_2d = mat_plot_2d or MatPlot2D() self.subplot_figsize = None @@ -112,6 +103,7 @@ def open_subplot_figure( If the figure is a subplot, the setup_figure function is omitted to ensure that each subplot does not create a \ new figure and so that it can be output using the *output.output_figure(structure=None)* function. """ + import matplotlib.pyplot as plt self.set_mat_plots_for_subplot( is_for_subplot=True, @@ -219,13 +211,3 @@ def subplot_of_plotters_figure(self, plotter_list, name): self.mat_plot_2d.output.subplot_to_figure(auto_filename=f"subplot_{name}") self.close_subplot_figure() - - -class Plotter(AbstractPlotter): - @property - def get_1d(self): - return GetVisuals1D(visuals=self.visuals_1d, include=self.include_1d) - - @property - def get_2d(self): - return GetVisuals2D(visuals=self.visuals_2d, include=self.include_2d) diff --git a/autoarray/plot/get_visuals/__init__.py b/autoarray/plot/get_visuals/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/autoarray/plot/get_visuals/abstract.py b/autoarray/plot/get_visuals/abstract.py deleted file mode 100644 index f149fcf19..000000000 --- a/autoarray/plot/get_visuals/abstract.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional, Union - -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.plot.visuals.two_d import Visuals2D - - -class AbstractGetVisuals: - def __init__( - self, include: Union[Include1D, Include2D], visuals: Union[Visuals1D, Visuals2D] - ): - """ - Class which gets attributes and adds them to a `Visuals` objects, such that they are plotted on figures. - - For a visual to be extracted and added for plotting, it must have a `True` value in its corresponding entry in - the `Include` object. If this entry is `False`, the `GetVisuals.get` method returns a None and the attribute - is omitted from the plot. - - The `GetVisuals` class adds new visuals to a pre-existing `Visuals` object that is passed to its `__init__` - method. This only adds a new entry if the visual are not already in this object. - - Parameters - ---------- - include - Sets which visuals are included on the figure that is to be plotted (only entries which are `True` - are extracted via the `GetVisuals` object). - visuals - The pre-existing visuals of the plotter which new visuals are added too via the `GetVisuals` class. - """ - self.include = include - self.visuals = visuals - - def get(self, name: str, value, include_name: Optional[str] = None): - """ - Get an attribute for plotting in a `Visuals1D` object based on the following criteria: - - 1) If `visuals_1d` already has a value for the attribute this is returned, over-riding the input `value` of - that attribute. - - 2) If `visuals_1d` do not contain the attribute, the input `value` is returned provided its corresponding - entry in the `Include1D` class is `True`. - - 3) If the `Include1D` entry is `False` a None is returned and the attribute is therefore not plotted. - - Parameters - ---------- - name - The name of the attribute which is to be extracted. - value - The `value` of the attribute, which is used when criteria 2 above is met. - - Returns - ------- - The collection of attributes that can be plotted by a `Plotter` object. - """ - - if include_name is None: - include_name = name - - if getattr(self.visuals, name) is not None: - return getattr(self.visuals, name) - elif getattr(self.include, include_name): - return value diff --git a/autoarray/plot/get_visuals/one_d.py b/autoarray/plot/get_visuals/one_d.py deleted file mode 100644 index 449120c2c..000000000 --- a/autoarray/plot/get_visuals/one_d.py +++ /dev/null @@ -1,54 +0,0 @@ -from autoarray.plot.get_visuals.abstract import AbstractGetVisuals -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.visuals.one_d import Visuals1D -from autoarray.structures.arrays.uniform_1d import Array1D - - -class GetVisuals1D(AbstractGetVisuals): - def __init__(self, include: Include1D, visuals: Visuals1D): - """ - Class which gets 1D attributes and adds them to a `Visuals1D` objects, such that they are plotted on 1D figures. - - For a visual to be extracted and added for plotting, it must have a `True` value in its corresponding entry in - the `Include1D` object. If this entry is `False`, the `GetVisuals1D.get` method returns a None and the attribute - is omitted from the plot. - - The `GetVisuals1D` class adds new visuals to a pre-existing `Visuals1D` object that is passed to its `__init__` - method. This only adds a new entry if the visual are not already in this object. - - Parameters - ---------- - include - Sets which 1D visuals are included on the figure that is to be plotted (only entries which are `True` - are extracted via the `GetVisuals1D` object). - visuals - The pre-existing visuals of the plotter which new visuals are added too via the `GetVisuals1D` class. - """ - super().__init__(include=include, visuals=visuals) - - def via_array_1d_from(self, array_1d: Array1D) -> Visuals1D: - """ - From an `Array1D` get its attributes that can be plotted and return them in a `Visuals1D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include1D` object are extracted - for plotting. - - From an `Array1D` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the 1D array's coordinate system. - - mask: the mask of the 1D array. - - Parameters - ---------- - array - The 1D array whose attributes are extracted for plotting. - - Returns - ------- - Visuals1D - The collection of attributes that are plotted by a `Plotter` object. - """ - return self.visuals + self.visuals.__class__( - origin=self.get("origin", array_1d.origin), - mask=self.get("mask", array_1d.mask), - ) diff --git a/autoarray/plot/get_visuals/two_d.py b/autoarray/plot/get_visuals/two_d.py deleted file mode 100644 index c2b99a173..000000000 --- a/autoarray/plot/get_visuals/two_d.py +++ /dev/null @@ -1,231 +0,0 @@ -from typing import Union - -from autoarray.fit.fit_imaging import FitImaging -from autoarray.inversion.pixelization.mappers.rectangular import ( - MapperRectangular, -) -from autoarray.mask.mask_2d import Mask2D -from autoarray.plot.get_visuals.abstract import AbstractGetVisuals -from autoarray.plot.include.two_d import Include2D -from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.structures.grids.uniform_2d import Grid2D -from autoarray.structures.grids.irregular_2d import Grid2DIrregular - -from autoarray.type import Grid2DLike - - -class GetVisuals2D(AbstractGetVisuals): - def __init__(self, include: Include2D, visuals: Visuals2D): - """ - Class which gets 2D attributes and adds them to a `Visuals2D` objects, such that they are plotted on 2D figures. - - For a visual to be extracted and added for plotting, it must have a `True` value in its corresponding entry in - the `Include2D` object. If this entry is `False`, the `GetVisuals2D.get` method returns a None and the - attribute is omitted from the plot. - - The `GetVisuals2D` class adds new visuals to a pre-existing `Visuals2D` object that is passed to - its `__init__` method. This only adds a new entry if the visual are not already in this object. - - Parameters - ---------- - include - Sets which 2D visuals are included on the figure that is to be plotted (only entries which are `True` - are extracted via the `GetVisuals2D` object). - visuals - The pre-existing visuals of the plotter which new visuals are added too via the `GetVisuals2D` class. - """ - super().__init__(include=include, visuals=visuals) - - def origin_via_mask_from(self, mask: Mask2D) -> Grid2DIrregular: - """ - From a `Mask2D` get its origin for plotter, which is only extracted if an origin is not already - in `self.visuals` and with `True` entries in the `Include2D` object are extracted for plotting. - - Parameters - ---------- - mask - The 2D mask whose origin is extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that are plotted by a `Plotter` object, which include the origin if it is - extracted. - """ - return self.get("origin", Grid2DIrregular(values=[mask.origin])) - - def via_mask_from(self, mask: Mask2D) -> Visuals2D: - """ - From a `Mask2D` get its attributes that can be plotted and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `Mask2D` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the 2D coordinate system. - - mask: the 2D mask. - - border: the border of the 2D mask, which are all of the mask's exterior edge pixels. - - Parameters - ---------- - mask - The 2D mask whose attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that are plotted by a `Plotter` object. - """ - origin = self.origin_via_mask_from(mask=mask) - mask_visuals = self.get("mask", mask) - border = self.get("border", mask.derive_grid.border) - - return self.visuals + self.visuals.__class__( - origin=origin, mask=mask_visuals, border=border - ) - - def via_grid_from(self, grid: Grid2DLike) -> Visuals2D: - """ - From a `Grid2D` get its attributes that can be plotted and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `Grid2D` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the grid's coordinate system. - - Parameters - ---------- - grid : Grid2D - The grid whose attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that can be plotted by a `Plotter` object. - """ - if not isinstance(grid, Grid2D): - return self.visuals - - origin = self.origin_via_mask_from(mask=grid.mask) - - return self.visuals + self.visuals.__class__(origin=origin) - - def via_mapper_for_data_from(self, mapper: MapperRectangular) -> Visuals2D: - """ - From a `Mapper` get its attributes that can be plotted in the mapper's data-plane (e.g. the reconstructed - data) and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `Mapper` the following attributes can be extracted for plotting in the data-plane: - - - origin: the (y,x) origin of the `Array2D`'s coordinate system in the data plane. - - mask : the `Mask2D` defined in the data-plane containing the data that is used by the `Mapper`. - - mapper_image_plane_mesh_grid: the `Mapper`'s pixelization's mesh in the data-plane. - - mapper_border_grid: the border of the `Mapper`'s full grid in the data-plane. - - Parameters - ---------- - mapper - The mapper whose data-plane attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that can be plotted by a `Plotter` object. - """ - - visuals_via_mask = self.via_mask_from(mask=mapper.mapper_grids.mask) - - mesh_grid = self.get( - "mesh_grid", mapper.image_plane_mesh_grid, "mapper_image_plane_mesh_grid" - ) - - return ( - self.visuals - + visuals_via_mask - + self.visuals.__class__(mesh_grid=mesh_grid) - ) - - def via_mapper_for_source_from(self, mapper: MapperRectangular) -> Visuals2D: - """ - From a `Mapper` get its attributes that can be plotted in the mapper's source-plane (e.g. the reconstruction) - and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `Mapper` the following attributes can be extracted for plotting in the source-plane: - - - origin: the (y,x) origin of the coordinate system in the source plane. - - mapper_source_plane_data_grid: the (y,x) grid of coordinates in the mapper's source-plane which are paired with - the mapper's pixelization's mesh pixels. - - mapper_source_plane_mesh_grid: the `Mapper`'s pixelization's mesh grid in the source-plane. - - mapper_border_grid: the border of the `Mapper`'s full grid in the data-plane. - - Parameters - ---------- - mapper - The mapper whose source-plane attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that can be plotted by a `Plotter2D` object. - """ - - origin = self.get( - "origin", Grid2DIrregular(values=[mapper.source_plane_mesh_grid.origin]) - ) - - grid = self.get( - "grid", - mapper.source_plane_data_grid.over_sampled, - "mapper_source_plane_data_grid", - ) - - try: - border_grid = mapper.mapper_grids.source_plane_data_grid.over_sampled[ - mapper.border_relocator.sub_border_slim - ] - border = self.get("border", border_grid) - - except AttributeError: - border = None - - mesh_grid = self.get( - "mesh_grid", mapper.source_plane_mesh_grid, "mapper_source_plane_mesh_grid" - ) - - return self.visuals + self.visuals.__class__( - origin=origin, grid=grid, border=border, mesh_grid=mesh_grid - ) - - def via_fit_imaging_from(self, fit: FitImaging) -> Visuals2D: - """ - From a `FitImaging` get its attributes that can be plotted and return them in a `Visuals2D` object. - - Only attributes not already in `self.visuals` and with `True` entries in the `Include2D` object are extracted - for plotting. - - From a `FitImaging` the following attributes can be extracted for plotting: - - - origin: the (y,x) origin of the 2D coordinate system. - - mask: the 2D mask. - - border: the border of the 2D mask, which are all of the mask's exterior edge pixels. - - Parameters - ---------- - fit - The fit imaging object whose attributes are extracted for plotting. - - Returns - ------- - Visuals2D - The collection of attributes that are plotted by a `Plotter` object. - """ - return self.via_mask_from(mask=fit.mask) diff --git a/autoarray/plot/include/__init__.py b/autoarray/plot/include/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/autoarray/plot/include/abstract.py b/autoarray/plot/include/abstract.py deleted file mode 100644 index baaaa586b..000000000 --- a/autoarray/plot/include/abstract.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Optional - -from autoconf import conf - - -class AbstractInclude: - def __init__(self, origin: Optional[bool] = None, mask: Optional[bool] = None): - """ - Sets which `Visuals` are included on a figure that is plotted using a `Plotter`. - - The `Include` object is used to extract the visuals of the plotted data structure (e.g. `Array2D`, `Grid2D`) so - they can be used in plot functions. Only visuals with a `True` entry in the `Include` object are extracted and t - plotted. - - If an entry is not input into the class (e.g. it retains its default entry of `None`) then the bool is - loaded from the `config/visualize/include.ini` config file. This means the default visuals of a project - can be specified in a config file. - - Parameters - ---------- - origin - If `True`, the `origin` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - mask - if `True`, the `mask` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - """ - - self._origin = origin - self._mask = mask - - def load(self, value, name): - if value is True: - return True - elif value is False: - return False - elif value is None: - return conf.instance["visualize"]["include"][self.section][name] - - @property - def section(self): - raise NotImplementedError - - @property - def origin(self): - return self.load(value=self._origin, name="origin") - - @property - def mask(self): - return self.load(value=self._mask, name="mask") diff --git a/autoarray/plot/include/one_d.py b/autoarray/plot/include/one_d.py deleted file mode 100644 index 593471a74..000000000 --- a/autoarray/plot/include/one_d.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from autoarray.plot.include.abstract import AbstractInclude - - -class Include1D(AbstractInclude): - def __init__(self, origin: Optional[bool] = None, mask: Optional[bool] = None): - """ - Sets which `Visuals1D` are included on a figure plotting 1D data that is plotted using a `Plotter1D`. - - The `Include` object is used to extract the visuals of the plotted 1D data structures so they can be used in - plot functions. Only visuals with a `True` entry in the `Include` object are extracted and plotted. - - If an entry is not input into the class (e.g. it retains its default entry of `None`) then the bool is - loaded from the `config/visualize/include.ini` config file. This means the default visuals of a project - can be specified in a config file. - - Parameters - ---------- - origin - If `True`, the `origin` of the plotted data structure (e.g. `Line`) is included on the figure. - mask - if `True`, the `mask` of the plotted data structure (e.g. `Line`) is included on the figure. - """ - super().__init__(origin=origin, mask=mask) - - @property - def section(self): - return "include_1d" diff --git a/autoarray/plot/include/two_d.py b/autoarray/plot/include/two_d.py deleted file mode 100644 index ebf284e29..000000000 --- a/autoarray/plot/include/two_d.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Optional - -from autoarray.plot.include.abstract import AbstractInclude - - -class Include2D(AbstractInclude): - def __init__( - self, - origin: Optional[bool] = None, - mask: Optional[bool] = None, - border: Optional[bool] = None, - grid: Optional[bool] = None, - mapper_image_plane_mesh_grid: Optional[bool] = None, - mapper_source_plane_mesh_grid: Optional[bool] = None, - mapper_source_plane_data_grid: Optional[bool] = None, - parallel_overscan: Optional[bool] = None, - serial_prescan: Optional[bool] = None, - serial_overscan: Optional[bool] = None, - ): - """ - Sets which `Visuals2D` are included on a figure plotting 2D data that is plotted using a `Plotter`. - - The `Include` object is used to extract the visuals of the plotted 2D data structures so they can be used in - plot functions. Only visuals with a `True` entry in the `Include` object are extracted and plotted. - - If an entry is not input into the class (e.g. it retains its default entry of `None`) then the bool is - loaded from the `config/visualize/include.ini` config file. This means the default visuals of a project - can be specified in a config file. - - Parameters - ---------- - origin - If `True`, the `origin` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - mask - if `True`, the `mask` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - border - If `True`, the `border` of the plotted data structure (e.g. `Array2D`, `Grid2D`) is included on the figure. - mapper_image_plane_mesh_grid - If `True`, the pixelization grid in the data plane of a plotted `Mapper` is included on the figure. - mapper_source_plane_mesh_grid - If `True`, the pixelization grid in the source plane of a plotted `Mapper` is included on the figure. - parallel_overscan - If `True`, the parallel overscan of a plotted `Frame2D` is included on the figure. - serial_prescan - If `True`, the serial prescan of a plotted `Frame2D` is included on the figure. - serial_overscan - If `True`, the serial overscan of a plotted `Frame2D` is included on the figure. - """ - - super().__init__(origin=origin, mask=mask) - - self._border = border - self._grid = grid - self._mapper_image_plane_mesh_grid = mapper_image_plane_mesh_grid - self._mapper_source_plane_mesh_grid = mapper_source_plane_mesh_grid - self._mapper_source_plane_data_grid = mapper_source_plane_data_grid - self._parallel_overscan = parallel_overscan - self._serial_prescan = serial_prescan - self._serial_overscan = serial_overscan - - @property - def section(self): - return "include_2d" - - @property - def border(self): - return self.load(value=self._border, name="border") - - @property - def grid(self): - return self.load(value=self._grid, name="grid") - - @property - def mapper_image_plane_mesh_grid(self): - return self.load( - value=self._mapper_image_plane_mesh_grid, - name="mapper_image_plane_mesh_grid", - ) - - @property - def mapper_source_plane_mesh_grid(self): - return self.load( - value=self._mapper_source_plane_mesh_grid, - name="mapper_source_plane_mesh_grid", - ) - - @property - def mapper_source_plane_data_grid(self): - return self.load( - value=self._mapper_source_plane_data_grid, - name="mapper_source_plane_data_grid", - ) - - @property - def parallel_overscan(self): - return self.load(value=self._parallel_overscan, name="parallel_overscan") - - @property - def serial_prescan(self): - return self.load(value=self._serial_prescan, name="serial_prescan") - - @property - def serial_overscan(self): - return self.load(value=self._serial_overscan, name="serial_overscan") diff --git a/autoarray/plot/mat_plot/one_d.py b/autoarray/plot/mat_plot/one_d.py index 733f2ff83..79b721bba 100644 --- a/autoarray/plot/mat_plot/one_d.py +++ b/autoarray/plot/mat_plot/one_d.py @@ -166,6 +166,17 @@ def plot_yx( text_manual_dict_y=None, bypass: bool = False, ): + + try: + y = y.array + except AttributeError: + pass + + try: + x = x.array + except AttributeError: + pass + if (y is None) or np.count_nonzero(y) == 0 or np.isnan(y).all(): return @@ -185,7 +196,7 @@ def plot_yx( x = np.arange(len(y)) use_integers = True pixel_scales = (x[1] - x[0],) - x = Array1D.no_mask(values=x, pixel_scales=pixel_scales) + x = Array1D.no_mask(values=x, pixel_scales=pixel_scales).array if self.yx_plot.plot_axis_type is None: plot_axis_type = "linear" diff --git a/autoarray/plot/mat_plot/two_d.py b/autoarray/plot/mat_plot/two_d.py index 9fcec3657..b85caba1f 100644 --- a/autoarray/plot/mat_plot/two_d.py +++ b/autoarray/plot/mat_plot/two_d.py @@ -1,6 +1,6 @@ import matplotlib.pyplot as plt import numpy as np -from typing import Optional, List, Tuple, Union +from typing import Optional, List, Union from autoconf import conf @@ -9,10 +9,12 @@ ) from autoarray.inversion.pixelization.mappers.delaunay import MapperDelaunay from autoarray.inversion.pixelization.mappers.voronoi import MapperVoronoi +from autoarray.mask.derive.zoom_2d import Zoom2D from autoarray.plot.mat_plot.abstract import AbstractMatPlot from autoarray.plot.auto_labels import AutoLabels from autoarray.plot.visuals.two_d import Visuals2D from autoarray.structures.arrays.uniform_2d import Array2D +from autoarray.structures.arrays.rgb import Array2DRGB from autoarray.structures.arrays import array_2d_util @@ -41,6 +43,7 @@ def __init__( legend: Optional[wb.Legend] = None, output: Optional[wb.Output] = None, array_overlay: Optional[w2d.ArrayOverlay] = None, + fill: Optional[w2d.Fill] = None, contour: Optional[w2d.Contour] = None, grid_scatter: Optional[w2d.GridScatter] = None, grid_plot: Optional[w2d.GridPlot] = None, @@ -61,6 +64,7 @@ def __init__( serial_prescan_plot: Optional[w2d.SerialPrescanPlot] = None, serial_overscan_plot: Optional[w2d.SerialOverscanPlot] = None, use_log10: bool = False, + plot_mask: bool = True, ): """ Visualizes 2D data structures (e.g an `Array2D`, `Grid2D`, `VectorField`, etc.) using Matplotlib. @@ -119,6 +123,8 @@ def __init__( Sets if the figure is displayed on the user's screen or output to `.png` using `plt.show` and `plt.savefig` array_overlay Overlays an input `Array2D` over the figure using `plt.imshow`. + fill + Sets the fill of the figure using `plt.fill` and customizes its appearance, such as the color and alpha. contour Overlays contours of an input `Array2D` over the figure using `plt.contour`. grid_scatter @@ -177,6 +183,7 @@ def __init__( ) self.array_overlay = array_overlay or w2d.ArrayOverlay(is_default=True) + self.fill = fill or w2d.Fill(is_default=True) self.contour = contour or w2d.Contour(is_default=True) @@ -217,55 +224,10 @@ def __init__( ) self.use_log10 = use_log10 + self.plot_mask = plot_mask self.is_for_subplot = False - def zoomed_array_and_extent_from(self, array) -> Tuple[np.ndarray, Tuple]: - """ - Returns the array and extent of the array, zoomed around the mask of the array, if the config file is set to - do this. - - Many plots zoom in around the mask of an array, to emphasize the signal of the data and not waste - plotting space on empty pixels. This function computes the zoomed array and extent of this array, given the - array. - - If the mask is all false, the array is returned without zooming by disabling the buffer. - - Parameters - ---------- - array - - Returns - ------- - - """ - - if array.mask.is_all_false: - buffer = 0 - else: - buffer = 1 - - zoom_around_mask = conf.instance["visualize"]["general"]["general"][ - "zoom_around_mask" - ] - - if ( - self.output.format == "fits" - and conf.instance["visualize"]["general"]["general"][ - "disable_zoom_for_fits" - ] - ): - zoom_around_mask = False - - if zoom_around_mask: - extent = array.extent_of_zoomed_array(buffer=buffer) - array = array.zoomed_around_mask(buffer=buffer) - - else: - extent = array.geometry.extent - - return array, extent - def plot_array( self, array: Array2D, @@ -301,7 +263,15 @@ def plot_array( "a pixel scales attribute." ) - array, extent = self.zoomed_array_and_extent_from(array=array) + if conf.instance["visualize"]["general"]["general"]["zoom_around_mask"]: + + zoom = Zoom2D(mask=array.mask) + + buffer = 0 if array.mask.is_all_false else 1 + + array = zoom.array_2d_from(array=array, buffer=buffer) + + extent = array.geometry.extent ax = None @@ -313,18 +283,29 @@ def plot_array( aspect = self.figure.aspect_from(shape_native=array.shape_native) - norm = self.cmap.norm_from(array=array, use_log10=self.use_log10) + norm = self.cmap.norm_from(array=array.array, use_log10=self.use_log10) origin = conf.instance["visualize"]["general"]["general"]["imshow_origin"] - plt.imshow( - X=array.native.array, - aspect=aspect, - cmap=self.cmap.cmap, - norm=norm, - extent=extent, - origin=origin, - ) + if isinstance(array, Array2DRGB): + + plt.imshow( + X=array.native.array, + aspect=aspect, + extent=extent, + origin=origin, + ) + + else: + + plt.imshow( + X=array.native.array, + aspect=aspect, + cmap=self.cmap.cmap, + norm=norm, + extent=extent, + origin=origin, + ) if visuals_2d.array_overlay is not None: self.array_overlay.overlay_array( @@ -354,7 +335,12 @@ def plot_array( pixels=array.shape_native[1], ) - self.title.set(auto_title=auto_labels.title, use_log10=self.use_log10) + if isinstance(array, Array2DRGB): + title = "RGB" + else: + title = auto_labels.title + + self.title.set(auto_title=title, use_log10=self.use_log10) self.ylabel.set() self.xlabel.set() @@ -384,9 +370,13 @@ def plot_array( except ValueError: pass - visuals_2d.plot_via_plotter( - plotter=self, grid_indexes=grid_indexes, geometry=array.geometry - ) + if self.plot_mask and visuals_2d.mask is None: + + if not array.mask.is_all_false: + + self.mask_scatter.scatter_grid(grid=array.mask.derive_grid.edge.array) + + visuals_2d.plot_via_plotter(plotter=self, grid_indexes=grid_indexes) if not self.is_for_subplot and not bypass: self.output.to_figure(structure=array, auto_filename=auto_labels.filename) @@ -427,10 +417,10 @@ def plot_grid( if color_array is None: if y_errors is None and x_errors is None: - self.grid_scatter.scatter_grid(grid=grid_plot) + self.grid_scatter.scatter_grid(grid=grid_plot.array) else: self.grid_errorbar.errorbar_grid( - grid=grid_plot, y_errors=y_errors, x_errors=x_errors + grid=grid_plot.array, y_errors=y_errors, x_errors=x_errors ) elif color_array is not None: @@ -438,11 +428,11 @@ def plot_grid( if y_errors is None and x_errors is None: self.grid_scatter.scatter_grid_colored( - grid=grid, color_array=color_array, cmap=cmap + grid=grid.array, color_array=color_array, cmap=cmap ) else: self.grid_errorbar.errorbar_grid_colored( - grid=grid, + grid=grid.array, cmap=cmap, color_array=color_array, y_errors=y_errors, @@ -450,6 +440,7 @@ def plot_grid( ) if self.colorbar is not None: + colorbar = self.colorbar.set_with_color_values( units=self.units, cmap=self.cmap.cmap, @@ -495,9 +486,7 @@ def plot_grid( if self.contour is not False: self.contour.set(array=color_array, extent=extent, use_log10=self.use_log10) - visuals_2d.plot_via_plotter( - plotter=self, grid_indexes=grid, geometry=grid.geometry - ) + visuals_2d.plot_via_plotter(plotter=self, grid_indexes=grid.array) if not self.is_for_subplot: self.output.to_figure(structure=grid, auto_filename=auto_labels.filename) @@ -575,18 +564,79 @@ def _plot_rectangular_mapper( else: ax = self.setup_subplot(aspect=aspect_inv) + shape_native = mapper.source_plane_mesh_grid.shape_native + if pixel_values is not None: - self.plot_array( - array=pixel_values, - visuals_2d=visuals_2d, - auto_labels=auto_labels, - bypass=True, + + from autoarray.inversion.pixelization.mappers.rectangular_uniform import ( + MapperRectangularUniform, + ) + from autoarray.inversion.pixelization.mappers.rectangular import ( + MapperRectangular, ) - self.axis.set(extent=extent, grid=mapper.source_plane_mesh_grid) + if isinstance(mapper, MapperRectangularUniform): - self.yticks.set(min_value=extent[2], max_value=extent[3], units=self.units) - self.xticks.set(min_value=extent[0], max_value=extent[1], units=self.units) + self.plot_array( + array=pixel_values, + visuals_2d=visuals_2d, + auto_labels=auto_labels, + bypass=True, + ) + + else: + + norm = self.cmap.norm_from( + array=pixel_values.array, use_log10=self.use_log10 + ) + + edges_transformed = mapper.edges_transformed + + edges_transformed_dense = np.moveaxis( + np.stack(np.meshgrid(*edges_transformed.T)), 0, 2 + ) + + plt.pcolormesh( + edges_transformed_dense[..., 0], + edges_transformed_dense[..., 1], + pixel_values.array.reshape(shape_native), + shading="flat", + norm=norm, + cmap=self.cmap.cmap, + ) + + if self.colorbar is not False: + + cb = self.colorbar.set( + units=self.units, + ax=ax, + norm=norm, + cb_unit=auto_labels.cb_unit, + use_log10=self.use_log10, + ) + self.colorbar_tickparams.set(cb=cb) + + extent_axis = self.axis.config_dict.get("extent") + + if extent_axis is None: + extent_axis = extent + + self.axis.set(extent=extent_axis) + + self.tickparams.set() + self.yticks.set( + min_value=extent_axis[2], + max_value=extent_axis[3], + units=self.units, + pixels=shape_native[0], + ) + + self.xticks.set( + min_value=extent_axis[0], + max_value=extent_axis[1], + units=self.units, + pixels=shape_native[1], + ) if not isinstance(self.text, list): self.text.set() @@ -598,21 +648,17 @@ def _plot_rectangular_mapper( else: [annotate.set() for annotate in self.annotate] - self.grid_plot.plot_rectangular_grid_lines( - extent=mapper.source_plane_mesh_grid.geometry.extent, - shape_native=mapper.shape_native, - ) + # self.grid_plot.plot_rectangular_grid_lines( + # extent=mapper.source_plane_mesh_grid.geometry.extent, + # shape_native=mapper.shape_native, + # ) self.title.set(auto_title=auto_labels.title) - self.tickparams.set() self.ylabel.set() self.xlabel.set() visuals_2d.plot_via_plotter( - plotter=self, - grid_indexes=mapper.source_plane_data_grid.over_sampled, - mapper=mapper, - geometry=mapper.mapper_grids.mask.geometry, + plotter=self, grid_indexes=mapper.source_plane_data_grid.over_sampled ) if not self.is_for_subplot: @@ -693,10 +739,7 @@ def _plot_delaunay_mapper( self.xlabel.set() visuals_2d.plot_via_plotter( - plotter=self, - grid_indexes=mapper.source_plane_data_grid.over_sampled, - mapper=mapper, - geometry=mapper.mapper_grids.mask.geometry, + plotter=self, grid_indexes=mapper.source_plane_data_grid.over_sampled ) if not self.is_for_subplot: @@ -776,10 +819,7 @@ def _plot_voronoi_mapper( self.xlabel.set() visuals_2d.plot_via_plotter( - plotter=self, - grid_indexes=mapper.source_plane_data_grid.over_sampled, - mapper=mapper, - geometry=mapper.mapper_grids.mask.geometry, + plotter=self, grid_indexes=mapper.source_plane_data_grid.over_sampled ) if pixel_values is not None: diff --git a/autoarray/plot/visuals/abstract.py b/autoarray/plot/visuals/abstract.py index b9d4b0e26..35583e985 100644 --- a/autoarray/plot/visuals/abstract.py +++ b/autoarray/plot/visuals/abstract.py @@ -12,25 +12,19 @@ def __add__(self, other): mask = Mask2D.circular(shape_native=(100, 100), pixel_scales=0.1, radius=3.0) array = Array2D.ones(shape_native=(100, 100), pixel_scales=0.1) masked_array = al.Array2D(values=array, mask=mask) - include_2d = Include2D(mask=True) - array_plotter = aplt.Array2DPlotter(array=masked_array, include_2d=include_2d) + array_plotter = aplt.Array2DPlotter(array=masked_array) array_plotter.figure() - Because `mask=True` in `Include2D` the function `figure` extracts the `Mask2D` from the `masked_array` - and plots it. It does this by creating a new `Visuals2D` object. - If the user did not manually input a `Visuals2D` object, the one created in `function_array` is the one used to plot the image However, if the user specifies their own `Visuals2D` object and passed it to the plotter, e.g.: visuals_2d = Visuals2D(origin=(0.0, 0.0)) - include_2d = Include2D(mask=True) - array_plotter = aplt.Array2DPlotter(array=masked_array, include_2d=include_2d) + array_plotter = aplt.Array2DPlotter(array=masked_array) - We now wish for the `Plotter` to plot the `origin` in the user's input `Visuals2D` object and the `Mask2d` - extracted via the `Include2D`. To achieve this, two `Visuals2D` objects are created: (i) the user's input - instance (with an origin) and; (ii) the one created by the `Include2D` object (with a mask). + We now wish for the `Plotter` to plot the `origin` in the user's input `Visuals2D` object. To achieve this, + one `Visuals2D` object is created: (i) the user's input instance (with an origin). This `__add__` override means we can add the two together to make the final `Visuals2D` object that is plotted on the figure containing both the `origin` and `Mask2D`.: diff --git a/autoarray/plot/visuals/one_d.py b/autoarray/plot/visuals/one_d.py index 8e3e33584..b84a832b3 100644 --- a/autoarray/plot/visuals/one_d.py +++ b/autoarray/plot/visuals/one_d.py @@ -2,7 +2,6 @@ from typing import List, Optional, Union from autoarray.mask.mask_1d import Mask1D -from autoarray.plot.include.one_d import Include1D from autoarray.plot.visuals.abstract import AbstractVisuals from autoarray.structures.arrays.uniform_1d import Array1D from autoarray.structures.grids.uniform_1d import Grid1D @@ -23,10 +22,6 @@ def __init__( self.vertical_line = vertical_line self.shaded_region = shaded_region - @property - def include(self): - return Include1D() - def plot_via_plotter(self, plotter): if self.points is not None: plotter.yx_scatter.scatter_yx(y=self.points, x=np.arange(len(self.points))) diff --git a/autoarray/plot/visuals/two_d.py b/autoarray/plot/visuals/two_d.py index d1da534ff..35242db57 100644 --- a/autoarray/plot/visuals/two_d.py +++ b/autoarray/plot/visuals/two_d.py @@ -22,13 +22,12 @@ def __init__( mesh_grid: Optional[Grid2D] = None, vectors: Optional[VectorYX2DIrregular] = None, patches: Optional[List[ptch.Patch]] = None, + fill_region: Optional[List] = None, array_overlay: Optional[Array2D] = None, parallel_overscan=None, serial_prescan=None, serial_overscan=None, indexes=None, - pix_indexes=None, - indexes_via_scatter=False, ): self.origin = origin self.mask = mask @@ -39,34 +38,43 @@ def __init__( self.mesh_grid = mesh_grid self.vectors = vectors self.patches = patches + self.fill_region = fill_region self.array_overlay = array_overlay self.parallel_overscan = parallel_overscan self.serial_prescan = serial_prescan self.serial_overscan = serial_overscan self.indexes = indexes - self.pix_indexes = pix_indexes - self.indexes_via_scatter = indexes_via_scatter - def plot_via_plotter(self, plotter, grid_indexes=None, mapper=None, geometry=None): + def plot_via_plotter(self, plotter, grid_indexes=None): + + if self.mask is not None: + plotter.mask_scatter.scatter_grid(grid=self.mask.derive_grid.edge.array) + if self.origin is not None: plotter.origin_scatter.scatter_grid( - grid=Grid2DIrregular(values=self.origin) + grid=Grid2DIrregular(values=self.origin).array ) - if self.mask is not None: - plotter.mask_scatter.scatter_grid(grid=self.mask.derive_grid.edge) - if self.border is not None: - plotter.border_scatter.scatter_grid(grid=self.border) + try: + plotter.border_scatter.scatter_grid(grid=self.border.array) + except AttributeError: + plotter.border_scatter.scatter_grid(grid=self.border) if self.grid is not None: - plotter.grid_scatter.scatter_grid(grid=self.grid) + try: + plotter.grid_scatter.scatter_grid(grid=self.grid.array) + except AttributeError: + plotter.grid_scatter.scatter_grid(grid=self.grid) if self.mesh_grid is not None: - plotter.mesh_grid_scatter.scatter_grid(grid=self.mesh_grid) + plotter.mesh_grid_scatter.scatter_grid(grid=self.mesh_grid.array) if self.positions is not None: - plotter.positions_scatter.scatter_grid(grid=self.positions) + try: + plotter.positions_scatter.scatter_grid(grid=self.positions.array) + except (AttributeError, ValueError): + plotter.positions_scatter.scatter_grid(grid=self.positions) if self.vectors is not None: plotter.vector_yx_quiver.quiver_vectors(vectors=self.vectors) @@ -74,32 +82,15 @@ def plot_via_plotter(self, plotter, grid_indexes=None, mapper=None, geometry=Non if self.patches is not None: plotter.patch_overlay.overlay_patches(patches=self.patches) + if self.fill_region is not None: + plotter.fill.plot_fill(fill_region=self.fill_region) + if self.lines is not None: plotter.grid_plot.plot_grid(grid=self.lines) if self.indexes is not None and grid_indexes is not None: - if not self.indexes_via_scatter: - plotter.index_plot.plot_grid_indexes_multi( - grid=grid_indexes, indexes=self.indexes, geometry=geometry - ) - - else: - plotter.index_scatter.scatter_grid_indexes( - grid=grid_indexes, - indexes=self.indexes, - ) - - if self.pix_indexes is not None and mapper is not None: - indexes = mapper.pix_indexes_for_slim_indexes(pix_indexes=self.pix_indexes) - - if not self.indexes_via_scatter: - plotter.index_plot.plot_grid_indexes_x1( - grid=grid_indexes, - indexes=indexes, - ) - - else: - plotter.index_scatter.scatter_grid_indexes( - grid=mapper.source_plane_data_grid.over_sampled, - indexes=indexes, - ) + + plotter.index_scatter.scatter_grid_indexes( + grid=grid_indexes, + indexes=self.indexes, + ) diff --git a/autoarray/plot/wrap/base/abstract.py b/autoarray/plot/wrap/base/abstract.py index 82ffd1fff..c246c1afa 100644 --- a/autoarray/plot/wrap/base/abstract.py +++ b/autoarray/plot/wrap/base/abstract.py @@ -1,5 +1,4 @@ import numpy as np -import matplotlib from autoconf import conf @@ -16,6 +15,8 @@ def set_backend(): It is also common for high perforamcne computers (HPCs) to not support visualization and raise an error when a graphical backend (e.g. TKAgg) is used. Setting the backend to `Agg` addresses this. """ + import matplotlib + backend = conf.get_matplotlib_backend() if backend not in "default": diff --git a/autoarray/plot/wrap/base/annotate.py b/autoarray/plot/wrap/base/annotate.py index 3ba0b60d9..e1f1e917f 100644 --- a/autoarray/plot/wrap/base/annotate.py +++ b/autoarray/plot/wrap/base/annotate.py @@ -1,5 +1,3 @@ -import matplotlib.pyplot as plt - from autoarray.plot.wrap.base.abstract import AbstractMatWrap @@ -13,6 +11,9 @@ class Annotate(AbstractMatWrap): """ def set(self): + + import matplotlib.pyplot as plt + if "x" not in self.kwargs and "y" not in self.kwargs and "s" not in self.kwargs: return diff --git a/autoarray/plot/wrap/base/axis.py b/autoarray/plot/wrap/base/axis.py index 69d34d933..cd57c5d20 100644 --- a/autoarray/plot/wrap/base/axis.py +++ b/autoarray/plot/wrap/base/axis.py @@ -1,4 +1,3 @@ -import matplotlib.pyplot as plt import numpy as np from typing import List @@ -34,6 +33,7 @@ def set(self, extent: List[float] = None, grid=None): The extent of the figure which set the axis-limits on the figure the grid is plotted, following the format [xmin, xmax, ymin, ymax]. """ + import matplotlib.pyplot as plt config_dict = self.config_dict extent_dict = config_dict.get("extent") diff --git a/autoarray/plot/wrap/base/cmap.py b/autoarray/plot/wrap/base/cmap.py index 5dc2f514b..51144ba46 100644 --- a/autoarray/plot/wrap/base/cmap.py +++ b/autoarray/plot/wrap/base/cmap.py @@ -1,10 +1,7 @@ import copy import logging -from matplotlib.colors import LinearSegmentedColormap -import matplotlib.colors as colors import numpy as np -from autoconf import conf from autoarray.plot.wrap.base.abstract import AbstractMatWrap @@ -60,6 +57,7 @@ def norm_from(self, array: np.ndarray, use_log10: bool = False) -> object: array The array of data which is to be plotted. """ + import matplotlib.colors as colors vmin = self.vmin_from(array=array, use_log10=use_log10) vmax = self.vmax_from(array=array, use_log10=use_log10) @@ -99,6 +97,8 @@ def norm_from(self, array: np.ndarray, use_log10: bool = False) -> object: @property def cmap(self): + from matplotlib.colors import LinearSegmentedColormap + if self.config_dict["cmap"] == "default": from autoarray.plot.wrap.segmentdata import segmentdata diff --git a/autoarray/plot/wrap/base/colorbar.py b/autoarray/plot/wrap/base/colorbar.py index b0650013a..71a8a174e 100644 --- a/autoarray/plot/wrap/base/colorbar.py +++ b/autoarray/plot/wrap/base/colorbar.py @@ -1,5 +1,3 @@ -import matplotlib.pyplot as plt -import matplotlib.cm as cm import numpy as np from typing import List, Optional @@ -142,6 +140,7 @@ def set( """ Set the figure's colorbar, optionally overriding the tick labels and values with manual inputs. """ + import matplotlib.pyplot as plt tick_values = self.tick_values_from(norm=norm, use_log10=use_log10) tick_labels = self.tick_labels_from( @@ -183,6 +182,8 @@ def set_with_color_values( color_values The values of the pixels on the Voronoi mesh which are used to create the colorbar. """ + import matplotlib.pyplot as plt + import matplotlib.cm as cm mappable = cm.ScalarMappable(norm=norm, cmap=cmap) mappable.set_array(color_values) @@ -194,6 +195,7 @@ def set_with_color_values( ) if tick_values is None and tick_labels is None: + cb = plt.colorbar( mappable=mappable, ax=ax, diff --git a/autoarray/plot/wrap/base/figure.py b/autoarray/plot/wrap/base/figure.py index 238264fd9..61a3347b6 100644 --- a/autoarray/plot/wrap/base/figure.py +++ b/autoarray/plot/wrap/base/figure.py @@ -1,6 +1,5 @@ from enum import Enum import gc -import matplotlib.pyplot as plt from typing import Union, Tuple from autoarray.plot.wrap.base.abstract import AbstractMatWrap @@ -82,6 +81,8 @@ def open(self): """ Wraps the Matplotlib method 'plt.figure' for opening a figure. """ + import matplotlib.pyplot as plt + if not plt.fignum_exists(num=1): config_dict = self.config_dict config_dict.pop("aspect") @@ -93,5 +94,7 @@ def close(self): """ Wraps the Matplotlib method 'plt.close' for closing a figure. """ + import matplotlib.pyplot as plt + plt.close() gc.collect() diff --git a/autoarray/plot/wrap/base/label.py b/autoarray/plot/wrap/base/label.py index 377f9a6b9..a54d5c474 100644 --- a/autoarray/plot/wrap/base/label.py +++ b/autoarray/plot/wrap/base/label.py @@ -1,10 +1,6 @@ -import matplotlib.pyplot as plt from typing import Optional -from autoconf import conf - from autoarray.plot.wrap.base.abstract import AbstractMatWrap -from autoarray.plot.wrap.base.units import Units class AbstractLabel(AbstractMatWrap): @@ -48,6 +44,7 @@ def set( units The units of the image that is plotted which informs the appropriate y label text. """ + import matplotlib.pyplot as plt config_dict = self.config_dict @@ -77,6 +74,7 @@ def set( units The units of the image that is plotted which informs the appropriate x label text. """ + import matplotlib.pyplot as plt config_dict = self.config_dict diff --git a/autoarray/plot/wrap/base/legend.py b/autoarray/plot/wrap/base/legend.py index 04b64441b..09a6e9d4d 100644 --- a/autoarray/plot/wrap/base/legend.py +++ b/autoarray/plot/wrap/base/legend.py @@ -1,5 +1,3 @@ -import matplotlib.pyplot as plt - from autoarray.plot.wrap.base.abstract import AbstractMatWrap @@ -19,6 +17,9 @@ def __init__(self, label=None, include=True, **kwargs): self.include = include def set(self): + + import matplotlib.pyplot as plt + if self.include: config_dict = self.config_dict config_dict.pop("include") if "include" in config_dict else None diff --git a/autoarray/plot/wrap/base/output.py b/autoarray/plot/wrap/base/output.py index 73a5f40e4..a85c2ebb6 100644 --- a/autoarray/plot/wrap/base/output.py +++ b/autoarray/plot/wrap/base/output.py @@ -1,5 +1,4 @@ import logging -import matplotlib.pyplot as plt from os import path import os from typing import Union, List, Optional @@ -102,11 +101,14 @@ def filename_from(self, auto_filename): return filename def savefig(self, filename: str, output_path: str, format: str): + + import matplotlib.pyplot as plt + try: plt.savefig( path.join(output_path, f"{filename}.{format}"), bbox_inches=self.bbox_inches, - pad_inches=0, + pad_inches=0.1, ) except ValueError as e: logger.info( @@ -130,6 +132,7 @@ def to_figure( auto_filename If the filename is not manually specified this name is used instead, which is defined in the parent plotter. """ + import matplotlib.pyplot as plt filename = self.filename_from(auto_filename=auto_filename) @@ -163,6 +166,7 @@ def subplot_to_figure(self, auto_filename: Optional[str] = None): auto_filename If the filename is not manually specified this name is used instead, which is defined in the parent plotter. """ + import matplotlib.pyplot as plt filename = self.filename_from(auto_filename=auto_filename) diff --git a/autoarray/plot/wrap/base/text.py b/autoarray/plot/wrap/base/text.py index 4dd94c0d6..4141bc0ac 100644 --- a/autoarray/plot/wrap/base/text.py +++ b/autoarray/plot/wrap/base/text.py @@ -1,5 +1,3 @@ -import matplotlib.pyplot as plt - from autoarray.plot.wrap.base.abstract import AbstractMatWrap @@ -13,6 +11,9 @@ class Text(AbstractMatWrap): """ def set(self): + + import matplotlib.pyplot as plt + if "x" not in self.kwargs and "y" not in self.kwargs and "s" not in self.kwargs: return diff --git a/autoarray/plot/wrap/base/tickparams.py b/autoarray/plot/wrap/base/tickparams.py index ca96ad792..7369a29a0 100644 --- a/autoarray/plot/wrap/base/tickparams.py +++ b/autoarray/plot/wrap/base/tickparams.py @@ -1,5 +1,3 @@ -import matplotlib.pyplot as plt - from autoarray.plot.wrap.base.abstract import AbstractMatWrap @@ -14,4 +12,7 @@ class TickParams(AbstractMatWrap): def set(self): """Set the tick_params of the figure using the method `plt.tick_params`.""" + + import matplotlib.pyplot as plt + plt.tick_params(**self.config_dict) diff --git a/autoarray/plot/wrap/base/ticks.py b/autoarray/plot/wrap/base/ticks.py index 95b376e8f..3b7c72e57 100644 --- a/autoarray/plot/wrap/base/ticks.py +++ b/autoarray/plot/wrap/base/ticks.py @@ -1,5 +1,3 @@ -import matplotlib.pyplot as plt -from matplotlib.ticker import FormatStrFormatter import numpy as np from typing import List, Tuple, Optional @@ -371,6 +369,8 @@ def set( units The units of the figure. """ + import matplotlib.pyplot as plt + from matplotlib.ticker import FormatStrFormatter if self.manual_min_max_value: min_value = self.manual_min_max_value[0] @@ -426,6 +426,8 @@ def set( units The units of the figure. """ + import matplotlib.pyplot as plt + from matplotlib.ticker import FormatStrFormatter if self.manual_min_max_value: min_value = self.manual_min_max_value[0] diff --git a/autoarray/plot/wrap/base/title.py b/autoarray/plot/wrap/base/title.py index 8185a7184..60aec1d30 100644 --- a/autoarray/plot/wrap/base/title.py +++ b/autoarray/plot/wrap/base/title.py @@ -1,5 +1,3 @@ -import matplotlib.pyplot as plt - from autoarray.plot.wrap.base.abstract import AbstractMatWrap @@ -29,6 +27,9 @@ def __init__(self, prefix: str = None, disable_log10_label: bool = False, **kwar self.manual_label = self.kwargs.get("label") def set(self, auto_title=None, use_log10: bool = False): + + import matplotlib.pyplot as plt + config_dict = self.config_dict label = auto_title if self.manual_label is None else self.manual_label diff --git a/autoarray/plot/wrap/base/units.py b/autoarray/plot/wrap/base/units.py index 7bf0f1d45..db1048b32 100644 --- a/autoarray/plot/wrap/base/units.py +++ b/autoarray/plot/wrap/base/units.py @@ -15,7 +15,7 @@ def __init__( ticks_label: Optional[str] = None, colorbar_convert_factor: Optional[float] = None, colorbar_label: Optional[str] = None, - **kwargs + **kwargs, ): """ This object controls the units of a plotted figure, and performs multiple tasks when making the plot: diff --git a/autoarray/plot/wrap/one_d/avxline.py b/autoarray/plot/wrap/one_d/avxline.py index 496d644c7..7d36672d8 100644 --- a/autoarray/plot/wrap/one_d/avxline.py +++ b/autoarray/plot/wrap/one_d/avxline.py @@ -1,4 +1,3 @@ -import matplotlib.pyplot as plt from typing import List, Optional from autoarray.plot.wrap.one_d.abstract import AbstractMatWrap1D @@ -42,6 +41,7 @@ def axvline_vertical_line( label Labels for each vertical line used by a `Legend`. """ + import matplotlib.pyplot as plt if vertical_line is [] or vertical_line is None: return diff --git a/autoarray/plot/wrap/one_d/fill_between.py b/autoarray/plot/wrap/one_d/fill_between.py index 4f446252b..8a91b9a73 100644 --- a/autoarray/plot/wrap/one_d/fill_between.py +++ b/autoarray/plot/wrap/one_d/fill_between.py @@ -1,4 +1,3 @@ -import matplotlib.pyplot as plt import numpy as np from typing import List, Union @@ -44,6 +43,7 @@ def fill_between_shaded_regions( y1 The second line of ydata that defines the region that is filled in. """ + import matplotlib.pyplot as plt config_dict = self.config_dict diff --git a/autoarray/plot/wrap/one_d/yx_plot.py b/autoarray/plot/wrap/one_d/yx_plot.py index 15d894b72..d679b91e3 100644 --- a/autoarray/plot/wrap/one_d/yx_plot.py +++ b/autoarray/plot/wrap/one_d/yx_plot.py @@ -1,4 +1,3 @@ -import matplotlib.pyplot as plt import numpy as np from typing import Union @@ -51,6 +50,7 @@ def plot_y_vs_x( label Optionally include a label on the plot for a `Legend` to display. """ + import matplotlib.pyplot as plt if self.label is not None: label = self.label @@ -72,7 +72,7 @@ def plot_y_vs_x( # marker="o", fmt="o", # ls=ls_errorbar, - **self.config_dict + **self.config_dict, ) if plot_axis_type == "errorbar_logy": plt.yscale("log") diff --git a/autoarray/plot/wrap/one_d/yx_scatter.py b/autoarray/plot/wrap/one_d/yx_scatter.py index ab77a0ea1..5e3c1d93a 100644 --- a/autoarray/plot/wrap/one_d/yx_scatter.py +++ b/autoarray/plot/wrap/one_d/yx_scatter.py @@ -1,4 +1,3 @@ -import matplotlib.pyplot as plt import numpy as np from typing import Union @@ -29,6 +28,7 @@ def scatter_yx(self, y: Union[np.ndarray, Grid1D], x: list): errors The error on every point of the grid that is plotted. """ + import matplotlib.pyplot as plt config_dict = self.config_dict diff --git a/autoarray/plot/wrap/two_d/__init__.py b/autoarray/plot/wrap/two_d/__init__.py index 5eb85eeab..5b438f4f8 100644 --- a/autoarray/plot/wrap/two_d/__init__.py +++ b/autoarray/plot/wrap/two_d/__init__.py @@ -1,5 +1,6 @@ from .array_overlay import ArrayOverlay from .contour import Contour +from .fill import Fill from .grid_scatter import GridScatter from .grid_plot import GridPlot from .grid_errorbar import GridErrorbar diff --git a/autoarray/plot/wrap/two_d/array_overlay.py b/autoarray/plot/wrap/two_d/array_overlay.py index 5de20b879..372bb5f6c 100644 --- a/autoarray/plot/wrap/two_d/array_overlay.py +++ b/autoarray/plot/wrap/two_d/array_overlay.py @@ -1,6 +1,7 @@ import matplotlib.pyplot as plt from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D +from autoarray.mask.derive.zoom_2d import Zoom2D class ArrayOverlay(AbstractMatWrap2D): @@ -17,8 +18,9 @@ class ArrayOverlay(AbstractMatWrap2D): def overlay_array(self, array, figure): aspect = figure.aspect_from(shape_native=array.shape_native) - extent = array.extent_of_zoomed_array(buffer=0) - print(type(array)) + zoom = Zoom2D(mask=array.mask) + array_zoom = zoom.array_2d_from(array=array, buffer=0) + extent = array_zoom.geometry.extent - plt.imshow(X=array.native._array, aspect=aspect, extent=extent, **self.config_dict) + plt.imshow(X=array.native, aspect=aspect, extent=extent, **self.config_dict) diff --git a/autoarray/plot/wrap/two_d/delaunay_drawer.py b/autoarray/plot/wrap/two_d/delaunay_drawer.py index 5d168213b..40e10a778 100644 --- a/autoarray/plot/wrap/two_d/delaunay_drawer.py +++ b/autoarray/plot/wrap/two_d/delaunay_drawer.py @@ -112,5 +112,5 @@ def draw_delaunay_pixels( cmap=cmap, vmin=vmin, vmax=vmax, - **self.config_dict + **self.config_dict, ) diff --git a/autoarray/plot/wrap/two_d/fill.py b/autoarray/plot/wrap/two_d/fill.py new file mode 100644 index 000000000..f580dde54 --- /dev/null +++ b/autoarray/plot/wrap/two_d/fill.py @@ -0,0 +1,38 @@ +import logging + +import matplotlib.pyplot as plt + +from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D + + +logger = logging.getLogger(__name__) + + +class Fill(AbstractMatWrap2D): + def __init__(self, **kwargs): + """ + The settings used to customize plots using fill on a figure + + This object wraps the following Matplotlib methods: + + - plt.fill https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.fill.html + + Parameters + ---------- + symmetric + If True, the colormap normalization (e.g. `vmin` and `vmax`) span the same absolute values producing a + symmetric color bar. + """ + + super().__init__(**kwargs) + + def plot_fill(self, fill_region): + + try: + y_fill = fill_region[:, 0] + x_fill = fill_region[:, 1] + except TypeError: + y_fill = fill_region[0] + x_fill = fill_region[1] + + plt.fill(x_fill, y_fill, **self.config_dict) diff --git a/autoarray/plot/wrap/two_d/grid_plot.py b/autoarray/plot/wrap/two_d/grid_plot.py index a727bec30..cce6ea336 100644 --- a/autoarray/plot/wrap/two_d/grid_plot.py +++ b/autoarray/plot/wrap/two_d/grid_plot.py @@ -1,4 +1,3 @@ -import matplotlib.pyplot as plt import numpy as np import itertools from typing import List, Union, Tuple @@ -43,6 +42,7 @@ def plot_rectangular_grid_lines( shape_native The 2D shape of the mask the array is paired with. """ + import matplotlib.pyplot as plt ys = np.linspace(extent[2], extent[3], shape_native[1] + 1) xs = np.linspace(extent[0], extent[1], shape_native[0] + 1) @@ -66,6 +66,7 @@ def plot_grid(self, grid: Union[np.ndarray, Grid2D]): grid The grid of (y,x) coordinates that is plotted. """ + import matplotlib.pyplot as plt try: color = self.config_dict["c"] @@ -94,6 +95,7 @@ def plot_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular]]): grid_list The list of grids of (y,x) coordinates that are plotted. """ + import matplotlib.pyplot as plt if len(grid_list) == 0: return None @@ -104,62 +106,11 @@ def plot_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular]]): try: for grid in grid_list: - plt.plot(grid[:, 1], grid[:, 0], c=next(color), **config_dict) + try: + plt.plot(grid[:, 1], grid[:, 0], c=next(color), **config_dict) + except ValueError: + plt.plot( + grid.array[:, 1], grid.array[:, 0], c=next(color), **config_dict + ) except IndexError: pass - - def plot_grid_indexes_x1( - self, - grid: Union[np.ndarray, Grid2D, Grid2DIrregular], - indexes: np.ndarray, - ): - color = itertools.cycle(self.config_dict["c"]) - config_dict = self.config_dict - config_dict.pop("c") - - if isinstance(indexes[0], int): - indexes = [indexes] - - for index_list in indexes: - grid_contour = Grid2DContour( - grid=grid[index_list, :], - pixel_scales=None, - shape_native=None, - ) - - grid_hull = grid_contour.hull - - if grid_hull is not None: - plt.plot( - grid_hull[:, 1], grid_hull[:, 0], color=next(color), **config_dict - ) - - def plot_grid_indexes_multi( - self, - grid: Union[np.ndarray, Grid2D, Grid2DIrregular], - indexes: np.ndarray, - geometry: Geometry2D, - ): - color = itertools.cycle(self.config_dict["c"]) - config_dict = self.config_dict - config_dict.pop("c") - - if isinstance(indexes[0], int): - indexes = [indexes] - - for index_list in indexes: - grid_in = grid[index_list, :] - - if isinstance(index_list[0], tuple): - grid_in = grid_in[0] - - grid_contour = Grid2DContour( - grid=grid_in, - pixel_scales=geometry.pixel_scales, - shape_native=geometry.shape_native, - ) - - color_plot = next(color) - - for contour in grid_contour.contour_list: - plt.plot(contour[:, 1], contour[:, 0], color=color_plot, **config_dict) diff --git a/autoarray/plot/wrap/two_d/grid_scatter.py b/autoarray/plot/wrap/two_d/grid_scatter.py index 8399ae591..e9b9879d0 100644 --- a/autoarray/plot/wrap/two_d/grid_scatter.py +++ b/autoarray/plot/wrap/two_d/grid_scatter.py @@ -1,7 +1,6 @@ import matplotlib.pyplot as plt import numpy as np import itertools -from scipy.spatial import ConvexHull from typing import List, Union @@ -80,7 +79,17 @@ def scatter_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular] try: for grid in grid_list: - plt.scatter(y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict) + try: + plt.scatter( + y=grid[:, 0], x=grid[:, 1], c=next(color), **config_dict + ) + except ValueError: + plt.scatter( + y=grid.array[:, 0], + x=grid.array[:, 1], + c=next(color), + **config_dict, + ) except IndexError: return None diff --git a/autoarray/plot/wrap/two_d/voronoi_drawer.py b/autoarray/plot/wrap/two_d/voronoi_drawer.py index aec6661cc..46a8a2a13 100644 --- a/autoarray/plot/wrap/two_d/voronoi_drawer.py +++ b/autoarray/plot/wrap/two_d/voronoi_drawer.py @@ -6,7 +6,7 @@ from autoarray.plot.wrap.base.units import Units from autoarray.inversion.pixelization.mappers.voronoi import MapperVoronoi -from autoarray.inversion.pixelization.mesh import mesh_util +from autoarray.inversion.pixelization.mesh import mesh_numba_util from autoarray.plot.wrap import base as wb @@ -59,7 +59,7 @@ def draw_voronoi_pixels( if ax is None: ax = plt.gca() - regions, vertices = mesh_util.voronoi_revised_from(voronoi=mapper.voronoi) + regions, vertices = mesh_numba_util.voronoi_revised_from(voronoi=mapper.voronoi) if pixel_values is not None: norm = cmap.norm_from(array=pixel_values, use_log10=use_log10) diff --git a/autoarray/preloads.py b/autoarray/preloads.py index 6808f0f6c..340d85bdd 100644 --- a/autoarray/preloads.py +++ b/autoarray/preloads.py @@ -1,15 +1,7 @@ import logging -import numpy as np -import os -from typing import List - -from autoconf import conf - -from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList -from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper -from autoarray import exc -from autoarray.inversion.inversion.imaging import inversion_imaging_util +import jax.numpy as jnp +import numpy as np logger = logging.getLogger(__name__) @@ -17,552 +9,65 @@ class Preloads: + def __init__( self, - w_tilde=None, - use_w_tilde=None, - image_plane_mesh_grid_pg_list=None, - relocated_grid=None, - mapper_list=None, - operated_mapping_matrix=None, - linear_func_operated_mapping_matrix_dict=None, - data_linear_func_matrix_dict=None, - mapper_operated_mapping_matrix_dict=None, - curvature_matrix=None, - data_vector_mapper=None, - curvature_matrix_mapper_diag=None, - regularization_matrix=None, - log_det_regularization_matrix_term=None, - traced_mesh_grids_list_of_planes=None, - image_plane_mesh_grid_list=None, + mapper_indices: np.ndarray = None, + source_pixel_zeroed_indices: np.ndarray = None, + linear_light_profile_blurred_mapping_matrix=None, ): - self.w_tilde = w_tilde - self.use_w_tilde = use_w_tilde - - self.image_plane_mesh_grid_pg_list = image_plane_mesh_grid_pg_list - self.relocated_grid = relocated_grid - self.mapper_list = mapper_list - self.operated_mapping_matrix = operated_mapping_matrix - self.linear_func_operated_mapping_matrix_dict = ( - linear_func_operated_mapping_matrix_dict - ) - self.data_linear_func_matrix_dict = data_linear_func_matrix_dict - self.mapper_operated_mapping_matrix_dict = mapper_operated_mapping_matrix_dict - self.curvature_matrix = curvature_matrix - self.data_vector_mapper = data_vector_mapper - self.curvature_matrix_mapper_diag = curvature_matrix_mapper_diag - self.regularization_matrix = regularization_matrix - self.log_det_regularization_matrix_term = log_det_regularization_matrix_term - - self.traced_mesh_grids_list_of_planes = traced_mesh_grids_list_of_planes - self.image_plane_mesh_grid_list = image_plane_mesh_grid_list - - @property - def check_threshold(self): - return conf.instance["general"]["test"]["preloads_check_threshold"] - - def set_w_tilde_imaging(self, fit_0, fit_1): - """ - The w-tilde linear algebra formalism speeds up inversions by computing beforehand quantities that enable - efficiently construction of the curvature matrix. These quantities can only be used if the noise-map is - fixed, therefore this function preloads these w-tilde quantities if the noise-map does not change. - - This function compares the noise map of two fit's corresponding to two model instances, and preloads wtilde - if the noise maps of both fits are the same. - - The preload is typically used through search chaining pipelines, as it is uncommon for the noise map to be - scaled during the model-fit (although it is common for a fixed but scaled noise map to be used). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.w_tilde = None - self.use_w_tilde = False - - if fit_0.inversion is None: - return - - if not fit_0.inversion.has(cls=AbstractMapper): - return - - if np.max(abs(fit_0.noise_map - fit_1.noise_map)) < 1e-8: - logger.info("PRELOADS - Computing W-Tilde... May take a moment.") - - from autoarray.dataset.imaging.w_tilde import WTildeImaging - - ( - preload, - indexes, - lengths, - ) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from( - noise_map_native=np.array(fit_0.noise_map.native), - kernel_native=np.array(fit_0.dataset.psf.native), - native_index_for_slim_index=np.array( - fit_0.dataset.mask.derive_indexes.native_for_slim - ), - ) - - self.w_tilde = WTildeImaging( - curvature_preload=preload, - indexes=indexes.astype("int"), - lengths=lengths.astype("int"), - noise_map_value=fit_0.noise_map[0], - ) - - self.use_w_tilde = True - - logger.info("PRELOADS - W-Tilde preloaded for this model-fit.") - - def set_relocated_grid(self, fit_0, fit_1): - """ - If the `MassProfile`'s in a model are fixed their traced grids (which may have had coordinates relocated at - the border) does not change during the model=fit and can therefore be preloaded. - - This function compares the relocated grids of the mappers of two fit corresponding to two model instances, and - preloads the grid if the grids of both fits are the same. - - The preload is typically used in adapt searches, where the mass model is fixed and the parameters are - varied. - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.relocated_grid = None - - if fit_0.inversion is None: - return - - if ( - fit_0.inversion.total(cls=AbstractMapper) > 1 - or fit_0.inversion.total(cls=AbstractMapper) == 0 - ): - return - - mapper_0 = fit_0.inversion.cls_list_from(cls=AbstractMapper)[0] - mapper_1 = fit_1.inversion.cls_list_from(cls=AbstractMapper)[0] - - if ( - mapper_0.source_plane_data_grid.shape[0] - == mapper_1.source_plane_data_grid.shape[0] - ): - if ( - np.max( - abs( - mapper_0.source_plane_data_grid - - mapper_1.source_plane_data_grid - ) - ) - < 1.0e-8 - ): - self.relocated_grid = mapper_0.source_plane_data_grid - - logger.info( - "PRELOADS - Relocated grid of pxielization preloaded for this model-fit." - ) - - def set_mapper_list(self, fit_0, fit_1): """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and the list of `Mapper`'s containing this information can - be preloaded. This includes preloading the `mapping_matrix`. + Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance + and compatibility with JAX. - This function compares the mapping matrix of two fit's corresponding to two model instances, and preloads the - list of mappers if the mapping matrix of both fits are the same. + Some arrays (e.g. `mapper_indices`) are required to be defined before sampling begins, because JAX demands + that input shapes remain static. These are used during each inversion to ensure consistent matrix shapes + for all likelihood evaluations. - The preload is typically used in searches where only light profiles vary (e.g. when only the lens's light is - being fitted for). + Other arrays (e.g. parts of the curvature matrix) are preloaded purely to improve performance. In cases where + the source model is fixed (e.g. when fitting only the lens light), sections of the curvature matrix do not + change and can be reused, avoiding redundant computation. Parameters ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. + mapper_indices + The integer indices of mapper pixels in the inversion. Used to extract reduced matrices (e.g. + `curvature_matrix_reduced`) that compute the pixelized inversion's log evidence term, where the indicies + are requirred to separate the rows and columns of matrices from linear light profiles. + source_pixel_zeroed_indices + Indices of source pixels that should be set to zero in the reconstruction. These typically correspond to + outer-edge source-plane regions with no image-plane mapping (e.g. outside a circular mask), helping + separate the lens light from the pixelized source model. + linear_light_profile_blurred_mapping_matrix + The evaluated images of the linear light profiles that make up the blurred mapping matrix component of the + inversion, with the other component being the pixelization's pixels. These are fixed when the lens light + is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but + the intensity values will still be solved for during the inversion. """ - self.mapper_list = None + self.mapper_indices = None + self.source_pixel_zeroed_indices = None + self.source_pixel_zeroed_indices_to_keep = None + self.linear_light_profile_blurred_mapping_matrix = None - if fit_0.inversion is None: - return + if mapper_indices is not None: - if fit_0.inversion.total(cls=AbstractMapper) == 0: - return + self.mapper_indices = jnp.array(mapper_indices) - from autoarray.inversion.inversion.interferometer.lop import ( - InversionInterferometerMappingPyLops, - ) + if source_pixel_zeroed_indices is not None: - if isinstance(fit_0.inversion, InversionInterferometerMappingPyLops): - return + self.source_pixel_zeroed_indices = jnp.array(source_pixel_zeroed_indices) - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion + ids_zeros = jnp.array(source_pixel_zeroed_indices, dtype=int) - if inversion_0.mapping_matrix.shape[1] == inversion_1.mapping_matrix.shape[1]: - if np.allclose(inversion_0.mapping_matrix, inversion_1.mapping_matrix): - self.mapper_list = inversion_0.cls_list_from(cls=AbstractMapper) - - logger.info( - "PRELOADS - Mappers of planes preloaded for this model-fit." - ) - - def set_operated_mapping_matrix_with_preloads(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and matrices used to perform the linear algebra in an - inversion can be preloaded, which help efficiently construct the curvature matrix. - - This function compares the operated mapping matrix of two fit's corresponding to two model instances, and - preloads the mapper if the mapping matrix of both fits are the same. - - The preload is typically used in searches where only light profiles vary (e.g. when only the lens's light is - being fitted for). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.operated_mapping_matrix = None - - from autoarray.inversion.inversion.interferometer.lop import ( - InversionInterferometerMappingPyLops, - ) - - if isinstance(fit_0.inversion, InversionInterferometerMappingPyLops): - return - - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0 is None: - return - - if ( - inversion_0.operated_mapping_matrix.shape[1] - == inversion_1.operated_mapping_matrix.shape[1] - ): - if ( - np.max( - abs( - inversion_0.operated_mapping_matrix - - inversion_1.operated_mapping_matrix - ) - ) - < 1e-8 - ): - self.operated_mapping_matrix = inversion_0.operated_mapping_matrix - - logger.info( - "PRELOADS - Inversion linear algebra quantities preloaded for this model-fit." - ) - - def set_linear_func_inversion_dicts(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and matrices used to perform the linear algebra in an - inversion can be preloaded, which help efficiently construct the curvature matrix. - - This function compares the operated mapping matrix of two fit's corresponding to two model instances, and - preloads the mapper if the mapping matrix of both fits are the same. - - The preload is typically used in searches where only light profiles vary (e.g. when only the lens's light is - being fitted for). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ + values_to_solve = jnp.ones(np.max(mapper_indices), dtype=bool) + values_to_solve = values_to_solve.at[ids_zeros].set(False) - from autoarray.inversion.pixelization.pixelization import Pixelization + # Get the indices where values_to_solve is True + self.source_pixel_zeroed_indices_to_keep = jnp.where(values_to_solve)[0] - self.linear_func_operated_mapping_matrix_dict = None + if linear_light_profile_blurred_mapping_matrix is not None: - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0 is None: - return - - if not inversion_0.has(cls=AbstractMapper): - return - - if not inversion_0.has(cls=AbstractLinearObjFuncList): - return - - try: - inversion_0.linear_func_operated_mapping_matrix_dict - except NotImplementedError: - return - - if not hasattr(inversion_0, "linear_func_operated_mapping_matrix_dict"): - return - - should_preload = False - - for operated_mapping_matrix_0, operated_mapping_matrix_1 in zip( - inversion_0.linear_func_operated_mapping_matrix_dict.values(), - inversion_1.linear_func_operated_mapping_matrix_dict.values(), - ): - if ( - np.max(abs(operated_mapping_matrix_0 - operated_mapping_matrix_1)) - < 1e-8 - ): - should_preload = True - else: - should_preload = False - break - - if should_preload: - self.linear_func_operated_mapping_matrix_dict = ( - inversion_0.linear_func_operated_mapping_matrix_dict + self.linear_light_profile_blurred_mapping_matrix = jnp.array( + linear_light_profile_blurred_mapping_matrix ) - self.data_linear_func_matrix_dict = inversion_0.data_linear_func_matrix_dict - - logger.info( - "PRELOADS - Inversion linear light profile operated mapping matrix / data linear func matrix preloaded for this model-fit." - ) - - def set_curvature_matrix(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and therefore its associated curvature matrix is also - fixed, meaning the curvature matrix preloaded. - - If linear ``LightProfiles``'s are included, the regions of the curvature matrix associatd with these - objects vary but the diagonals of the mapper do not change. In this case, the `curvature_matrix_mapper_diag` - is preloaded. - - This function compares the curvature matrix of two fit's corresponding to two model instances, and preloads - this value if it is the same for both fits. - - The preload is typically used in **PyAutoGalaxy** inversions using a `Rectangular` pixelization. - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - - self.curvature_matrix = None - self.data_vector_mapper = None - self.curvature_matrix_mapper_diag = None - self.mapper_operated_mapping_matrix_dict = None - - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0 is None: - return - - try: - inversion_0._curvature_matrix_mapper_diag - except NotImplementedError: - return - - if inversion_0.curvature_matrix.shape == inversion_1.curvature_matrix.shape: - if ( - np.max(abs(inversion_0.curvature_matrix - inversion_1.curvature_matrix)) - < 1e-8 - ): - self.curvature_matrix = inversion_0.curvature_matrix - - logger.info( - "PRELOADS - Inversion Curvature Matrix preloaded for this model-fit." - ) - - return - - if inversion_0._curvature_matrix_mapper_diag is not None: - if ( - np.max( - abs( - inversion_0._curvature_matrix_mapper_diag - - inversion_1._curvature_matrix_mapper_diag - ) - ) - < 1e-8 - ): - self.mapper_operated_mapping_matrix_dict = ( - inversion_0.mapper_operated_mapping_matrix_dict - ) - self.data_vector_mapper = inversion_0._data_vector_mapper - self.curvature_matrix_mapper_diag = ( - inversion_0._curvature_matrix_mapper_diag - ) - - logger.info( - "PRELOADS - Inversion Curvature Matrix Mapper Diag preloaded for this model-fit." - ) - - def set_regularization_matrix_and_term(self, fit_0, fit_1): - """ - If the `MassProfile`'s and `Mesh`'s in a model are fixed, the mapping of image-pixels to the - source-pixels does not change during the model-fit and therefore its associated regularization matrices are - also fixed, meaning the log determinant of the regularization matrix term of the Bayesian evidence can be - preloaded. - - This function compares the value of the log determinant of the regularization matrix of two fit's corresponding - to two model instances, and preloads this value if it is the same for both fits. - - The preload is typically used in searches where only light profiles vary (e.g. when only the lens's light is - being fitted for). - - Parameters - ---------- - fit_0 - The first fit corresponding to a model with a specific set of unit-values. - fit_1 - The second fit corresponding to a model with a different set of unit-values. - """ - self.regularization_matrix = None - self.log_det_regularization_matrix_term = None - - inversion_0 = fit_0.inversion - inversion_1 = fit_1.inversion - - if inversion_0 is None: - return - - if inversion_0.total(cls=AbstractMapper) == 0: - return - - if ( - abs( - inversion_0.log_det_regularization_matrix_term - - inversion_1.log_det_regularization_matrix_term - ) - < 1e-8 - ): - self.regularization_matrix = inversion_0.regularization_matrix - self.log_det_regularization_matrix_term = ( - inversion_0.log_det_regularization_matrix_term - ) - - logger.info( - "PRELOADS - Inversion Log Det Regularization Matrix Term preloaded for this model-fit." - ) - - def check_via_fit(self, fit): - import copy - - settings_inversion = copy.deepcopy(fit.settings_inversion) - - fit_with_preloads = fit.refit_with_new_preloads( - preloads=self, settings_inversion=settings_inversion - ) - - fit_without_preloads = fit.refit_with_new_preloads( - preloads=self.__class__(use_w_tilde=False), - settings_inversion=settings_inversion, - ) - - if os.environ.get("PYAUTOFIT_TEST_MODE") == "1": - return - - try: - if ( - abs( - fit_with_preloads.figure_of_merit - - fit_without_preloads.figure_of_merit - ) - > self.check_threshold - ): - raise exc.PreloadsException( - f""" - The log likelihood of fits using and not using preloads are not consistent by a value larger than - the preloads check threshold of {self.check_threshold}, indicating preloading has gone wrong. - - The likelihood values are: - - With Preloads: {fit_with_preloads.figure_of_merit} - Without Preloads: {fit_without_preloads.figure_of_merit} - - Double check that the model-fit is set up correctly and that the preloads are being used correctly. - - This exception can be turned off by setting the general.yaml -> test -> check_preloads to False - in the config files. However, care should be taken when doing this. - """ - ) - - except exc.InversionException: - data_vector_difference = np.max( - np.abs( - fit_with_preloads.inversion.data_vector - - fit_without_preloads.inversion.data_vector - ) - ) - - if data_vector_difference > 1.0e-4: - raise exc.PreloadsException( - f""" - The data vectors of fits using and not using preloads are not consistent, indicating - preloading has gone wrong. - - The maximum value a data vector absolute value difference is: {data_vector_difference} - """ - ) - - curvature_reg_matrix_difference = np.max( - np.abs( - fit_with_preloads.inversion.curvature_reg_matrix - - fit_without_preloads.inversion.curvature_reg_matrix - ) - ) - - if curvature_reg_matrix_difference > 1.0e-4: - raise exc.PreloadsException( - f""" - The curvature matrices of fits using and not using preloads are not consistent, indicating - preloading has gone wrong. - - The maximum value of a curvature matrix absolute value difference is: {curvature_reg_matrix_difference} - """ - ) - - @property - def info(self) -> List[str]: - """ - The information on what has or has not been preloaded, which is written to the file `preloads.summary`. - - Returns - ------- - A list of strings containing statements on what has or has not been preloaded. - """ - line = [f"W Tilde = {self.w_tilde is not None}\n"] - line += [f"Relocated Grid = {self.relocated_grid is not None}\n"] - line += [f"Mapper = {self.mapper_list is not None}\n"] - line += [ - f"Blurred Mapping Matrix = {self.operated_mapping_matrix is not None}\n" - ] - line += [ - f"Inversion Linear Func (Linear Light Profile) Dicts = {self.linear_func_operated_mapping_matrix_dict is not None}\n" - ] - line += [f"Curvature Matrix = {self.curvature_matrix is not None}\n"] - line += [ - f"Curvature Matrix Mapper Diag = {self.curvature_matrix_mapper_diag is not None}\n" - ] - line += [f"Regularization Matrix = {self.regularization_matrix is not None}\n"] - line += [ - f"Log Det Regularization Matrix Term = {self.log_det_regularization_matrix_term is not None}\n" - ] - - return line diff --git a/autoarray/structures/arrays/array_1d_util.py b/autoarray/structures/arrays/array_1d_util.py index 5e6565326..8ed9b83a6 100644 --- a/autoarray/structures/arrays/array_1d_util.py +++ b/autoarray/structures/arrays/array_1d_util.py @@ -1,11 +1,11 @@ from __future__ import annotations +import jax.numpy as jnp import numpy as np from typing import TYPE_CHECKING, List, Union if TYPE_CHECKING: from autoarray.mask.mask_1d import Mask1D -from autoarray import numba_util from autoarray.mask import mask_1d_util from autoarray.structures.arrays import array_2d_util @@ -38,26 +38,27 @@ def convert_array_1d( If True, the ndarray is stored in its native format [total_y_pixels, total_x_pixels]. This avoids mapping large data arrays to and from the slim / native formats, which can be a computational bottleneck. """ - array_1d = array_2d_util.convert_array(array=array_1d) + is_numpy = True if isinstance(array_1d, np.ndarray) else False + is_native = array_1d.shape[0] == mask_1d.shape_native[0] if is_native == store_native: - return array_1d + array_1d = array_1d elif not store_native: - return array_1d_slim_from( - array_1d_native=np.array(array_1d), - mask_1d=np.array(mask_1d), + array_1d = array_1d_slim_from( + array_1d_native=array_1d, + mask_1d=mask_1d, ) - - return array_1d_native_from( - array_1d_slim=array_1d, - mask_1d=np.array(mask_1d), - ) + else: + array_1d = array_1d_native_from( + array_1d_slim=array_1d, + mask_1d=mask_1d, + ) + return np.array(array_1d) if is_numpy else jnp.array(array_1d) -@numba_util.jit() def array_1d_slim_from( array_1d_native: np.ndarray, mask_1d: np.ndarray, @@ -105,19 +106,8 @@ def array_1d_slim_from( array_1d_slim = array_1d_slim_from(array_1d_native, array_2d=array_2d) """ - - total_pixels = mask_1d_util.total_pixels_1d_from( - mask_1d=mask_1d, - ) - - line_1d_slim = np.zeros(shape=total_pixels) - index = 0 - - for x in range(mask_1d.shape[0]): - if not mask_1d[x]: - line_1d_slim[index] = array_1d_native[x] - index += 1 - + unmasked_indices = ~mask_1d + line_1d_slim = array_1d_native[unmasked_indices] return line_1d_slim @@ -132,13 +122,12 @@ def array_1d_native_from( ).astype("int") return array_1d_via_indexes_1d_from( - array_1d_slim=np.array(array_1d_slim), + array_1d_slim=array_1d_slim, shape=shape, native_index_for_slim_index_1d=native_index_for_slim_index_1d, ) -@numba_util.jit() def array_1d_via_indexes_1d_from( array_1d_slim: np.ndarray, shape: int, @@ -177,11 +166,9 @@ def array_1d_via_indexes_1d_from( ndarray The native 1D array of values mapped from the slimmed array with dimensions (total_x_pixels). """ - array_1d_native = np.zeros(shape) - - for slim_index in range(len(native_index_for_slim_index_1d)): - array_1d_native[native_index_for_slim_index_1d[slim_index]] = array_1d_slim[ - slim_index - ] - - return array_1d_native + if isinstance(array_1d_slim, np.ndarray): + array_1d_native = np.zeros(shape) + array_1d_native[native_index_for_slim_index_1d] = array_1d_slim + return array_1d_native + array_1d_native = jnp.zeros(shape) + return array_1d_native.at[native_index_for_slim_index_1d].set(array_1d_slim) diff --git a/autoarray/structures/arrays/array_2d_util.py b/autoarray/structures/arrays/array_2d_util.py index d8a0f7e10..a1534e480 100644 --- a/autoarray/structures/arrays/array_2d_util.py +++ b/autoarray/structures/arrays/array_2d_util.py @@ -1,5 +1,4 @@ from __future__ import annotations -import jax import jax.numpy as jnp import numpy as np from typing import TYPE_CHECKING, List, Tuple, Union @@ -7,11 +6,9 @@ if TYPE_CHECKING: from autoarray.mask.mask_2d import Mask2D -from autoarray import numba_util from autoarray.mask import mask_2d_util from autoarray import exc -from functools import partial def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: @@ -23,12 +20,15 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray: array : list or ndarray The array which may be converted to an ndarray """ - if isinstance(array, np.ndarray) or isinstance(array, list): + + try: + array = array.array + except AttributeError: + pass + + if isinstance(array, list): array = np.asarray(array) - elif isinstance(array, jnp.ndarray): - array = jax.lax.cond( - type(array) is list, lambda _: jnp.asarray(array), lambda _: array, None - ) + return array @@ -118,27 +118,30 @@ def convert_array_2d( If True, the ndarray is stored in its native format [total_y_pixels, total_x_pixels]. This avoids mapping large data arrays to and from the slim / native formats, which can be a computational bottleneck. """ - array_2d = convert_array(array=array_2d).copy() + array_2d = convert_array(array=array_2d) + + is_numpy = True if isinstance(array_2d, np.ndarray) else False check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d) is_native = len(array_2d.shape) == 2 if is_native and not skip_mask: - array_2d *= np.invert(mask_2d) + array_2d *= ~mask_2d if is_native == store_native: - return array_2d + array_2d = array_2d elif not store_native: - return array_2d_slim_from( - array_2d_native=np.array(array_2d), - mask_2d=np.array(mask_2d), + array_2d = array_2d_slim_from( + array_2d_native=array_2d, + mask_2d=mask_2d, ) - array_2d = array_2d_native_from( - array_2d_slim=array_2d, - mask_2d=np.array(mask_2d), - ) - return array_2d + else: + array_2d = array_2d_native_from( + array_2d_slim=array_2d, + mask_2d=mask_2d, + ) + return np.array(array_2d) if is_numpy else jnp.array(array_2d) def convert_array_2d_to_slim(array_2d: np.ndarray, mask_2d: Mask2D) -> np.ndarray: @@ -209,7 +212,6 @@ def convert_array_2d_to_native(array_2d: np.ndarray, mask_2d: Mask2D) -> np.ndar ) -@numba_util.jit() def extracted_array_2d_from( array_2d: np.ndarray, y0: int, y1: int, x0: int, x1: int ) -> np.ndarray: @@ -222,8 +224,6 @@ def extracted_array_2d_from( In the example below, an array of size (5,5) is extracted using the coordinates y0=1, y1=4, x0=1, x1=4. This extracts an array of dimensions (3,3) and is equivalent to array_2d[1:4, 1:4]. - This function is necessary work with numba jit tags and is why a standard Numpy array extraction is not used. - Parameters ---------- array_2d @@ -247,29 +247,32 @@ def extracted_array_2d_from( array_2d = np.ones((5,5)) extracted_array = extract_array_2d(array_2d=array_2d, y0=1, y1=4, x0=1, x1=4) """ - new_shape = (y1 - y0, x1 - x0) + resized_array = np.zeros(new_shape, dtype=array_2d.dtype) + + # Compute valid slice ranges + y_start = max(y0, 0) + y_end = min(y1, array_2d.shape[0]) + x_start = max(x0, 0) + x_end = min(x1, array_2d.shape[1]) - resized_array = np.zeros(shape=new_shape) + # Target insertion indices + y_insert_start = y_start - y0 + y_insert_end = y_insert_start + (y_end - y_start) + x_insert_start = x_start - x0 + x_insert_end = x_insert_start + (x_end - x_start) - for y_resized, y in enumerate(range(y0, y1)): - for x_resized, x in enumerate(range(x0, x1)): - if ( - y >= 0 - and x >= 0 - and y <= array_2d.shape[0] - 1 - and x <= array_2d.shape[1] - 1 - ): - resized_array[y_resized, x_resized] = array_2d[y, x] + resized_array[y_insert_start:y_insert_end, x_insert_start:x_insert_end] = array_2d[ + y_start:y_end, x_start:x_end + ] return resized_array -@numba_util.jit() def resized_array_2d_from( array_2d: np.ndarray, resized_shape: Tuple[int, int], - origin: Tuple[int, int] = (-1, -1), + origin: Tuple[int, int] = None, pad_value: int = 0.0, ) -> np.ndarray: """ @@ -279,9 +282,6 @@ def resized_array_2d_from( calculated automatically. For example, a (5,5) array's central pixel is (2,2). For even dimensions the central pixel is assumed to be the lower indexed value, e.g. a (6,4) array's central pixel is calculated as (2,1). - The default origin is (-1, -1) because numba requires that the function input is the same type throughout the - function, thus a default 'None' value cannot be used. - Parameters ---------- array_2d @@ -305,106 +305,40 @@ def resized_array_2d_from( resize_array = resize_array_2d(array_2d=array_2d, new_shape=(2,2), origin=(2, 2)) """ - y_is_even = int(array_2d.shape[0]) % 2 == 0 - x_is_even = int(array_2d.shape[1]) % 2 == 0 - - if origin == (-1, -1): - if y_is_even: - y_centre = int(array_2d.shape[0] / 2) - elif not y_is_even: - y_centre = int(array_2d.shape[0] / 2) - - if x_is_even: - x_centre = int(array_2d.shape[1] / 2) - elif not x_is_even: - x_centre = int(array_2d.shape[1] / 2) - + if origin is None: + y_centre = array_2d.shape[0] // 2 + x_centre = array_2d.shape[1] // 2 origin = (y_centre, x_centre) - resized_array = np.zeros(shape=resized_shape) - - if y_is_even: - y_min = origin[0] - int(resized_shape[0] / 2) - y_max = origin[0] + int((resized_shape[0] / 2)) + 1 - elif not y_is_even: - y_min = origin[0] - int(resized_shape[0] / 2) - y_max = origin[0] + int((resized_shape[0] / 2)) + 1 - - if x_is_even: - x_min = origin[1] - int(resized_shape[1] / 2) - x_max = origin[1] + int((resized_shape[1] / 2)) + 1 - elif not x_is_even: - x_min = origin[1] - int(resized_shape[1] / 2) - x_max = origin[1] + int((resized_shape[1] / 2)) + 1 - - for y_resized, y in enumerate(range(y_min, y_max)): - for x_resized, x in enumerate(range(x_min, x_max)): - if y >= 0 and y < array_2d.shape[0] and x >= 0 and x < array_2d.shape[1]: - if ( - y_resized >= 0 - and y_resized < resized_shape[0] - and x_resized >= 0 - and x_resized < resized_shape[1] - ): - resized_array[y_resized, x_resized] = array_2d[y, x] - else: - if ( - y_resized >= 0 - and y_resized < resized_shape[0] - and x_resized >= 0 - and x_resized < resized_shape[1] - ): - resized_array[y_resized, x_resized] = pad_value + # Define window edges so that length == resized_shape dimension exactly + y_min = origin[0] - resized_shape[0] // 2 + y_max = y_min + resized_shape[0] - return resized_array + x_min = origin[1] - resized_shape[1] // 2 + x_max = x_min + resized_shape[1] + resized_array = np.full(resized_shape, pad_value, dtype=array_2d.dtype) -@numba_util.jit() -def replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d: np.ndarray, noise_map_2d: np.ndarray, target_signal_to_noise: float = 2.0 -) -> np.ndarray: - """ - If the values of a 2D image array are negative, this function replaces the corresponding 2D noise-map array - values to meet a specified target to noise value. + # Calculate source indices clipped to array bounds + src_y_start = max(y_min, 0) + src_y_end = min(y_max, array_2d.shape[0]) + src_x_start = max(x_min, 0) + src_x_end = min(x_max, array_2d.shape[1]) - This routine is necessary because of anomolous values in images which come from our HST ACS data_type-reduction - pipeline, where image-pixels with negative values (e.g. due to the background sky subtraction) have extremely - small noise values, which inflate their signal-to-noise values and chi-squared contributions in the modeling. + # Calculate destination indices corresponding to source indices + dst_y_start = max(0, -y_min) + dst_y_end = dst_y_start + (src_y_end - src_y_start) + dst_x_start = max(0, -x_min) + dst_x_end = dst_x_start + (src_x_end - src_x_start) - Parameters - ---------- - image_2d - The 2D image array used to locate the pixel indexes in the noise-map which are replaced. - noise_map_2d - The 2D noise-map array whose values are replaced. - target_signal_to_noise - The target signal-to-noise the noise-map valueus are changed to. - - Returns - ------- - ndarray - The 2D noise-map with values changed. - - Examples - -------- - image_2d = np.ones((5,5)) - image_2d[2,2] = -1.0 - noise_map_2d = np.ones((5,5)) + # Copy overlapping region from source to destination + resized_array[dst_y_start:dst_y_end, dst_x_start:dst_x_end] = array_2d[ + src_y_start:src_y_end, src_x_start:src_x_end + ] - noise_map_2d_replaced = replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=2.0): - """ - for y in range(image_2d.shape[0]): - for x in range(image_2d.shape[1]): - if image_2d[y, x] < 0.0: - absolute_signal_to_noise = np.abs(image_2d[y, x]) / noise_map_2d[y, x] - if absolute_signal_to_noise >= target_signal_to_noise: - noise_map_2d[y, x] = np.abs(image_2d[y, x]) / target_signal_to_noise - - return noise_map_2d + return resized_array -@numba_util.jit() def index_2d_for_index_slim_from(indexes_slim: np.ndarray, shape_native) -> np.ndarray: """ For pixels on a native 2D array of shape (total_y_pixels, total_x_pixels), this array maps the slimmed 1D pixel @@ -437,16 +371,18 @@ def index_2d_for_index_slim_from(indexes_slim: np.ndarray, shape_native) -> np.n indexes_slim = np.array([0, 1, 2, 5]) indexes_2d = index_2d_for_index_slim_from(indexes_slim=indexes_slim, shape=(3,3)) """ - index_2d_for_index_slim = np.zeros((indexes_slim.shape[0], 2)) + # Calculate row indices by integer division by number of columns + rows = indexes_slim // shape_native[1] - for i, index_slim in enumerate(indexes_slim): - index_2d_for_index_slim[i, 0] = int(index_slim / shape_native[1]) - index_2d_for_index_slim[i, 1] = int(index_slim % shape_native[1]) + # Calculate column indices by modulo number of columns + cols = indexes_slim % shape_native[1] + + # Stack rows and cols horizontally into shape (N, 2) + index_2d_for_index_slim = np.vstack((rows, cols)).T return index_2d_for_index_slim -@numba_util.jit() def index_slim_for_index_2d_from(indexes_2d: np.ndarray, shape_native) -> np.ndarray: """ For pixels on a native 2D array of shape (total_y_pixels, total_x_pixels), this array maps the 2D pixel indexes to @@ -479,12 +415,10 @@ def index_slim_for_index_2d_from(indexes_2d: np.ndarray, shape_native) -> np.nda indexes_2d = np.array([[0,0], [1,0], [2,0], [2,2]]) indexes_flat = index_flat_for_index_2d_from(indexes_2d=indexes_2d, shape=(3,3)) """ - index_slim_for_index_native_2d = np.zeros(indexes_2d.shape[0]) - - for i in range(indexes_2d.shape[0]): - index_slim_for_index_native_2d[i] = int( - (indexes_2d[i, 0]) * shape_native[1] + indexes_2d[i, 1] - ) + # Calculate 1D indexes as row_index * number_of_columns + col_index + index_slim_for_index_native_2d = ( + indexes_2d[:, 0] * shape_native[1] + indexes_2d[:, 1] + ) return index_slim_for_index_native_2d @@ -587,7 +521,6 @@ def array_2d_native_from( ) -@partial(jax.jit, static_argnums=(1,)) def array_2d_via_indexes_from( array_2d_slim: np.ndarray, shape: Tuple[int, int], @@ -620,69 +553,10 @@ def array_2d_via_indexes_from( ndarray The native 2D array of values mapped from the slimmed array with dimensions (total_values, total_values). """ + if isinstance(array_2d_slim, np.ndarray): + array = np.zeros(shape) + array[tuple(native_index_for_slim_index_2d.T)] = array_2d_slim + return array return ( jnp.zeros(shape).at[tuple(native_index_for_slim_index_2d.T)].set(array_2d_slim) ) - - -@numba_util.jit() -def array_2d_slim_complex_from( - array_2d_native: np.ndarray, - mask: np.ndarray, -) -> np.ndarray: - """ - For a 2D array and mask, map the values of all unmasked pixels to a 1D array. - - The pixel coordinate origin is at the top left corner of the 2D array and goes right-wards and downwards. - - For example, for an array of shape (3,3) and where all pixels are unmasked: - - - pixel [0,0] of the 2D array will correspond to index 0 of the 1D array. - - pixel [0,1] of the 2D array will correspond to index 1 of the 1D array. - - pixel [1,0] of the 2D array will correspond to index 3 of the 1D array. - - pixel [2,0] of the 2D array will correspond to index 6 of the 1D array. - - Parameters - ---------- - array_2d_native - A 2D array of values on the dimensions of the grid. - mask - A 2D array of bools, where `False` values mean unmasked and are included in the mapping. - array_2d - The 2D array of values which are mapped to a 1D array. - - Returns - ------- - ndarray - A 1D array of values mapped from the 2D array with dimensions (total_unmasked_pixels). - """ - - total_pixels = np.sum(~mask) - - array_1d = 0 + 0j * np.zeros(shape=total_pixels) - index = 0 - - for y in range(mask.shape[0]): - for x in range(mask.shape[1]): - if not mask[y, x]: - array_1d[index] = array_2d_native[y, x] - index += 1 - - return array_1d - - -@numba_util.jit() -def array_2d_native_complex_via_indexes_from( - array_2d_slim: np.ndarray, - shape_native: Tuple[int, int], - native_index_for_slim_index_2d: np.ndarray, -) -> np.ndarray: - array_2d = 0 + 0j * np.zeros(shape_native) - - for slim_index in range(len(native_index_for_slim_index_2d)): - array_2d[ - native_index_for_slim_index_2d[slim_index, 0], - native_index_for_slim_index_2d[slim_index, 1], - ] = array_2d_slim[slim_index] - - return array_2d diff --git a/autoarray/structures/arrays/irregular.py b/autoarray/structures/arrays/irregular.py index 667ddfd43..f4b54ee11 100644 --- a/autoarray/structures/arrays/irregular.py +++ b/autoarray/structures/arrays/irregular.py @@ -38,12 +38,6 @@ def __init__(self, values: Union[List, np.ndarray]): A collection of values. """ - # if len(values) == 0: - # return [] - - # if isinstance(values, ArrayIrregular): - # return values - if type(values) is list: values = np.asarray(values) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index bb62656a9..b2992278a 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -1,8 +1,7 @@ -from astropy import units import jax import jax.numpy as jnp +from jax import lax import numpy as np -import scipy.signal from pathlib import Path from typing import List, Tuple, Union @@ -16,7 +15,6 @@ from autoarray import exc from autoarray import type as ty from autoarray.structures.arrays import array_2d_util -from autoarray.mask.mask_2d import mask_2d_util class Kernel2D(AbstractArray2D): @@ -27,8 +25,10 @@ def __init__( header=None, normalize: bool = False, store_native: bool = False, + image_mask=None, + blurring_mask=None, *args, - **kwargs + **kwargs, ): """ An array of values, which are paired to a uniform 2D mask of pixels. Each entry @@ -57,6 +57,27 @@ def __init__( if normalize: self._array = np.divide(self._array, np.sum(self._array)) + self.stored_native = self.native + + self.slim_to_native_tuple = None + + if image_mask is not None: + + slim_to_native = image_mask.derive_indexes.native_for_slim.astype("int32") + self.slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1]) + + self.slim_to_native_blurring_tuple = None + + if blurring_mask is not None: + + slim_to_native_blurring = ( + blurring_mask.derive_indexes.native_for_slim.astype("int32") + ) + self.slim_to_native_blurring_tuple = ( + slim_to_native_blurring[:, 0], + slim_to_native_blurring[:, 1], + ) + @classmethod def no_mask( cls, @@ -65,6 +86,8 @@ def no_mask( shape_native: Tuple[int, int] = None, origin: Tuple[float, float] = (0.0, 0.0), normalize: bool = False, + image_mask=None, + blurring_mask=None, ): """ Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically @@ -92,7 +115,13 @@ def no_mask( pixel_scales=pixel_scales, origin=origin, ) - return Kernel2D(values=values, mask=values.mask, normalize=normalize) + return Kernel2D( + values=values, + mask=values.mask, + normalize=normalize, + image_mask=image_mask, + blurring_mask=blurring_mask, + ) @classmethod def full( @@ -255,7 +284,7 @@ def from_gaussian( """ grid = Grid2D.uniform(shape_native=shape_native, pixel_scales=pixel_scales) - grid_shifted = np.subtract(grid, centre) + grid_shifted = np.subtract(grid.array, centre) grid_radius = np.sqrt(np.sum(grid_shifted**2.0, 1)) theta_coordinate_to_profile = np.arctan2( grid_shifted[:, 0], grid_shifted[:, 1] @@ -297,6 +326,9 @@ def from_as_gaussian_via_alma_fits_header_parameters( centre: Tuple[float, float] = (0.0, 0.0), normalize: bool = False, ) -> "Kernel2D": + + from astropy import units + x_stddev = ( x_stddev * (units.deg).to(units.arcsec) / (2.0 * np.sqrt(2.0 * np.log(2.0))) ) @@ -385,7 +417,7 @@ def rescaled_with_odd_dimensions_from( try: kernel_rescaled = rescale( - np.array(self.native._array), + self.native.array, rescale_factor, anti_aliasing=False, mode="constant", @@ -393,7 +425,7 @@ def rescaled_with_odd_dimensions_from( ) except TypeError: kernel_rescaled = rescale( - np.array(self.native._array), + self.native.array, rescale_factor, anti_aliasing=False, mode="constant", @@ -469,23 +501,59 @@ def convolved_array_from(self, array: Array2D) -> Array2D: ------ KernelException if either Kernel2D psf dimension is odd """ + import scipy.signal + if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: raise exc.KernelException("Kernel2D Kernel2D must be odd") array_2d = array.native convolved_array_2d = scipy.signal.convolve2d( - array_2d._array, np.array(self.native._array), mode="same" + array_2d.array, self.native.array, mode="same" ) convolved_array_1d = array_2d_util.array_2d_slim_from( - mask_2d=np.array(array_2d.mask), + mask_2d=array_2d.mask, array_2d_native=convolved_array_2d, ) return Array2D(values=convolved_array_1d, mask=array_2d.mask) - def convolve_image(self, image, blurring_image, jax_method="fft"): + def convolved_array_with_mask_from(self, array: Array2D, mask) -> Array2D: + """ + Convolve an array with this Kernel2D + + Parameters + ---------- + image + An array representing the image the Kernel2D is convolved with. + + Returns + ------- + convolved_image + An array representing the image after convolution. + + Raises + ------ + KernelException if either Kernel2D psf dimension is odd + """ + import scipy.signal + + if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: + raise exc.KernelException("Kernel2D Kernel2D must be odd") + + convolved_array_2d = scipy.signal.convolve2d( + array.array, self.native.array, mode="same" + ) + + convolved_array_1d = array_2d_util.array_2d_slim_from( + mask_2d=mask, + array_2d_native=convolved_array_2d, + ) + + return Array2D(values=convolved_array_1d, mask=mask) + + def convolve_image(self, image, blurring_image, jax_method="direct"): """ For a given 1D array and blurring array, convolve the two using this psf. @@ -502,33 +570,93 @@ def convolve_image(self, image, blurring_image, jax_method="fft"): kernels that are more than about 5x5. Default is `fft`. """ - slim_to_native = jnp.nonzero( - jnp.logical_not(image.mask.array), size=image.shape[0] - ) - slim_to_native_blurring = jnp.nonzero( - jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] - ) + slim_to_native_tuple = self.slim_to_native_tuple + slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple + + if slim_to_native_tuple is None: + + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(image.mask.array), size=image.shape[0] + ) + + if slim_to_native_blurring_tuple is None: - expanded_array_native = jnp.zeros(image.mask.shape) + slim_to_native_blurring_tuple = jnp.nonzero( + jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] + ) - expanded_array_native = expanded_array_native.at[slim_to_native].set( - image.array + # make sure dtype matches what you want + expanded_array_native = jnp.zeros( + image.mask.shape, dtype=jnp.asarray(image.array).dtype ) - expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set( - blurring_image.array + + # set using a tuple of index arrays + expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( + jnp.asarray(image.array) ) + expanded_array_native = expanded_array_native.at[ + slim_to_native_blurring_tuple + ].set(jnp.asarray(blurring_image.array)) - kernel = np.array(self.native.array) + kernel = self.stored_native.array convolve_native = jax.scipy.signal.convolve( expanded_array_native, kernel, mode="same", method=jax_method ) - convolved_array_1d = convolve_native[slim_to_native] + convolved_array_1d = convolve_native[slim_to_native_tuple] return Array2D(values=convolved_array_1d, mask=image.mask) - def convolve_image_no_blurring(self, image, mask, jax_method="fft"): + def convolve_image_no_blurring(self, image, mask, jax_method="direct"): + """ + For a given 1D array and blurring array, convolve the two using this psf. + + Parameters + ---------- + image + 1D array of the values which are to be blurred with the psf's PSF. + blurring_image + 1D array of the blurring values which blur into the array after PSF convolution. + jax_method + If JAX is enabled this keyword will indicate what method is used for the PSF + convolution. Can be either `direct` to calculate it in real space or `fft` + to calculated it via a fast Fourier transform. `fft` is typically faster for + kernels that are more than about 5x5. Default is `fft`. + """ + + slim_to_native_tuple = self.slim_to_native_tuple + + if slim_to_native_tuple is None: + + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(mask.array), size=image.shape[0] + ) + + # make sure dtype matches what you want + expanded_array_native = jnp.zeros(mask.shape) + + # set using a tuple of index arrays + if isinstance(image, np.ndarray) or isinstance(image, jnp.ndarray): + expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( + image + ) + else: + expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( + jnp.asarray(image.array) + ) + + kernel = self.stored_native.array + + convolve_native = jax.scipy.signal.convolve( + expanded_array_native, kernel, mode="same", method=jax_method + ) + + convolved_array_1d = convolve_native[slim_to_native_tuple] + + return Array2D(values=convolved_array_1d, mask=mask) + + def convolve_image_no_blurring_for_mapping(self, image, mask, jax_method="direct"): """ For a given 1D array and blurring array, convolve the two using this psf. @@ -545,23 +673,33 @@ def convolve_image_no_blurring(self, image, mask, jax_method="fft"): kernels that are more than about 5x5. Default is `fft`. """ - slim_to_native = jnp.nonzero(jnp.logical_not(mask.array), size=image.shape[0]) + slim_to_native_tuple = self.slim_to_native_tuple + if slim_to_native_tuple is None: + + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(mask.array), size=image.shape[0] + ) + + # make sure dtype matches what you want expanded_array_native = jnp.zeros(mask.shape) - expanded_array_native = expanded_array_native.at[slim_to_native].set(image) + # set using a tuple of index arrays + expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( + image + ) - kernel = np.array(self.native.array) + kernel = self.stored_native.array convolve_native = jax.scipy.signal.convolve( expanded_array_native, kernel, mode="same", method=jax_method ) - convolved_array_1d = convolve_native[slim_to_native] + convolved_array_1d = convolve_native[slim_to_native_tuple] return Array2D(values=convolved_array_1d, mask=mask) - def convolve_mapping_matrix(self, mapping_matrix, mask): + def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"): """For a given 1D array and blurring array, convolve the two using this psf. Parameters @@ -569,6 +707,6 @@ def convolve_mapping_matrix(self, mapping_matrix, mask): image 1D array of the values which are to be blurred with the psf's PSF. """ - return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None))( - mapping_matrix, mask - ).T + return jax.vmap( + self.convolve_image_no_blurring_for_mapping, in_axes=(1, None, None) + )(mapping_matrix, mask, jax_method).T diff --git a/autoarray/structures/arrays/rgb.py b/autoarray/structures/arrays/rgb.py new file mode 100644 index 000000000..8e2171f23 --- /dev/null +++ b/autoarray/structures/arrays/rgb.py @@ -0,0 +1,45 @@ +from autoarray.abstract_ndarray import AbstractNDArray +from autoarray.structures.arrays.uniform_2d import Array2D + + +class Array2DRGB(Array2D): + + def __init__(self, values, mask): + """ + A container for RGB images which have a final dimension of 3, which allows them to be visualized using + the same functionality as `Array2D` objects. + + By passing an RGB image to this class, the following visualization functionality is used when the RGB + image is used in `Plotter` objects: + + - The RGB image is plotted using the `imshow` function of Matplotlib. + - Functionality which sets the scale of the axis, zooms the image, and sets the axis limits is used. + - The colorbar is set to the RGB image, which is a 3D array with a final dimension of 3. + - The formatting of the image is identical to that of `Array2D` objects, which means the image is plotted + with the same aspect ratio as the original image making for easy subplot formatting. + + This class always assumes the array is in its `native` representation, but with a final dimension of 3. + + Parameters + ---------- + values + The values of the RGB image, which is a 3D array with a final dimension of 3. + mask + The 2D mask associated with the array, defining the pixels each array value in its ``slim`` representation + is paired with. + """ + + array = values + + while isinstance(array, AbstractNDArray): + array = array.array + + self._array = array + self.mask = mask + + @property + def native(self) -> "Array2D": + """ + Returns the RGB ndarray of shape [total_y_pixels, total_x_pixels, 3] in its `native` representation. + """ + return self diff --git a/autoarray/structures/arrays/uniform_1d.py b/autoarray/structures/arrays/uniform_1d.py index 4dbf1692f..d708a68ad 100644 --- a/autoarray/structures/arrays/uniform_1d.py +++ b/autoarray/structures/arrays/uniform_1d.py @@ -24,6 +24,7 @@ def __init__( header: Optional[Header] = None, store_native: bool = False, ): + values = array_1d_util.convert_array_1d( array_1d=values, mask_1d=mask, @@ -87,7 +88,7 @@ def no_mask( origin=origin, ) - return Array1D(values=values, mask=mask, header=header) + return Array1D(values=np.array(values), mask=mask, header=header) @classmethod def full( diff --git a/autoarray/structures/arrays/uniform_2d.py b/autoarray/structures/arrays/uniform_2d.py index 11c478ad5..7c955e2b1 100644 --- a/autoarray/structures/arrays/uniform_2d.py +++ b/autoarray/structures/arrays/uniform_2d.py @@ -7,6 +7,7 @@ from autoconf.fitsable import ndarray_via_fits_from, header_obj_from from autoarray.mask.mask_2d import Mask2D +from autoarray.mask.derive.zoom_2d import Zoom2D from autoarray.structures.abstract_structure import Structure from autoarray.structures.header import Header from autoarray.structures.arrays.uniform_1d import Array1D @@ -232,11 +233,6 @@ def __init__( if conf.instance["general"]["structures"]["native_binned_only"]: store_native = True - try: - values = values._array - except AttributeError: - values = values - values = array_2d_util.convert_array_2d( array_2d=values, mask_2d=mask, @@ -292,6 +288,31 @@ def native(self) -> "Array2D": values=self, mask=self.mask, header=self.header, store_native=True ) + @property + def native_for_fits(self) -> "Array2D": + """ + Return a `Array2D` for output to a .fits file, where the data is stored in its `native` representation, + which is an ``ndarray`` of shape [total_y_pixels, total_x_pixels]. + + Depending on configuration files, this array could be zoomed in on such that only the unmasked region + of the image is included in the .fits file, to save hard-disk space. Alternatively, the original `shape_native` + of the data can be retained. + + If it is already stored in its `native` representation it is return as it is. If not, it is mapped from + `slim` to `native` and returned as a new `Array2D`. + """ + if conf.instance["visualize"]["plots"]["fits_are_zoomed"]: + + zoom = Zoom2D(mask=self.mask) + + buffer = 0 if self.mask.is_all_false else 1 + + return zoom.array_2d_from(array=self, buffer=buffer) + + return Array2D( + values=self, mask=self.mask, header=self.header, store_native=True + ) + @property def native_skip_mask(self) -> "Array2D": """ @@ -322,7 +343,7 @@ def in_counts_per_second(self) -> "Array2D": @property def original_orientation(self) -> Union[np.ndarray, "Array2D"]: return layout_util.rotate_array_via_roe_corner_from( - array=np.array(self), roe_corner=self.header.original_roe_corner + array=self, roe_corner=self.header.original_roe_corner ) @property @@ -451,68 +472,6 @@ def brightest_sub_pixel_coordinate_in_region_from( pixel_coordinates_2d=(subpixel_y, subpixel_x) ) - def zoomed_around_mask(self, buffer: int = 1) -> "Array2D": - """ - Extract the 2D region of an array corresponding to the rectangle encompassing all unmasked values. - - This is used to extract and visualize only the region of an image that is used in an analysis. - - Parameters - ---------- - buffer - The number pixels around the extracted array used as a buffer. - """ - - extracted_array_2d = array_2d_util.extracted_array_2d_from( - array_2d=np.array(self.native._array), - y0=self.mask.zoom_region[0] - buffer, - y1=self.mask.zoom_region[1] + buffer, - x0=self.mask.zoom_region[2] - buffer, - x1=self.mask.zoom_region[3] + buffer, - ) - - mask = Mask2D.all_false( - shape_native=extracted_array_2d.shape, - pixel_scales=self.pixel_scales, - origin=self.mask.mask_centre, - ) - - array = array_2d_util.convert_array_2d( - array_2d=extracted_array_2d, mask_2d=mask - ) - - return Array2D(values=array, mask=mask, header=self.header) - - def extent_of_zoomed_array(self, buffer: int = 1) -> np.ndarray: - """ - For an extracted zoomed array computed from the method *zoomed_around_mask* compute its extent in scaled - coordinates. - - The extent of the grid in scaled units returned as an ``ndarray`` of the form [x_min, x_max, y_min, y_max]. - - This is used visualize zoomed and extracted arrays via the imshow() method. - - Parameters - ---------- - buffer - The number pixels around the extracted array used as a buffer. - """ - extracted_array_2d = array_2d_util.extracted_array_2d_from( - array_2d=np.array(self.native._array), - y0=self.mask.zoom_region[0] - buffer, - y1=self.mask.zoom_region[1] + buffer, - x0=self.mask.zoom_region[2] - buffer, - x1=self.mask.zoom_region[3] + buffer, - ) - - mask = Mask2D.all_false( - shape_native=extracted_array_2d.shape, - pixel_scales=self.pixel_scales, - origin=self.mask.mask_centre, - ) - - return mask.geometry.extent - def resized_from( self, new_shape: Tuple[int, int], mask_pad_value: int = 0.0 ) -> "Array2D": @@ -532,7 +491,7 @@ def resized_from( """ resized_array_2d = array_2d_util.resized_array_2d_from( - array_2d=np.array(self.native._array), resized_shape=new_shape + array_2d=self.native.array, resized_shape=new_shape ) resized_mask = self.mask.resized_from( @@ -592,14 +551,14 @@ def trimmed_after_convolution_from( psf_cut_x = int(np.ceil(kernel_shape[1] / 2)) - 1 array_y = int(self.mask.shape[0]) array_x = int(self.mask.shape[1]) - trimmed_array_2d = self.native[ + trimmed_array_2d = self.native.array[ psf_cut_y : array_y - psf_cut_y, psf_cut_x : array_x - psf_cut_x ] resized_mask = self.mask.resized_from(new_shape=trimmed_array_2d.shape) array = array_2d_util.convert_array_2d( - array_2d=trimmed_array_2d._array, mask_2d=resized_mask + array_2d=trimmed_array_2d, mask_2d=resized_mask ) return Array2D( @@ -962,7 +921,7 @@ def from_yx_and_values( ) grid_pixels = geometry_util.grid_pixel_indexes_2d_slim_from( - grid_scaled_2d_slim=np.array(grid.slim), + grid_scaled_2d_slim=grid.slim, shape_native=shape_native, pixel_scales=pixel_scales, ) diff --git a/autoarray/structures/grids/grid_1d_util.py b/autoarray/structures/grids/grid_1d_util.py index aa0592c47..82aa4514e 100644 --- a/autoarray/structures/grids/grid_1d_util.py +++ b/autoarray/structures/grids/grid_1d_util.py @@ -1,15 +1,14 @@ from __future__ import annotations import numpy as np +import jax.numpy as jnp from typing import TYPE_CHECKING, List, Union, Tuple if TYPE_CHECKING: from autoarray.mask.mask_1d import Mask1D from autoarray.structures.arrays import array_1d_util -from autoarray import numba_util from autoarray.geometry import geometry_util from autoarray.structures.grids import grid_2d_util -from autoarray.mask import mask_1d_util from autoarray import type as ty @@ -41,19 +40,23 @@ def convert_grid_1d( grid_1d = grid_2d_util.convert_grid(grid=grid_1d) + is_numpy = True if isinstance(grid_1d, np.ndarray) else False + is_native = grid_1d.shape[0] == mask_1d.shape_native[0] if is_native == store_native: - return grid_1d + grid_1d = grid_1d elif not store_native: - return grid_1d_slim_from( + grid_1d = grid_1d_slim_from( grid_1d_native=grid_1d, - mask_1d=np.array(mask_1d), + mask_1d=mask_1d, ) - return grid_1d_native_from( - grid_1d_slim=grid_1d, - mask_1d=np.array(mask_1d), - ) + else: + grid_1d = grid_1d_native_from( + grid_1d_slim=grid_1d, + mask_1d=mask_1d, + ) + return np.array(grid_1d) if is_numpy else jnp.array(grid_1d) def grid_1d_slim_via_shape_slim_from( @@ -95,7 +98,6 @@ def grid_1d_slim_via_shape_slim_from( ) -@numba_util.jit() def grid_1d_slim_via_mask_from( mask_1d: np.ndarray, pixel_scales: ty.PixelScales, @@ -131,23 +133,13 @@ def grid_1d_slim_via_mask_from( mask = np.array([True, False, True, False, False, False]) grid_slim = grid_1d_via_mask_from(mask_1d=mask_1d, pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) """ - - total_pixels = mask_1d_util.total_pixels_1d_from(mask_1d) - - grid_1d = np.zeros(shape=(total_pixels,)) - centres_scaled = geometry_util.central_scaled_coordinate_1d_from( shape_slim=mask_1d.shape, pixel_scales=pixel_scales, origin=origin ) - - index = 0 - - for x in range(mask_1d.shape[0]): - if not mask_1d[x]: - grid_1d[index] = (x - centres_scaled[0]) * pixel_scales[0] - index += 1 - - return grid_1d + indices = jnp.arange(mask_1d.shape[0]) + unmasked = jnp.logical_not(mask_1d) + coords = (indices - centres_scaled[0]) * pixel_scales[0] + return coords[unmasked] def grid_1d_slim_from( @@ -179,7 +171,7 @@ def grid_1d_slim_from( """ return array_1d_util.array_1d_slim_from( - array_1d_native=np.array(grid_1d_native), + array_1d_native=grid_1d_native, mask_1d=mask_1d, ) diff --git a/autoarray/structures/grids/grid_2d_util.py b/autoarray/structures/grids/grid_2d_util.py index 6c72c00ef..db8f92fe0 100644 --- a/autoarray/structures/grids/grid_2d_util.py +++ b/autoarray/structures/grids/grid_2d_util.py @@ -1,7 +1,6 @@ from __future__ import annotations import numpy as np import jax.numpy as jnp -import jax from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -11,10 +10,22 @@ from autoarray import exc from autoarray.structures.arrays import array_2d_util from autoarray.geometry import geometry_util -from autoarray import numba_util from autoarray import type as ty +def convert_grid(grid: Union[np.ndarray, List]) -> np.ndarray: + + try: + grid = grid.array + except AttributeError: + pass + + if isinstance(grid, list): + grid = np.asarray(grid) + + return grid + + def check_grid_slim(grid, shape_native): if shape_native is None: raise exc.GridException( @@ -36,13 +47,6 @@ def check_grid_slim(grid, shape_native): ) -def convert_grid(grid: Union[np.ndarray, List]) -> np.ndarray: - if type(grid) is list: - grid = np.asarray(grid) - - return grid - - def check_grid_2d(grid_2d: np.ndarray): if grid_2d.shape[-1] != 2: raise exc.GridException( @@ -108,25 +112,33 @@ def convert_grid_2d( grid_2d = convert_grid(grid=grid_2d) + is_numpy = True if isinstance(grid_2d, np.ndarray) else False + check_grid_2d_and_mask_2d(grid_2d=grid_2d, mask_2d=mask_2d) is_native = len(grid_2d.shape) == 3 if is_native: - grid_2d[:, :, 0] *= np.invert(mask_2d) - grid_2d[:, :, 1] *= np.invert(mask_2d) + if not is_numpy: + grid_2d = grid_2d.at[:, :, 0].multiply(~mask_2d) + grid_2d = grid_2d.at[:, :, 1].multiply(~mask_2d) + else: + grid_2d[:, :, 0] *= ~mask_2d + grid_2d[:, :, 1] *= ~mask_2d if is_native == store_native: - return grid_2d + grid_2d = grid_2d elif not store_native: - return grid_2d_slim_from( - grid_2d_native=np.array(grid_2d), - mask=np.array(mask_2d), + grid_2d = grid_2d_slim_from( + grid_2d_native=grid_2d, + mask=mask_2d, ) - return grid_2d_native_from( - grid_2d_slim=np.array(grid_2d), - mask_2d=np.array(mask_2d), - ) + else: + grid_2d = grid_2d_native_from( + grid_2d_slim=grid_2d, + mask_2d=mask_2d, + ) + return np.array(grid_2d) if is_numpy else jnp.array(grid_2d) def convert_grid_2d_to_slim( @@ -241,14 +253,28 @@ def grid_2d_slim_via_mask_from( shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin ) - centres_scaled = jnp.array(centres_scaled) - pixel_scales = jnp.array(pixel_scales) - sign = jnp.array([-1.0, 1.0]) - return ( - (jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T - centres_scaled) - * sign - * pixel_scales - ) + # JAX branch + if isinstance(mask_2d, jnp.ndarray): + centres_scaled = jnp.asarray(centres_scaled) + pixel_scales = jnp.asarray(pixel_scales) + sign = jnp.array([-1.0, 1.0]) + + # use jnp.where instead of jnp.nonzero + rows, cols = jnp.where(~mask_2d.astype(bool)) + indices = jnp.stack([rows, cols], axis=1) # shape (N_unmasked, 2) + + # (indices - centre) -> pixel offsets; apply sign and scale to get physical coords + return (indices - centres_scaled) * sign * pixel_scales + + # NumPy branch (kept consistent) + centres_scaled = np.asarray(centres_scaled) + pixel_scales = np.asarray(pixel_scales) + sign = np.array([-1.0, 1.0]) + + rows, cols = np.where(~mask_2d.astype(bool)) + indices = np.stack([rows, cols], axis=1) + + return (indices - centres_scaled) * sign * pixel_scales def grid_2d_via_mask_from( @@ -383,7 +409,6 @@ def grid_2d_via_shape_native_from( ) -@numba_util.jit() def _radial_projected_shape_slim_from( extent: np.ndarray, centre: Tuple[float, float], @@ -458,7 +483,6 @@ def _radial_projected_shape_slim_from( return int((scaled_distance / pixel_scale)) + 1 -@numba_util.jit() def grid_scaled_2d_slim_radial_projected_from( extent: np.ndarray, centre: Tuple[float, float], @@ -510,7 +534,7 @@ def grid_scaled_2d_slim_radial_projected_from( The (y,x) scaled units to pixel units conversion factor of the 2D mask array. shape_slim Manually choose the shape of the 1D projected grid that is returned. If 0, the border based on the 2D grid is - used (due to numba None cannot be used as a default value). + used. Returns ------- @@ -549,97 +573,13 @@ def grid_scaled_2d_slim_radial_projected_from( radii = centre[1] - for slim_index in range(shape_slim): - grid_scaled_2d_slim_radii[slim_index, 1] = radii - radii += pixel_scale - - return grid_scaled_2d_slim_radii - - -@numba_util.jit() -def relocated_grid_via_jit_from(grid, border_grid): - """ - Relocate the coordinates of a grid to its border if they are outside the border, where the border is - defined as all pixels at the edge of the grid's mask (see *mask._border_1d_indexes*). - - This is performed as follows: - - 1: Use the mean value of the grid's y and x coordinates to determine the origin of the grid. - 2: Compute the radial distance of every grid coordinate from the origin. - 3: For every coordinate, find its nearest pixel in the border. - 4: Determine if it is outside the border, by comparing its radial distance from the origin to its paired - border pixel's radial distance. - 5: If its radial distance is larger, use the ratio of radial distances to move the coordinate to the - border (if its inside the border, do nothing). - - The method can be used on uniform or irregular grids, however for irregular grids the border of the - 'image-plane' mask is used to define border pixels. - - Parameters - ---------- - grid - The grid (uniform or irregular) whose pixels are to be relocated to the border edge if outside it. - border_grid : Grid2D - The grid of border (y,x) coordinates. - """ + # Create an array of radii values spaced by pixel_scale + radii_array = radii + pixel_scale * np.arange(shape_slim) - grid_relocated = np.zeros(grid.shape) - grid_relocated[:, :] = grid[:, :] + # Assign all values at once to the second column (index 1) + grid_scaled_2d_slim_radii[:, 1] = radii_array - border_origin = np.zeros(2) - border_origin[0] = np.mean(border_grid[:, 0]) - border_origin[1] = np.mean(border_grid[:, 1]) - border_grid_radii = np.sqrt( - np.add( - np.square(np.subtract(border_grid[:, 0], border_origin[0])), - np.square(np.subtract(border_grid[:, 1], border_origin[1])), - ) - ) - border_min_radii = np.min(border_grid_radii) - - grid_radii = np.sqrt( - np.add( - np.square(np.subtract(grid[:, 0], border_origin[0])), - np.square(np.subtract(grid[:, 1], border_origin[1])), - ) - ) - - for pixel_index in range(grid.shape[0]): - if grid_radii[pixel_index] > border_min_radii: - closest_pixel_index = np.argmin( - np.square(grid[pixel_index, 0] - border_grid[:, 0]) - + np.square(grid[pixel_index, 1] - border_grid[:, 1]) - ) - - move_factor = ( - border_grid_radii[closest_pixel_index] / grid_radii[pixel_index] - ) - - if move_factor < 1.0: - grid_relocated[pixel_index, :] = ( - move_factor * (grid[pixel_index, :] - border_origin[:]) - + border_origin[:] - ) - - return grid_relocated - - -@numba_util.jit() -def furthest_grid_2d_slim_index_from( - grid_2d_slim: np.ndarray, slim_indexes: np.ndarray, coordinate: Tuple[float, float] -) -> int: - distance_to_centre = 0.0 - - for slim_index in slim_indexes: - y = grid_2d_slim[slim_index, 0] - x = grid_2d_slim[slim_index, 1] - distance_to_centre_new = (x - coordinate[1]) ** 2 + (y - coordinate[0]) ** 2 - - if distance_to_centre_new >= distance_to_centre: - distance_to_centre = distance_to_centre_new - furthest_grid_2d_slim_index = slim_index - - return furthest_grid_2d_slim_index + return grid_scaled_2d_slim_radii + 1e-6 def grid_2d_slim_from( @@ -671,16 +611,17 @@ def grid_2d_slim_from( """ grid_1d_slim_y = array_2d_util.array_2d_slim_from( - array_2d_native=np.array(grid_2d_native[:, :, 0]), - mask_2d=np.array(mask), + array_2d_native=grid_2d_native[:, :, 0], + mask_2d=mask, ) grid_1d_slim_x = array_2d_util.array_2d_slim_from( - array_2d_native=np.array(grid_2d_native[:, :, 1]), - mask_2d=np.array(mask), + array_2d_native=grid_2d_native[:, :, 1], + mask_2d=mask, ) - - return np.stack((grid_1d_slim_y, grid_1d_slim_x), axis=-1) + if isinstance(grid_2d_native, np.ndarray): + return np.stack((grid_1d_slim_y, grid_1d_slim_x), axis=-1) + return jnp.stack((grid_1d_slim_y, grid_1d_slim_x), axis=-1) def grid_2d_native_from( @@ -723,59 +664,9 @@ def grid_2d_native_from( mask_2d=mask_2d, ) - return np.stack((grid_2d_native_y, grid_2d_native_x), axis=-1) - - -@numba_util.jit() -def grid_2d_slim_upscaled_from( - grid_slim: np.ndarray, upscale_factor: int, pixel_scales: ty.PixelScales -) -> np.ndarray: - """ - From an input slimmed 2D grid, return an upscaled slimmed 2D grid where (y,x) coordinates are added at an - upscaled resolution to each grid coordinate. - - Parameters - ---------- - grid_slim - The slimmed grid of (y,x) coordinates over which a square uniform grid is overlaid. - upscale_factor - The upscaled resolution at which the new grid coordinates are computed. - pixel_scales - The pixel scale of the uniform grid that laid over the irregular grid of (y,x) coordinates. - """ - - grid_2d_slim_upscaled = np.zeros(shape=(grid_slim.shape[0] * upscale_factor**2, 2)) - - upscale_index = 0 - - y_upscale_half = pixel_scales[0] / 2 - y_upscale_step = pixel_scales[0] / upscale_factor - - x_upscale_half = pixel_scales[1] / 2 - x_upscale_step = pixel_scales[1] / upscale_factor - - for slim_index in range(grid_slim.shape[0]): - y_grid = grid_slim[slim_index, 0] - x_grid = grid_slim[slim_index, 1] - - for y in range(upscale_factor): - for x in range(upscale_factor): - grid_2d_slim_upscaled[upscale_index, 0] = ( - y_grid - + y_upscale_half - - y * y_upscale_step - - (y_upscale_step / 2.0) - ) - grid_2d_slim_upscaled[upscale_index, 1] = ( - x_grid - - x_upscale_half - + x * x_upscale_step - + (x_upscale_step / 2.0) - ) - - upscale_index += 1 - - return grid_2d_slim_upscaled + if isinstance(grid_2d_slim, np.ndarray): + return np.stack((grid_2d_native_y, grid_2d_native_x), axis=-1) + return jnp.stack((grid_2d_native_y, grid_2d_native_x), axis=-1) def grid_2d_of_points_within_radius( @@ -801,7 +692,6 @@ def compute_polygon_area(points): return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) -@numba_util.jit() def grid_pixels_in_mask_pixels_from( grid, shape_native, pixel_scales, origin ) -> np.ndarray: @@ -832,10 +722,63 @@ def grid_pixels_in_mask_pixels_from( mesh_pixels_per_image_pixel = np.zeros(shape=shape_native) - for i in range(grid_pixel_centres.shape[0]): - y = grid_pixel_centres[i, 0] - x = grid_pixel_centres[i, 1] + # Assuming grid_pixel_centres is a 2D array where each row contains (y, x) indices. + y_indices = grid_pixel_centres[:, 0] + x_indices = grid_pixel_centres[:, 1] - mesh_pixels_per_image_pixel[y, x] += 1 + # Use np.add.at to increment the specific indices in a safe and efficient manner + np.add.at(mesh_pixels_per_image_pixel, (y_indices, x_indices), 1) return mesh_pixels_per_image_pixel + + +def grid_2d_slim_via_shape_native_not_mask_from( + shape_native: Tuple[int, int], + pixel_scales: Tuple[float, float], + origin: Tuple[float, float] = (0.0, 0.0), +) -> np.ndarray: + """ + Build the slim (flattened) grid of all (y, x) pixel centres for a rectangular grid + of shape `shape_native`, scaled by `pixel_scales` and shifted by `origin`. + + This is equivalent to taking an unmasked mask of shape `shape_native` and calling + grid_2d_slim_via_mask_from on it. + + Parameters + ---------- + shape_native + A pair (Ny, Nx) giving the number of pixels in y and x. + pixel_scales + A pair (sy, sx) giving the physical size of each pixel in y and x. + origin + A 2-tuple (y0, x0) around which the grid is centred. + + Returns + ------- + grid_slim : ndarray, shape (Ny*Nx, 2) + Each row is the (y, x) coordinate of one pixel centre, in row-major order, + shifted so that `origin` ↔ physical pixel-centre average, and scaled by + `pixel_scales`, with y increasing “up” and x increasing “right”. + """ + Ny, Nx = shape_native + sy, sx = pixel_scales + y0, x0 = origin + + # compute the integer pixel‐centre coordinates in array index space + # row indices 0..Ny-1, col indices 0..Nx-1 + arange = jnp.arange + meshy, meshx = jnp.meshgrid(arange(Ny), arange(Nx), indexing="ij") + coords = jnp.stack([meshy, meshx], axis=-1).reshape(-1, 2) + + # convert to physical coordinates: subtract array‐centre, flip y, scale, then add origin + # array‐centre in index space is at ((Ny-1)/2, (Nx-1)/2) + cy, cx = (Ny - 1) / 2.0, (Nx - 1) / 2.0 + # row index i → physical y = (cy - i) * sy + y0 + # col index j → physical x = (j - cx) * sx + x0 + idx_y = coords[:, 0] + idx_x = coords[:, 1] + + phys_y = (cy - idx_y) * sy + y0 + phys_x = (idx_x - cx) * sx + x0 + + return jnp.stack([phys_y, phys_x], axis=1) diff --git a/autoarray/structures/grids/irregular_2d.py b/autoarray/structures/grids/irregular_2d.py index 57d4c4a6f..ecd2aa831 100644 --- a/autoarray/structures/grids/irregular_2d.py +++ b/autoarray/structures/grids/irregular_2d.py @@ -1,4 +1,6 @@ import logging +import jax +import jax.numpy as jnp import numpy as np from typing import List, Tuple, Union @@ -43,8 +45,13 @@ def __init__(self, values: Union[np.ndarray, List]): if type(values) is list: if isinstance(values[0], Grid2DIrregular): values = values + elif isinstance(values[0], jnp.ndarray): + values = jnp.asarray(values) else: - values = np.asarray(values) + try: + values = np.asarray(values) + except ValueError: + pass super().__init__(values) @@ -185,8 +192,8 @@ def squared_distances_to_coordinate_from( coordinate The (y,x) coordinate from which the squared distance of every *Coordinate* is computed. """ - squared_distances = np.square(self[:, 0] - coordinate[0]) + np.square( - self[:, 1] - coordinate[1] + squared_distances = jnp.square(self.array[:, 0] - coordinate[0]) + jnp.square( + self.array[:, 1] - coordinate[1] ) return ArrayIrregular(values=squared_distances) @@ -201,8 +208,8 @@ def distances_to_coordinate_from( coordinate The (y,x) coordinate from which the distance of every coordinate is computed. """ - distances = np.sqrt( - self.squared_distances_to_coordinate_from(coordinate=coordinate) + distances = jnp.sqrt( + self.squared_distances_to_coordinate_from(coordinate=coordinate).array ) return ArrayIrregular(values=distances) @@ -230,15 +237,12 @@ def furthest_distances_to_other_coordinates(self) -> ArrayIrregular: The further distances of every coordinate to every other coordinate on the irregular grid. """ - radial_distances_max = np.zeros((self.shape[0])) + def max_radial_distance(point): + x_distances = jnp.square(point[0] - self.array[:, 0]) + y_distances = jnp.square(point[1] - self.array[:, 1]) + return jnp.sqrt(jnp.nanmax(x_distances + y_distances)) - for i in range(self.shape[0]): - x_distances = np.square(np.subtract(self[i, 0], self[:, 0])) - y_distances = np.square(np.subtract(self[i, 1], self[:, 1])) - - radial_distances_max[i] = np.sqrt(np.max(np.add(x_distances, y_distances))) - - return ArrayIrregular(values=radial_distances_max) + return ArrayIrregular(values=jax.vmap(max_radial_distance)(self.array)) def grid_of_closest_from(self, grid_pair: "Grid2DIrregular") -> "Grid2DIrregular": """ @@ -256,14 +260,12 @@ def grid_of_closest_from(self, grid_pair: "Grid2DIrregular") -> "Grid2DIrregular the `Grid2DIrregular` to the input grid. """ - grid_of_closest = np.zeros((grid_pair.shape[0], 2)) - - for i in range(grid_pair.shape[0]): - x_distances = np.square(np.subtract(grid_pair[i, 0], self[:, 0])) - y_distances = np.square(np.subtract(grid_pair[i, 1], self[:, 1])) - - radial_distances = np.add(x_distances, y_distances) + jax_array = jnp.asarray(self.array) - grid_of_closest[i, :] = self[np.argmin(radial_distances), :] + def closest_point(point): + x_distances = jnp.square(point[0] - jax_array[:, 0]) + y_distances = jnp.square(point[1] - jax_array[:, 1]) + radial_distances = x_distances + y_distances + return jax_array[jnp.argmin(radial_distances)] - return Grid2DIrregular(values=grid_of_closest) + return jax.vmap(closest_point)(grid_pair.array) diff --git a/autoarray/structures/grids/sparse_2d_util.py b/autoarray/structures/grids/sparse_2d_util.py index 207c23f7e..2f648bec8 100644 --- a/autoarray/structures/grids/sparse_2d_util.py +++ b/autoarray/structures/grids/sparse_2d_util.py @@ -1,6 +1,5 @@ import logging import numpy as np -from scipy.interpolate import interp1d, griddata logger = logging.getLogger(__name__) logger.level = logging.DEBUG @@ -46,6 +45,7 @@ def create_img_and_grid_hb_order(img_2d, mask, mask_radius, pixel_scales, length image associated to that grid. """ + from scipy.interpolate import griddata from autoarray.structures.grids.uniform_2d import Grid2D shape_nnn = np.shape(mask)[0] @@ -76,6 +76,7 @@ def inverse_transform_sampling_interpolated(probabilities, n_samples, gridx, gri probabilities: 1D normalized cumulative probablity curve. n_samples: the number of points to draw. """ + from scipy.interpolate import interp1d cdf = np.cumsum(probabilities) npixels = len(probabilities) diff --git a/autoarray/structures/grids/uniform_1d.py b/autoarray/structures/grids/uniform_1d.py index 53a9ec756..ee7e78a72 100644 --- a/autoarray/structures/grids/uniform_1d.py +++ b/autoarray/structures/grids/uniform_1d.py @@ -178,7 +178,7 @@ def no_mask( origin=origin, ) - return Grid1D(values=values, mask=mask) + return Grid1D(values=np.array(values), mask=mask) @classmethod def from_mask(cls, mask: Mask1D) -> "Grid1D": @@ -195,12 +195,12 @@ def from_mask(cls, mask: Mask1D) -> "Grid1D": """ grid_1d = grid_1d_util.grid_1d_slim_via_mask_from( - mask_1d=np.array(mask), + mask_1d=mask.array, pixel_scales=mask.pixel_scales, origin=mask.origin, ) - return Grid1D(values=grid_1d, mask=mask) + return Grid1D(values=np.array(grid_1d), mask=mask) @classmethod def uniform( @@ -312,10 +312,10 @@ def grid_2d_radial_projected_from(self, angle: float = 0.0) -> Grid2DIrregular: """ grid = np.zeros((self.mask.pixels_in_mask, 2)) - grid[:, 1] = self.slim + grid[:, 1] = self.slim.array grid = geometry_util.transform_grid_2d_to_reference_frame( grid_2d=grid, centre=(0.0, 0.0), angle=angle ) - return Grid2DIrregular(values=grid) + return Grid2DIrregular(values=grid + 1e-6) diff --git a/autoarray/structures/grids/uniform_2d.py b/autoarray/structures/grids/uniform_2d.py index 670cbfcbe..8d7d91fa5 100644 --- a/autoarray/structures/grids/uniform_2d.py +++ b/autoarray/structures/grids/uniform_2d.py @@ -159,8 +159,9 @@ def __init__( not uniform (e.g. due to gravitational lensing) is cannot be computed internally in this function. If the over sampled grid is not passed in it is computed assuming uniformity. """ + values = grid_2d_util.convert_grid_2d( - grid_2d=np.array(values), + grid_2d=values, mask_2d=mask, store_native=store_native, ) @@ -179,18 +180,24 @@ def __init__( self.over_sampler = OverSampler(sub_size=over_sample_size, mask=mask) - if over_sampled is None: - self.over_sampled = ( - over_sample_util.grid_2d_slim_over_sampled_via_mask_from( - mask_2d=np.array(self.mask), - pixel_scales=self.mask.pixel_scales, - sub_size=np.array(self.over_sampler.sub_size._array).astype("int"), - origin=self.mask.origin, - ) - ) + self._over_sampled = over_sampled - else: - self.over_sampled = over_sampled + @property + def over_sampled(self): + + if self._over_sampled is not None: + return self._over_sampled + + over_sampled = over_sample_util.grid_2d_slim_over_sampled_via_mask_from( + mask_2d=np.array(self.mask), + pixel_scales=self.mask.pixel_scales, + sub_size=self.over_sampler.sub_size.array.astype("int"), + origin=self.mask.origin, + ) + + self._over_sampled = Grid2DIrregular(values=over_sampled) + + return self._over_sampled @classmethod def no_mask( @@ -244,7 +251,7 @@ def no_mask( ) return Grid2D( - values=values, + values=np.array(values), mask=mask, over_sample_size=over_sample_size, ) @@ -544,14 +551,14 @@ def from_mask( The mask whose masked pixels are used to setup the grid. """ - grid_1d = grid_2d_util.grid_2d_slim_via_mask_from( - mask_2d=mask._array, + grid_2d = grid_2d_util.grid_2d_slim_via_mask_from( + mask_2d=mask.array, pixel_scales=mask.pixel_scales, origin=mask.origin, ) return Grid2D( - values=grid_1d, + values=np.array(grid_2d), mask=mask, over_sample_size=over_sample_size, ) @@ -682,8 +689,6 @@ def blurring_grid_from( ) def subtracted_from(self, offset: Tuple[(float, float), np.ndarray]) -> "Grid2D": - if offset[0] == 0.0 and offset[1] == 0.0: - return self mask = Mask2D( mask=self.mask, @@ -692,9 +697,10 @@ def subtracted_from(self, offset: Tuple[(float, float), np.ndarray]) -> "Grid2D" ) return Grid2D( - values=self - np.array(offset), + values=self - jnp.array(offset), mask=mask, over_sample_size=self.over_sample_size, + over_sampled=self.over_sampled - jnp.array(offset), ) @property @@ -839,14 +845,10 @@ def squared_distances_to_coordinate_from( coordinate The (y,x) coordinate from which the squared distance of every grid (y,x) coordinate is computed. """ - if isinstance(self, jnp.ndarray): - squared_distances = jnp.square( - self.array[:, 0] - coordinate[0] - ) + jnp.square(self.array[:, 1] - coordinate[1]) - else: - squared_distances = np.square(self[:, 0] - coordinate[0]) + np.square( - self[:, 1] - coordinate[1] - ) + squared_distances = jnp.square(self.array[:, 0] - coordinate[0]) + jnp.square( + self.array[:, 1] - coordinate[1] + ) + return Array2D(values=squared_distances, mask=self.mask) def distances_to_coordinate_from( @@ -863,7 +865,7 @@ def distances_to_coordinate_from( squared_distance = self.squared_distances_to_coordinate_from( coordinate=coordinate ) - distances = np.sqrt(squared_distance.array) + distances = jnp.sqrt(squared_distance.array) return Array2D(values=distances, mask=self.mask) def grid_2d_radial_projected_shape_slim_from( @@ -1099,6 +1101,7 @@ def padded_grid_from(self, kernel_shape_native: Tuple[int, int]) -> "Grid2D": padded_mask = Mask2D.all_false( shape_native=padded_shape, pixel_scales=self.mask.pixel_scales, + origin=self.origin, ) pad_width = ( @@ -1107,7 +1110,7 @@ def padded_grid_from(self, kernel_shape_native: Tuple[int, int]) -> "Grid2D": ) over_sample_size = np.pad( - self.over_sample_size.native._array, + self.over_sample_size.native.array, pad_width, mode="constant", constant_values=1, diff --git a/autoarray/structures/header.py b/autoarray/structures/header.py index d628bf1d2..dbafd1f14 100644 --- a/autoarray/structures/header.py +++ b/autoarray/structures/header.py @@ -1,5 +1,4 @@ import logging -from astropy import time from typing import Dict, Tuple, Optional from autoarray.dataset import preprocess @@ -36,6 +35,8 @@ def exposure_time(self) -> str: @property def modified_julian_date(self) -> Optional[str]: + from astropy import time + if ( self.date_of_observation is not None and self.time_of_observation is not None diff --git a/autoarray/structures/mesh/abstract_2d.py b/autoarray/structures/mesh/abstract_2d.py index 910a82ef9..cf630443e 100644 --- a/autoarray/structures/mesh/abstract_2d.py +++ b/autoarray/structures/mesh/abstract_2d.py @@ -1,4 +1,3 @@ -import numpy as np from typing import Optional, Tuple from autoarray.structures.abstract_structure import Structure @@ -6,6 +5,15 @@ class Abstract2DMesh(Structure): + + @property + def slim(self) -> "Structure": + raise NotImplementedError() + + @property + def native(self) -> Structure: + raise NotImplementedError() + @property def parameters(self) -> int: return self.pixels diff --git a/autoarray/structures/mesh/delaunay_2d.py b/autoarray/structures/mesh/delaunay_2d.py index a6b7f9012..11c1707ae 100644 --- a/autoarray/structures/mesh/delaunay_2d.py +++ b/autoarray/structures/mesh/delaunay_2d.py @@ -1,12 +1,12 @@ import numpy as np -from typing import List, Optional, Tuple +from typing import Optional, Tuple from autoconf import cached_property from autoarray.inversion.linear_obj.neighbors import Neighbors from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.mesh.triangulation_2d import Abstract2DMeshTriangulation -from autoarray.inversion.pixelization.mesh import mesh_util +from autoarray.inversion.pixelization.mesh import mesh_numba_util class Mesh2DDelaunay(Abstract2DMeshTriangulation): @@ -69,9 +69,9 @@ def interpolated_array_from( shape_native=shape_native, extent=extent ) - interpolated_array = mesh_util.delaunay_interpolated_array_from( + interpolated_array = mesh_numba_util.delaunay_interpolated_array_from( shape_native=shape_native, - interpolation_grid_slim=interpolation_grid.slim, + interpolation_grid_slim=np.array(interpolation_grid.slim.array), delaunay=self.delaunay, pixel_values=values, ) diff --git a/autoarray/structures/mesh/rectangular_2d.py b/autoarray/structures/mesh/rectangular_2d.py index 0e4eab108..8e447b74b 100644 --- a/autoarray/structures/mesh/rectangular_2d.py +++ b/autoarray/structures/mesh/rectangular_2d.py @@ -1,26 +1,22 @@ +import jax.numpy as jnp import numpy as np -from scipy.interpolate import griddata + from typing import List, Optional, Tuple +from autoconf import cached_property + from autoarray import type as ty from autoarray.inversion.linear_obj.neighbors import Neighbors -from autoarray.inversion.pixelization.mesh import mesh_util from autoarray.mask.mask_2d import Mask2D -from autoarray.structures.abstract_structure import Structure from autoarray.structures.arrays.uniform_2d import Array2D -from autoarray.structures.grids import grid_2d_util + from autoarray.structures.mesh.abstract_2d import Abstract2DMesh -from autoconf import cached_property +from autoarray.inversion.pixelization.mesh import mesh_util +from autoarray.structures.grids import grid_2d_util -class Mesh2DRectangular(Abstract2DMesh): - @property - def slim(self) -> "Structure": - raise NotImplementedError() - @property - def native(self) -> Structure: - raise NotImplementedError() +class Mesh2DRectangular(Abstract2DMesh): def __init__( self, @@ -91,26 +87,28 @@ def overlay_grid( buffer The size of the extra spacing placed between the edges of the rectangular pixelization and input grid. """ - - y_min = np.min(grid[:, 0]) - buffer - y_max = np.max(grid[:, 0]) + buffer - x_min = np.min(grid[:, 1]) - buffer - x_max = np.max(grid[:, 1]) + buffer - - pixel_scales = ( - float((y_max - y_min) / shape_native[0]), - float((x_max - x_min) / shape_native[1]), + grid = grid.array + + y_min = jnp.min(grid[:, 0]) - buffer + y_max = jnp.max(grid[:, 0]) + buffer + x_min = jnp.min(grid[:, 1]) - buffer + x_max = jnp.max(grid[:, 1]) + buffer + + pixel_scales = jnp.array( + ( + (y_max - y_min) / shape_native[0], + (x_max - x_min) / shape_native[1], + ) ) + origin = jnp.array(((y_max + y_min) / 2.0, (x_max + x_min) / 2.0)) - origin = ((y_max + y_min) / 2.0, (x_max + x_min) / 2.0) - - grid_slim = grid_2d_util.grid_2d_slim_via_shape_native_from( + grid_slim = grid_2d_util.grid_2d_slim_via_shape_native_not_mask_from( shape_native=shape_native, pixel_scales=pixel_scales, origin=origin, ) - return Mesh2DRectangular( + return cls( values=grid_slim, shape_native=shape_native, pixel_scales=pixel_scales, @@ -134,7 +132,9 @@ def neighbors(self) -> Neighbors: @cached_property def edge_pixel_list(self) -> List: - return mesh_util.rectangular_edge_pixel_list_from(neighbors=self.neighbors) + return mesh_util.rectangular_edge_pixel_list_from( + shape_native=self.shape_native + ) @property def pixels(self) -> int: @@ -174,11 +174,15 @@ def interpolated_array_from( The (x0, x1, y0, y1) extent of the grid in scaled coordinates over which the grid is created if it is input. """ + from scipy.interpolate import griddata + interpolation_grid = self.interpolation_grid_from( shape_native=shape_native, extent=extent ) - interpolated_array = griddata(points=self, values=values, xi=interpolation_grid) + interpolated_array = griddata( + points=self.array, values=values, xi=interpolation_grid + ) interpolated_array = interpolated_array.reshape(shape_native) diff --git a/autoarray/structures/mesh/rectangular_2d_uniform.py b/autoarray/structures/mesh/rectangular_2d_uniform.py new file mode 100644 index 000000000..688b7c53b --- /dev/null +++ b/autoarray/structures/mesh/rectangular_2d_uniform.py @@ -0,0 +1,6 @@ +from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular + + +class Mesh2DRectangularUniform(Mesh2DRectangular): + + pass diff --git a/autoarray/structures/mesh/triangulation_2d.py b/autoarray/structures/mesh/triangulation_2d.py index b5c9fedc9..1ffc915ce 100644 --- a/autoarray/structures/mesh/triangulation_2d.py +++ b/autoarray/structures/mesh/triangulation_2d.py @@ -1,27 +1,18 @@ import numpy as np -import scipy.spatial + from typing import List, Union, Tuple -from autoarray.structures.abstract_structure import Structure from autoconf import cached_property from autoarray.geometry.geometry_2d_irregular import Geometry2DIrregular from autoarray.structures.mesh.abstract_2d import Abstract2DMesh from autoarray import exc -from autoarray.inversion.pixelization.mesh import mesh_util +from autoarray.inversion.pixelization.mesh import mesh_numba_util from autoarray.structures.grids import grid_2d_util class Abstract2DMeshTriangulation(Abstract2DMesh): - @property - def slim(self) -> "Structure": - raise NotImplementedError() - - @property - def native(self) -> Structure: - raise NotImplementedError() - def __init__( self, values: Union[np.ndarray, List], @@ -82,7 +73,7 @@ def geometry(self): ) @cached_property - def delaunay(self) -> scipy.spatial.Delaunay: + def delaunay(self) -> "scipy.spatial.Delaunay": """ Returns a `scipy.spatial.Delaunay` object from the 2D (y,x) grid of irregular coordinates, which correspond to the corner of every triangle of a Delaunay triangulation. @@ -95,13 +86,18 @@ def delaunay(self) -> scipy.spatial.Delaunay: to compute the Voronoi mesh are ill posed. These exceptions are caught and combined into a single `MeshException`, which helps exception handling in the `inversion` package. """ + + import scipy.spatial + try: - return scipy.spatial.Delaunay(np.asarray([self[:, 0], self[:, 1]]).T) + return scipy.spatial.Delaunay( + np.asarray([self.array[:, 0], self.array[:, 1]]).T + ) except (ValueError, OverflowError, scipy.spatial.qhull.QhullError) as e: raise exc.MeshException() from e @cached_property - def voronoi(self) -> scipy.spatial.Voronoi: + def voronoi(self) -> "scipy.spatial.Voronoi": """ Returns a `scipy.spatial.Voronoi` object from the 2D (y,x) grid of irregular coordinates, which correspond to the centre of every Voronoi pixel. @@ -113,6 +109,7 @@ def voronoi(self) -> scipy.spatial.Voronoi: to compute the Delaunay triangulation are ill posed. These exceptions are caught and combined into a single `MeshException`, which helps exception handling in the `inversion` package. """ + import scipy.spatial from scipy.spatial import QhullError try: @@ -128,7 +125,7 @@ def edge_pixel_list(self) -> List: Returns a list of the Voronoi pixel indexes that are on the edge of the mesh. """ - return mesh_util.voronoi_edge_pixels_from( + return mesh_numba_util.voronoi_edge_pixels_from( regions=self.voronoi.regions, point_region=self.voronoi.point_region ) diff --git a/autoarray/structures/mesh/voronoi_2d.py b/autoarray/structures/mesh/voronoi_2d.py index 72eeda54d..b4135610d 100644 --- a/autoarray/structures/mesh/voronoi_2d.py +++ b/autoarray/structures/mesh/voronoi_2d.py @@ -1,6 +1,6 @@ import numpy as np -from scipy.interpolate import griddata -from typing import List, Optional, Tuple + +from typing import Optional, Tuple from autoconf import cached_property @@ -8,7 +8,7 @@ from autoarray.structures.arrays.uniform_2d import Array2D from autoarray.structures.mesh.triangulation_2d import Abstract2DMeshTriangulation -from autoarray.inversion.pixelization.mesh import mesh_util +from autoarray.inversion.pixelization.mesh import mesh_numba_util class Mesh2DVoronoi(Abstract2DMeshTriangulation): @@ -36,10 +36,10 @@ def neighbors(self) -> Neighbors: see `Neighbors` for a complete description of the neighboring scheme. The neighbors of a Voronoi mesh are computed using the `ridge_points` attribute of the scipy `Voronoi` - object, as described in the method `mesh_util.voronoi_neighbors_from`. + object, as described in the method `mesh_numba_util.voronoi_neighbors_from`. """ - neighbors, sizes = mesh_util.voronoi_neighbors_from( + neighbors, sizes = mesh_numba_util.voronoi_neighbors_from( pixels=self.pixels, ridge_points=np.asarray(self.voronoi.ridge_points) ) @@ -76,12 +76,14 @@ def interpolated_array_from( The 2D shape in scaled coordinates (e.g. arc-seconds in PyAutoGalaxy / PyAutoLens) that the interpolated reconstructed source is returned on. """ + from scipy.interpolate import griddata + interpolation_grid = self.interpolation_grid_from( shape_native=shape_native, extent=extent ) if use_nn: - interpolated_array = mesh_util.voronoi_nn_interpolated_array_from( + interpolated_array = mesh_numba_util.voronoi_nn_interpolated_array_from( shape_native=shape_native, interpolation_grid_slim=interpolation_grid.slim, pixel_values=values, diff --git a/autoarray/structures/mock/mock_decorators.py b/autoarray/structures/mock/mock_decorators.py index c02ebc0b8..28cf9eaec 100644 --- a/autoarray/structures/mock/mock_decorators.py +++ b/autoarray/structures/mock/mock_decorators.py @@ -157,11 +157,3 @@ def ndarray_yx_2d_list_from(self, grid, *args, **kwargs): Such functions are common in **PyAutoGalaxy** for light and mass profile objects. """ return [np.multiply(1.0, grid.array), np.multiply(2.0, grid.array)] - - -class MockGridRadialMinimum: - def __init__(self): - pass - - def radial_grid_from(self, grid): - return np.sqrt(np.add(np.square(grid[:, 0]), np.square(grid[:, 1]))) diff --git a/autoarray/structures/mock/mock_grid.py b/autoarray/structures/mock/mock_grid.py index 2352c2027..b1639a02c 100644 --- a/autoarray/structures/mock/mock_grid.py +++ b/autoarray/structures/mock/mock_grid.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Tuple, List +from typing import Tuple from autoarray.geometry.abstract_2d import AbstractGeometry2D from autoarray.inversion.linear_obj.neighbors import Neighbors diff --git a/autoarray/structures/plot/structure_plotters.py b/autoarray/structures/plot/structure_plotters.py index a44d44adc..7e7cf655e 100644 --- a/autoarray/structures/plot/structure_plotters.py +++ b/autoarray/structures/plot/structure_plotters.py @@ -1,11 +1,9 @@ import numpy as np from typing import List, Optional, Union -from autoarray.plot.abstract_plotters import Plotter +from autoarray.plot.abstract_plotters import AbstractPlotter from autoarray.plot.visuals.one_d import Visuals1D from autoarray.plot.visuals.two_d import Visuals2D -from autoarray.plot.include.one_d import Include1D -from autoarray.plot.include.two_d import Include2D from autoarray.plot.mat_plot.one_d import MatPlot1D from autoarray.plot.mat_plot.two_d import MatPlot2D from autoarray.plot.auto_labels import AutoLabels @@ -15,13 +13,12 @@ from autoarray.structures.grids.uniform_2d import Grid2D -class Array2DPlotter(Plotter): +class Array2DPlotter(AbstractPlotter): def __init__( self, array: Array2D, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots `Array2D` objects using the matplotlib method `imshow()` and many other matplotlib functions which @@ -32,8 +29,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Array2D` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Array2D` and plotted via the visuals object. Parameters ---------- @@ -43,36 +39,28 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Array2D` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - visuals_2d=visuals_2d, include_2d=include_2d, mat_plot_2d=mat_plot_2d - ) + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) self.array = array - def get_visuals_2d(self) -> Visuals2D: - return self.get_2d.via_mask_from(mask=self.array.mask) - def figure_2d(self): """ Plots the plotter's `Array2D` object in 2D. """ self.mat_plot_2d.plot_array( array=self.array, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Array2D", filename="array"), ) -class Grid2DPlotter(Plotter): +class Grid2DPlotter(AbstractPlotter): def __init__( self, grid: Grid2D, - mat_plot_2d: MatPlot2D = MatPlot2D(), - visuals_2d: Visuals2D = Visuals2D(), - include_2d: Include2D = Include2D(), + mat_plot_2d: MatPlot2D = None, + visuals_2d: Visuals2D = None, ): """ Plots `Grid2D` objects using the matplotlib method `scatter()` and many other matplotlib functions which @@ -83,8 +71,7 @@ def __init__( but a user can manually input values into `MatPlot2d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from - the `Grid2D` and plotted via the visuals object, if the corresponding entry is `True` in the `Include2D` - object or the `config/visualize/include.ini` file. + the `Grid2D` and plotted via the visuals object. Parameters ---------- @@ -94,18 +81,11 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 2D plots. visuals_2d Contains 2D visuals that can be overlaid on 2D plots. - include_2d - Specifies which attributes of the `Grid2D` are extracted and plotted as visuals for 2D plots. """ - super().__init__( - visuals_2d=visuals_2d, include_2d=include_2d, mat_plot_2d=mat_plot_2d - ) + super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d) self.grid = grid - def get_visuals_2d(self) -> Visuals2D: - return self.get_2d.via_grid_from(grid=self.grid) - def figure_2d( self, color_array: np.ndarray = None, @@ -128,7 +108,7 @@ def figure_2d( """ self.mat_plot_2d.plot_grid( grid=self.grid, - visuals_2d=self.get_visuals_2d(), + visuals_2d=self.visuals_2d, auto_labels=AutoLabels(title="Grid2D", filename="grid"), color_array=color_array, plot_grid_lines=plot_grid_lines, @@ -136,14 +116,13 @@ def figure_2d( ) -class YX1DPlotter(Plotter): +class YX1DPlotter(AbstractPlotter): def __init__( self, y: Union[Array1D, List], x: Optional[Union[Array1D, Grid1D, List]] = None, - mat_plot_1d: MatPlot1D = MatPlot1D(), - visuals_1d: Visuals1D = Visuals1D(), - include_1d: Include1D = Include1D(), + mat_plot_1d: MatPlot1D = None, + visuals_1d: Visuals1D = None, should_plot_grid: bool = False, should_plot_zero: bool = False, plot_axis_type: Optional[str] = None, @@ -159,8 +138,7 @@ def __init__( but a user can manually input values into `MatPlot1d` to customize the figure's appearance. Overlaid on the figure are visuals, contained in the `Visuals1D` object. Attributes may be extracted from - the `Array1D` and plotted via the visuals object, if the corresponding entry is `True` in the `Include1D` - object or the `config/visualize/include.ini` file. + the `Array1D` and plotted via the visuals object. Parameters ---------- @@ -172,8 +150,6 @@ def __init__( Contains objects which wrap the matplotlib function calls that make 1D plots. visuals_1d Contains 1D visuals that can be overlaid on 1D plots. - include_1d - Specifies which attributes of the `Array1D` are extracted and plotted as visuals for 1D plots. """ if isinstance(y, list): @@ -182,9 +158,7 @@ def __init__( if isinstance(x, list): x = Array1D.no_mask(values=x, pixel_scales=1.0) - super().__init__( - visuals_1d=visuals_1d, include_1d=include_1d, mat_plot_1d=mat_plot_1d - ) + super().__init__(visuals_1d=visuals_1d, mat_plot_1d=mat_plot_1d) self.y = y self.x = y.grid_radial if x is None else x @@ -194,9 +168,6 @@ def __init__( self.plot_yx_dict = plot_yx_dict or {} self.auto_labels = auto_labels - def get_visuals_1d(self) -> Visuals1D: - return self.get_1d.via_array_1d_from(array_1d=self.x) - def figure_1d(self): """ Plots the plotter's y and x values in 1D. @@ -205,10 +176,10 @@ def figure_1d(self): self.mat_plot_1d.plot_yx( y=self.y, x=self.x, - visuals_1d=self.get_visuals_1d(), + visuals_1d=self.visuals_1d, auto_labels=self.auto_labels, should_plot_grid=self.should_plot_grid, should_plot_zero=self.should_plot_zero, plot_axis_type_override=self.plot_axis_type, - **self.plot_yx_dict + **self.plot_yx_dict, ) diff --git a/autoarray/structures/triangles/abstract.py b/autoarray/structures/triangles/abstract.py index 880eea2f7..3ae5e4718 100644 --- a/autoarray/structures/triangles/abstract.py +++ b/autoarray/structures/triangles/abstract.py @@ -3,7 +3,6 @@ import numpy as np from autoarray import Grid2D -from autoarray.structures.triangles.shape import Shape HEIGHT_FACTOR = 3**0.5 / 2 @@ -122,21 +121,6 @@ def for_indexes(self, indexes: np.ndarray) -> "AbstractTriangles": The new ArrayTriangles instance. """ - @abstractmethod - def containing_indices(self, shape: Shape) -> np.ndarray: - """ - Find the triangles that insect with a given shape. - - Parameters - ---------- - shape - The shape - - Returns - ------- - The indices of triangles that intersect the shape. - """ - @abstractmethod def neighborhood(self) -> "AbstractTriangles": """ diff --git a/autoarray/structures/triangles/array.py b/autoarray/structures/triangles/array.py new file mode 100644 index 000000000..353163a00 --- /dev/null +++ b/autoarray/structures/triangles/array.py @@ -0,0 +1,415 @@ +import numpy as np +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class + +from autoarray.structures.triangles.abstract import HEIGHT_FACTOR + +from autoarray.structures.grids.uniform_2d import Grid2D +from autoarray.structures.triangles.abstract import AbstractTriangles +from autoarray.structures.triangles.shape import Shape + +MAX_CONTAINING_SIZE = 15 + + +@register_pytree_node_class +class ArrayTriangles(AbstractTriangles): + def __init__( + self, + indices, + vertices, + max_containing_size=MAX_CONTAINING_SIZE, + **kwargs, + ): + """ + Represents a set of triangles in efficient NumPy arrays. + + Parameters + ---------- + indices + The indices of the vertices of the triangles. This is a 2D array where each row is a triangle + with the three indices of the vertices. + vertices + The vertices of the triangles. + """ + self._indices = indices + self._vertices = vertices + self.max_containing_size = max_containing_size + + def __len__(self): + return len(self.triangles) + + def __iter__(self): + return iter(self.triangles) + + def __str__(self): + return f"{self.__class__.__name__} with {len(self.indices)} triangles" + + def __repr__(self): + return str(self) + + @classmethod + def for_limits_and_scale( + cls, + y_min: float, + y_max: float, + x_min: float, + x_max: float, + scale: float, + max_containing_size=MAX_CONTAINING_SIZE, + ) -> "AbstractTriangles": + height = scale * HEIGHT_FACTOR + + vertices = [] + indices = [] + vertex_dict = {} + + def add_vertex(v): + if v not in vertex_dict: + vertex_dict[v] = len(vertices) + vertices.append(v) + return vertex_dict[v] + + rows = [] + for row_y in np.arange(y_min, y_max + height, height): + row = [] + offset = (len(rows) % 2) * scale / 2 + for col_x in np.arange(x_min - offset, x_max + scale, scale): + row.append((row_y, col_x)) + rows.append(row) + + for i in range(len(rows) - 1): + row = rows[i] + next_row = rows[i + 1] + for j in range(len(row)): + if i % 2 == 0 and j < len(next_row) - 1: + t1 = [ + add_vertex(row[j]), + add_vertex(next_row[j]), + add_vertex(next_row[j + 1]), + ] + if j < len(row) - 1: + t2 = [ + add_vertex(row[j]), + add_vertex(row[j + 1]), + add_vertex(next_row[j + 1]), + ] + indices.append(t2) + elif i % 2 == 1 and j < len(next_row) - 1: + t1 = [ + add_vertex(row[j]), + add_vertex(next_row[j]), + add_vertex(row[j + 1]), + ] + indices.append(t1) + if j < len(next_row) - 1: + t2 = [ + add_vertex(next_row[j]), + add_vertex(next_row[j + 1]), + add_vertex(row[j + 1]), + ] + indices.append(t2) + else: + continue + indices.append(t1) + + return cls( + indices=jnp.array(indices), + vertices=jnp.array(vertices), + max_containing_size=max_containing_size, + ) + + @property + def indices(self): + return self._indices + + @property + def vertices(self): + return self._vertices + + @property + def triangles(self) -> jnp.ndarray: + """ + The triangles as a 3x2 array of vertices. + """ + + invalid_mask = jnp.any(self.indices == -1, axis=1) + nan_array = jnp.full( + (self.indices.shape[0], 3, 2), + jnp.nan, + dtype=jnp.float32, + ) + safe_indices = jnp.where(self.indices == -1, 0, self.indices) + triangle_vertices = self.vertices[safe_indices] + return jnp.where(invalid_mask[:, None, None], nan_array, triangle_vertices) + + @property + def means(self) -> jnp.ndarray: + """ + The mean of each triangle. + """ + return jnp.mean(self.triangles, axis=1) + + def containing_indices(self, shape: Shape) -> jnp.ndarray: + """ + Find the triangles that insect with a given shape. + + Parameters + ---------- + shape + The shape + + Returns + ------- + The triangles that intersect the shape. + """ + inside = shape.mask(self.triangles) + + return jnp.where( + inside, + size=self.max_containing_size, + fill_value=-1, + )[0] + + def for_indexes(self, indexes: jnp.ndarray) -> "ArrayTriangles": + """ + Create a new ArrayTriangles containing indices and vertices corresponding to the given indexes + but without duplicate vertices. + + Parameters + ---------- + indexes + The indexes of the triangles to include in the new ArrayTriangles. + + Returns + ------- + The new ArrayTriangles instance. + """ + selected_indices = select_and_handle_invalid( + data=self.indices, + indices=indexes, + invalid_value=-1, + invalid_replacement=jnp.array([-1, -1, -1], dtype=jnp.int32), + ) + + flat_indices = selected_indices.flatten() + + selected_vertices = select_and_handle_invalid( + data=self.vertices, + indices=flat_indices, + invalid_value=-1, + invalid_replacement=jnp.array([jnp.nan, jnp.nan], dtype=jnp.float32), + ) + + unique_vertices, inv_indices = jnp.unique( + selected_vertices, + axis=0, + return_inverse=True, + equal_nan=True, + size=selected_indices.shape[0] * 3, + fill_value=jnp.nan, + ) + + nan_mask = jnp.isnan(unique_vertices).any(axis=1) + inv_indices = jnp.where(nan_mask[inv_indices], -1, inv_indices) + + new_indices = inv_indices.reshape(selected_indices.shape) + + new_indices_sorted = jnp.sort(new_indices, axis=1) + + unique_triangles_indices = jnp.unique( + new_indices_sorted, + axis=0, + size=new_indices_sorted.shape[0], + fill_value=-1, + ) + + return ArrayTriangles( + indices=unique_triangles_indices, + vertices=unique_vertices, + max_containing_size=self.max_containing_size, + ) + + def _up_sample_triangle(self): + triangles = self.triangles + + m01 = (triangles[:, 0] + triangles[:, 1]) / 2 + m12 = (triangles[:, 1] + triangles[:, 2]) / 2 + m20 = (triangles[:, 2] + triangles[:, 0]) / 2 + + return jnp.concatenate( + [ + jnp.stack([triangles[:, 1], m12, m01], axis=1), + jnp.stack([triangles[:, 2], m20, m12], axis=1), + jnp.stack([m01, m12, m20], axis=1), + jnp.stack([triangles[:, 0], m01, m20], axis=1), + ], + axis=0, + ) + + def up_sample(self) -> "ArrayTriangles": + """ + Up-sample the triangles by adding a new vertex at the midpoint of each edge. + + This means each triangle becomes four smaller triangles. + """ + new_indices, unique_vertices = remove_duplicates(self._up_sample_triangle()) + + return ArrayTriangles( + indices=new_indices, + vertices=unique_vertices, + max_containing_size=self.max_containing_size, + ) + + def _neighborhood_triangles(self): + triangles = self.triangles + + new_v0 = triangles[:, 1] + triangles[:, 2] - triangles[:, 0] + new_v1 = triangles[:, 0] + triangles[:, 2] - triangles[:, 1] + new_v2 = triangles[:, 0] + triangles[:, 1] - triangles[:, 2] + + return jnp.concatenate( + [ + jnp.stack([new_v0, triangles[:, 1], triangles[:, 2]], axis=1), + jnp.stack([triangles[:, 0], new_v1, triangles[:, 2]], axis=1), + jnp.stack([triangles[:, 0], triangles[:, 1], new_v2], axis=1), + triangles, + ], + axis=0, + ) + + def neighborhood(self) -> "ArrayTriangles": + """ + Create a new set of triangles that are the neighborhood of the current triangles. + + Includes the current triangles and the triangles that share an edge with the current triangles. + """ + new_indices, unique_vertices = remove_duplicates(self._neighborhood_triangles()) + + return ArrayTriangles( + indices=new_indices, + vertices=unique_vertices, + max_containing_size=self.max_containing_size, + ) + + def with_vertices(self, vertices: jnp.ndarray) -> "ArrayTriangles": + """ + Create a new set of triangles with the vertices replaced. + + Parameters + ---------- + vertices + The new vertices to use. + + Returns + ------- + The new set of triangles with the new vertices. + """ + return ArrayTriangles( + indices=self.indices, + vertices=vertices, + max_containing_size=self.max_containing_size, + ) + + @property + def area(self) -> float: + """ + The total area covered by the triangles. + """ + triangles = self.triangles + return ( + 0.5 + * np.abs( + (triangles[:, 0, 0] * (triangles[:, 1, 1] - triangles[:, 2, 1])) + + (triangles[:, 1, 0] * (triangles[:, 2, 1] - triangles[:, 0, 1])) + + (triangles[:, 2, 0] * (triangles[:, 0, 1] - triangles[:, 1, 1])) + ).sum() + ) + + def tree_flatten(self): + """ + Flatten this model as a PyTree. + """ + return ( + self.indices, + self.vertices, + ), (self.max_containing_size,) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """ + Unflatten a PyTree into a model. + """ + return cls( + indices=children[0], + vertices=children[1], + max_containing_size=aux_data[0], + ) + + +def select_and_handle_invalid( + data: jnp.ndarray, + indices: jnp.ndarray, + invalid_value, + invalid_replacement, +): + """ + Select data based on indices, handling invalid indices by replacing them with a specified value. + + Parameters + ---------- + data + The array from which to select data. + indices + The indices used to select data from the array. + invalid_value + The value representing invalid indices. + invalid_replacement + The value to use for invalid entries in the result. + + Returns + ------- + An array with selected data, where invalid indices are replaced with `invalid_replacement`. + """ + invalid_mask = indices == invalid_value + safe_indices = jnp.where(invalid_mask, 0, indices) + selected_data = data[safe_indices] + selected_data = jnp.where( + invalid_mask[..., None], + invalid_replacement, + selected_data, + ) + + return selected_data + + +def remove_duplicates(new_triangles): + unique_vertices, inverse_indices = jnp.unique( + new_triangles.reshape(-1, 2), + axis=0, + return_inverse=True, + size=2 * new_triangles.shape[0], + fill_value=jnp.nan, + equal_nan=True, + ) + + inverse_indices_flat = inverse_indices.reshape(-1) + selected_vertices = unique_vertices[inverse_indices_flat] + mask = jnp.any(jnp.isnan(selected_vertices), axis=1) + inverse_indices_flat = jnp.where(mask, -1, inverse_indices_flat) + inverse_indices = inverse_indices_flat.reshape(inverse_indices.shape) + + new_indices = inverse_indices.reshape(-1, 3) + + new_indices_sorted = jnp.sort(new_indices, axis=1) + + unique_triangles_indices = jnp.unique( + new_indices_sorted, + axis=0, + size=new_indices_sorted.shape[0], + fill_value=jnp.array( + [-1, -1, -1], + dtype=jnp.int32, + ), + ) + + return unique_triangles_indices, unique_vertices diff --git a/autoarray/structures/triangles/array/__init__.py b/autoarray/structures/triangles/array/__init__.py deleted file mode 100644 index 0fade4b81..000000000 --- a/autoarray/structures/triangles/array/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .array import ArrayTriangles - -try: - from .jax_array import ArrayTriangles as JAXArrayTriangles -except ImportError: - pass diff --git a/autoarray/structures/triangles/array/abstract_array.py b/autoarray/structures/triangles/array/abstract_array.py deleted file mode 100644 index d0f8620ee..000000000 --- a/autoarray/structures/triangles/array/abstract_array.py +++ /dev/null @@ -1,211 +0,0 @@ -from abc import abstractmethod - -import numpy as np - -from autoarray import Grid2D, AbstractTriangles -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR - - -class AbstractArrayTriangles(AbstractTriangles): - def __init__( - self, - indices, - vertices, - **kwargs, - ): - """ - Represents a set of triangles in efficient NumPy arrays. - - Parameters - ---------- - indices - The indices of the vertices of the triangles. This is a 2D array where each row is a triangle - with the three indices of the vertices. - vertices - The vertices of the triangles. - """ - self._indices = indices - self._vertices = vertices - - @property - def indices(self): - return self._indices - - @property - def vertices(self): - return self._vertices - - def __len__(self): - return len(self.triangles) - - @property - def area(self) -> float: - """ - The total area covered by the triangles. - """ - triangles = self.triangles - return ( - 0.5 - * np.abs( - (triangles[:, 0, 0] * (triangles[:, 1, 1] - triangles[:, 2, 1])) - + (triangles[:, 1, 0] * (triangles[:, 2, 1] - triangles[:, 0, 1])) - + (triangles[:, 2, 0] * (triangles[:, 0, 1] - triangles[:, 1, 1])) - ).sum() - ) - - @property - @abstractmethod - def numpy(self): - pass - - def _up_sample_triangle(self): - triangles = self.triangles - - m01 = (triangles[:, 0] + triangles[:, 1]) / 2 - m12 = (triangles[:, 1] + triangles[:, 2]) / 2 - m20 = (triangles[:, 2] + triangles[:, 0]) / 2 - - return self.numpy.concatenate( - [ - self.numpy.stack([triangles[:, 1], m12, m01], axis=1), - self.numpy.stack([triangles[:, 2], m20, m12], axis=1), - self.numpy.stack([m01, m12, m20], axis=1), - self.numpy.stack([triangles[:, 0], m01, m20], axis=1), - ], - axis=0, - ) - - def _neighborhood_triangles(self): - triangles = self.triangles - - new_v0 = triangles[:, 1] + triangles[:, 2] - triangles[:, 0] - new_v1 = triangles[:, 0] + triangles[:, 2] - triangles[:, 1] - new_v2 = triangles[:, 0] + triangles[:, 1] - triangles[:, 2] - - return self.numpy.concatenate( - [ - self.numpy.stack([new_v0, triangles[:, 1], triangles[:, 2]], axis=1), - self.numpy.stack([triangles[:, 0], new_v1, triangles[:, 2]], axis=1), - self.numpy.stack([triangles[:, 0], triangles[:, 1], new_v2], axis=1), - triangles, - ], - axis=0, - ) - - def __str__(self): - return f"{self.__class__.__name__} with {len(self.indices)} triangles" - - def __repr__(self): - return str(self) - - @classmethod - def for_limits_and_scale( - cls, - y_min: float, - y_max: float, - x_min: float, - x_max: float, - scale: float, - **kwargs, - ) -> "AbstractTriangles": - height = scale * HEIGHT_FACTOR - - vertices = [] - indices = [] - vertex_dict = {} - - def add_vertex(v): - if v not in vertex_dict: - vertex_dict[v] = len(vertices) - vertices.append(v) - return vertex_dict[v] - - rows = [] - for row_y in np.arange(y_min, y_max + height, height): - row = [] - offset = (len(rows) % 2) * scale / 2 - for col_x in np.arange(x_min - offset, x_max + scale, scale): - row.append((row_y, col_x)) - rows.append(row) - - for i in range(len(rows) - 1): - row = rows[i] - next_row = rows[i + 1] - for j in range(len(row)): - if i % 2 == 0 and j < len(next_row) - 1: - t1 = [ - add_vertex(row[j]), - add_vertex(next_row[j]), - add_vertex(next_row[j + 1]), - ] - if j < len(row) - 1: - t2 = [ - add_vertex(row[j]), - add_vertex(row[j + 1]), - add_vertex(next_row[j + 1]), - ] - indices.append(t2) - elif i % 2 == 1 and j < len(next_row) - 1: - t1 = [ - add_vertex(row[j]), - add_vertex(next_row[j]), - add_vertex(row[j + 1]), - ] - indices.append(t1) - if j < len(next_row) - 1: - t2 = [ - add_vertex(next_row[j]), - add_vertex(next_row[j + 1]), - add_vertex(row[j + 1]), - ] - indices.append(t2) - else: - continue - indices.append(t1) - - vertices = np.array(vertices) - indices = np.array(indices) - - return cls( - indices=indices, - vertices=vertices, - **kwargs, - ) - - @classmethod - def for_grid( - cls, - grid: Grid2D, - **kwargs, - ) -> "AbstractTriangles": - """ - Create a grid of equilateral triangles from a regular grid. - - Parameters - ---------- - grid - The regular grid to convert to a grid of triangles. - - Returns - ------- - The grid of triangles. - """ - - scale = grid.pixel_scale - - y = grid[:, 0] - x = grid[:, 1] - - y_min = y.min() - y_max = y.max() - x_min = x.min() - x_max = x.max() - - return cls.for_limits_and_scale( - y_min, - y_max, - x_min, - x_max, - scale, - **kwargs, - ) diff --git a/autoarray/structures/triangles/array/array.py b/autoarray/structures/triangles/array/array.py deleted file mode 100644 index 06bb5dc89..000000000 --- a/autoarray/structures/triangles/array/array.py +++ /dev/null @@ -1,123 +0,0 @@ -from abc import ABC - -import numpy as np - -from autoarray.structures.triangles.array.abstract_array import AbstractArrayTriangles -from autoarray.structures.triangles.shape import Shape - - -class ArrayTriangles(AbstractArrayTriangles, ABC): - @property - def triangles(self): - return self.vertices[self.indices] - - @property - def numpy(self): - return np - - @property - def means(self): - return np.mean(self.triangles, axis=1) - - def containing_indices(self, shape: Shape) -> np.ndarray: - """ - Find the triangles that insect with a given shape. - - Parameters - ---------- - shape - The shape - - Returns - ------- - The triangles that intersect the shape. - """ - inside = shape.mask(self.triangles) - - return np.where(inside)[0] - - def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": - """ - Create a new ArrayTriangles containing indices and vertices corresponding to the given indexes - but without duplicate vertices. - - Parameters - ---------- - indexes - The indexes of the triangles to include in the new ArrayTriangles. - - Returns - ------- - The new ArrayTriangles instance. - """ - selected_indices = self.indices[indexes] - - flat_indices = selected_indices.flatten() - unique_vertices, inverse_indices = np.unique( - self.vertices[flat_indices], axis=0, return_inverse=True - ) - - new_indices = inverse_indices.reshape(selected_indices.shape) - - return ArrayTriangles(indices=new_indices, vertices=unique_vertices) - - def up_sample(self) -> "ArrayTriangles": - """ - Up-sample the triangles by adding a new vertex at the midpoint of each edge. - - This means each triangle becomes four smaller triangles. - """ - unique_vertices, inverse_indices = np.unique( - self._up_sample_triangle().reshape(-1, 2), axis=0, return_inverse=True - ) - new_indices = inverse_indices.reshape(-1, 3) - - return ArrayTriangles( - indices=new_indices, - vertices=unique_vertices, - ) - - def neighborhood(self) -> "ArrayTriangles": - """ - Create a new set of triangles that are the neighborhood of the current triangles. - - Includes the current triangles and the triangles that share an edge with the current triangles. - """ - unique_vertices, inverse_indices = np.unique( - self._neighborhood_triangles().reshape(-1, 2), - axis=0, - return_inverse=True, - ) - new_indices = inverse_indices.reshape(-1, 3) - - new_indices_sorted = np.sort(new_indices, axis=1) - - unique_triangles_indices, unique_index_positions = np.unique( - new_indices_sorted, axis=0, return_index=True - ) - - return ArrayTriangles( - indices=unique_triangles_indices, - vertices=unique_vertices, - ) - - def with_vertices(self, vertices: np.ndarray) -> "ArrayTriangles": - """ - Create a new set of triangles with the vertices replaced. - - Parameters - ---------- - vertices - The new vertices to use. - - Returns - ------- - The new set of triangles with the new vertices. - """ - return ArrayTriangles( - indices=self.indices, - vertices=vertices, - ) - - def __iter__(self): - return iter(self.triangles) diff --git a/autoarray/structures/triangles/array/jax_array.py b/autoarray/structures/triangles/array/jax_array.py deleted file mode 100644 index 23b9ad3b5..000000000 --- a/autoarray/structures/triangles/array/jax_array.py +++ /dev/null @@ -1,289 +0,0 @@ -from jax import numpy as np -from jax.tree_util import register_pytree_node_class - -from autoarray.structures.triangles.abstract import AbstractTriangles -from autoarray.structures.triangles.array.abstract_array import AbstractArrayTriangles -from autoarray.structures.triangles.shape import Shape - -MAX_CONTAINING_SIZE = 15 - - -@register_pytree_node_class -class ArrayTriangles(AbstractArrayTriangles): - def __init__( - self, - indices, - vertices, - max_containing_size=MAX_CONTAINING_SIZE, - ): - super().__init__(indices, vertices) - self.max_containing_size = max_containing_size - - @property - def numpy(self): - return np - - @property - def triangles(self) -> np.ndarray: - """ - The triangles as a 3x2 array of vertices. - """ - - invalid_mask = np.any(self.indices == -1, axis=1) - nan_array = np.full( - (self.indices.shape[0], 3, 2), - np.nan, - dtype=np.float32, - ) - safe_indices = np.where(self.indices == -1, 0, self.indices) - triangle_vertices = self.vertices[safe_indices] - return np.where(invalid_mask[:, None, None], nan_array, triangle_vertices) - - @property - def means(self) -> np.ndarray: - """ - The mean of each triangle. - """ - return np.mean(self.triangles, axis=1) - - def containing_indices(self, shape: Shape) -> np.ndarray: - """ - Find the triangles that insect with a given shape. - - Parameters - ---------- - shape - The shape - - Returns - ------- - The triangles that intersect the shape. - """ - inside = shape.mask(self.triangles) - - return np.where( - inside, - size=self.max_containing_size, - fill_value=-1, - )[0] - - def for_indexes(self, indexes: np.ndarray) -> "ArrayTriangles": - """ - Create a new ArrayTriangles containing indices and vertices corresponding to the given indexes - but without duplicate vertices. - - Parameters - ---------- - indexes - The indexes of the triangles to include in the new ArrayTriangles. - - Returns - ------- - The new ArrayTriangles instance. - """ - selected_indices = select_and_handle_invalid( - data=self.indices, - indices=indexes, - invalid_value=-1, - invalid_replacement=np.array([-1, -1, -1], dtype=np.int32), - ) - - flat_indices = selected_indices.flatten() - - selected_vertices = select_and_handle_invalid( - data=self.vertices, - indices=flat_indices, - invalid_value=-1, - invalid_replacement=np.array([np.nan, np.nan], dtype=np.float32), - ) - - unique_vertices, inv_indices = np.unique( - selected_vertices, - axis=0, - return_inverse=True, - equal_nan=True, - size=selected_indices.shape[0] * 3, - fill_value=np.nan, - ) - - nan_mask = np.isnan(unique_vertices).any(axis=1) - inv_indices = np.where(nan_mask[inv_indices], -1, inv_indices) - - new_indices = inv_indices.reshape(selected_indices.shape) - - new_indices_sorted = np.sort(new_indices, axis=1) - - unique_triangles_indices = np.unique( - new_indices_sorted, - axis=0, - size=new_indices_sorted.shape[0], - fill_value=-1, - ) - - return ArrayTriangles( - indices=unique_triangles_indices, - vertices=unique_vertices, - max_containing_size=self.max_containing_size, - ) - - def up_sample(self) -> "ArrayTriangles": - """ - Up-sample the triangles by adding a new vertex at the midpoint of each edge. - - This means each triangle becomes four smaller triangles. - """ - new_indices, unique_vertices = remove_duplicates(self._up_sample_triangle()) - - return ArrayTriangles( - indices=new_indices, - vertices=unique_vertices, - max_containing_size=self.max_containing_size, - ) - - def neighborhood(self) -> "ArrayTriangles": - """ - Create a new set of triangles that are the neighborhood of the current triangles. - - Includes the current triangles and the triangles that share an edge with the current triangles. - """ - new_indices, unique_vertices = remove_duplicates(self._neighborhood_triangles()) - - return ArrayTriangles( - indices=new_indices, - vertices=unique_vertices, - max_containing_size=self.max_containing_size, - ) - - def with_vertices(self, vertices: np.ndarray) -> "ArrayTriangles": - """ - Create a new set of triangles with the vertices replaced. - - Parameters - ---------- - vertices - The new vertices to use. - - Returns - ------- - The new set of triangles with the new vertices. - """ - return ArrayTriangles( - indices=self.indices, - vertices=vertices, - max_containing_size=self.max_containing_size, - ) - - def __iter__(self): - return iter(self.triangles) - - def tree_flatten(self): - """ - Flatten this model as a PyTree. - """ - return ( - self.indices, - self.vertices, - ), (self.max_containing_size,) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """ - Unflatten a PyTree into a model. - """ - return cls( - indices=children[0], - vertices=children[1], - max_containing_size=aux_data[0], - ) - - @classmethod - def for_limits_and_scale( - cls, - y_min: float, - y_max: float, - x_min: float, - x_max: float, - scale: float, - max_containing_size=MAX_CONTAINING_SIZE, - ) -> "AbstractTriangles": - triangles = super().for_limits_and_scale( - y_min, - y_max, - x_min, - x_max, - scale, - ) - return cls( - indices=np.array(triangles.indices), - vertices=np.array(triangles.vertices), - max_containing_size=max_containing_size, - ) - - -def select_and_handle_invalid( - data: np.ndarray, - indices: np.ndarray, - invalid_value, - invalid_replacement, -): - """ - Select data based on indices, handling invalid indices by replacing them with a specified value. - - Parameters - ---------- - data - The array from which to select data. - indices - The indices used to select data from the array. - invalid_value - The value representing invalid indices. - invalid_replacement - The value to use for invalid entries in the result. - - Returns - ------- - An array with selected data, where invalid indices are replaced with `invalid_replacement`. - """ - invalid_mask = indices == invalid_value - safe_indices = np.where(invalid_mask, 0, indices) - selected_data = data[safe_indices] - selected_data = np.where( - invalid_mask[..., None], - invalid_replacement, - selected_data, - ) - - return selected_data - - -def remove_duplicates(new_triangles): - unique_vertices, inverse_indices = np.unique( - new_triangles.reshape(-1, 2), - axis=0, - return_inverse=True, - size=2 * new_triangles.shape[0], - fill_value=np.nan, - equal_nan=True, - ) - - inverse_indices_flat = inverse_indices.reshape(-1) - selected_vertices = unique_vertices[inverse_indices_flat] - mask = np.any(np.isnan(selected_vertices), axis=1) - inverse_indices_flat = np.where(mask, -1, inverse_indices_flat) - inverse_indices = inverse_indices_flat.reshape(inverse_indices.shape) - - new_indices = inverse_indices.reshape(-1, 3) - - new_indices_sorted = np.sort(new_indices, axis=1) - - unique_triangles_indices = np.unique( - new_indices_sorted, - axis=0, - size=new_indices_sorted.shape[0], - fill_value=np.array( - [-1, -1, -1], - dtype=np.int32, - ), - ) - - return unique_triangles_indices, unique_vertices diff --git a/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py b/autoarray/structures/triangles/coordinate_array.py similarity index 52% rename from autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py rename to autoarray/structures/triangles/coordinate_array.py index e80facd1f..c919ffc86 100644 --- a/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array.py @@ -1,11 +1,12 @@ -from jax import numpy as np +from abc import ABC + +import numpy as np +import jax.numpy as jnp import jax from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.coordinate_array.abstract_coordinate_array import ( - AbstractCoordinateArray, -) -from autoarray.structures.triangles.array.jax_array import ArrayTriangles +from autoarray.structures.triangles.abstract import AbstractTriangles +from autoarray.structures.triangles.array import ArrayTriangles from autoarray.numpy_wrapper import register_pytree_node_class from autoconf import cached_property @@ -13,10 +14,39 @@ @register_pytree_node_class -class CoordinateArrayTriangles(AbstractCoordinateArray): - @property - def numpy(self): - return jax.numpy +class CoordinateArrayTriangles(AbstractTriangles, ABC): + + def __init__( + self, + coordinates: np.ndarray, + side_length: float = 1.0, + x_offset: float = 0.0, + y_offset: float = 0.0, + flipped: bool = False, + ): + """ + Represents a set of triangles by integer coordinates. + + Parameters + ---------- + coordinates + Integer x y coordinates for each triangle. + side_length + The side length of the triangles. + flipped + Whether the triangles are flipped upside down. + y_offset + An y_offset to apply to the y coordinates so that up-sampled triangles align. + """ + self.coordinates = coordinates + self.side_length = side_length + self.flipped = flipped + + self.scaling_factors = jnp.array( + [0.5 * side_length, HEIGHT_FACTOR * side_length] + ) + self.x_offset = x_offset + self.y_offset = y_offset @classmethod def for_limits_and_scale( @@ -38,7 +68,7 @@ def for_limits_and_scale( coordinates.append([x, y]) return cls( - coordinates=np.array(coordinates), + coordinates=jnp.array(coordinates), side_length=scale, ) @@ -69,18 +99,66 @@ def tree_unflatten(cls, aux_data, children): """ return cls(*children, flipped=aux_data[0]) + def __len__(self): + return jnp.count_nonzero(~jnp.isnan(self.coordinates).any(axis=1)) + + def __iter__(self): + return iter(self.triangles) + @property - def centres(self) -> np.ndarray: + def centres(self) -> jnp.ndarray: """ The centres of the triangles. """ - centres = self.scaling_factors * self.coordinates + np.array( + centres = self.scaling_factors * self.coordinates + jnp.array( [self.x_offset, self.y_offset] ) return centres @cached_property - def flip_mask(self) -> np.ndarray: + def vertex_coordinates(self) -> np.ndarray: + """ + The vertices of the triangles as an Nx3x2 array. + """ + coordinates = self.coordinates + return jnp.concatenate( + [ + coordinates + self.flip_array * np.array([0, 1], dtype=np.int32), + coordinates + self.flip_array * np.array([1, -1], dtype=np.int32), + coordinates + self.flip_array * np.array([-1, -1], dtype=np.int32), + ], + dtype=np.int32, + ) + + @cached_property + def triangles(self) -> np.ndarray: + """ + The vertices of the triangles as an Nx3x2 array. + """ + centres = self.centres + return jnp.stack( + ( + centres + + self.flip_array + * jnp.array( + [0.0, 0.5 * self.side_length * HEIGHT_FACTOR], + ), + centres + + self.flip_array + * jnp.array( + [0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] + ), + centres + + self.flip_array + * jnp.array( + [-0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] + ), + ), + axis=1, + ) + + @cached_property + def flip_mask(self) -> jnp.ndarray: """ A mask for the triangles that are flipped. @@ -92,16 +170,13 @@ def flip_mask(self) -> np.ndarray: return mask @cached_property - def flip_array(self) -> np.ndarray: + def flip_array(self) -> jnp.ndarray: """ An array of 1s and -1s to flip the triangles. """ - array = np.where(self.flip_mask, -1, 1) + array = jnp.where(self.flip_mask, -1, 1) return array[:, None] - def __iter__(self): - return iter(self.triangles) - def up_sample(self) -> "CoordinateArrayTriangles": """ Up-sample the triangles by adding a new vertex at the midpoint of each edge. @@ -113,11 +188,11 @@ def up_sample(self) -> "CoordinateArrayTriangles": n = coordinates.shape[0] - shift0 = np.zeros((n, 2)) - shift3 = np.tile(np.array([0, 1]), (n, 1)) - shift1 = np.stack([np.ones(n), np.where(flip_mask, 1, 0)], axis=1) - shift2 = np.stack([-np.ones(n), np.where(flip_mask, 1, 0)], axis=1) - shifts = np.stack([shift0, shift1, shift2, shift3], axis=1) + shift0 = jnp.zeros((n, 2)) + shift3 = jnp.tile(jnp.array([0, 1]), (n, 1)) + shift1 = jnp.stack([jnp.ones(n), jnp.where(flip_mask, 1, 0)], axis=1) + shift2 = jnp.stack([-jnp.ones(n), jnp.where(flip_mask, 1, 0)], axis=1) + shifts = jnp.stack([shift0, shift1, shift2, shift3], axis=1) coordinates_expanded = coordinates[:, None, :] new_coordinates = coordinates_expanded + shifts @@ -140,27 +215,27 @@ def neighborhood(self) -> "CoordinateArrayTriangles": coordinates = self.coordinates flip_mask = self.flip_mask - shift0 = np.zeros((coordinates.shape[0], 2)) - shift1 = np.tile(np.array([1, 0]), (coordinates.shape[0], 1)) - shift2 = np.tile(np.array([-1, 0]), (coordinates.shape[0], 1)) - shift3 = np.where( + shift0 = jnp.zeros((coordinates.shape[0], 2)) + shift1 = jnp.tile(jnp.array([1, 0]), (coordinates.shape[0], 1)) + shift2 = jnp.tile(jnp.array([-1, 0]), (coordinates.shape[0], 1)) + shift3 = jnp.where( flip_mask[:, None], - np.tile(np.array([0, 1]), (coordinates.shape[0], 1)), - np.tile(np.array([0, -1]), (coordinates.shape[0], 1)), + jnp.tile(jnp.array([0, 1]), (coordinates.shape[0], 1)), + jnp.tile(jnp.array([0, -1]), (coordinates.shape[0], 1)), ) - shifts = np.stack([shift0, shift1, shift2, shift3], axis=1) + shifts = jnp.stack([shift0, shift1, shift2, shift3], axis=1) coordinates_expanded = coordinates[:, None, :] new_coordinates = coordinates_expanded + shifts new_coordinates = new_coordinates.reshape(-1, 2) expected_size = 4 * coordinates.shape[0] - unique_coords, indices = np.unique( + unique_coords, indices = jnp.unique( new_coordinates, axis=0, size=expected_size, - fill_value=np.nan, + fill_value=jnp.nan, return_index=True, ) @@ -175,22 +250,22 @@ def neighborhood(self) -> "CoordinateArrayTriangles": @cached_property def _vertices_and_indices(self): flat_triangles = self.triangles.reshape(-1, 2) - vertices, inverse_indices = np.unique( + vertices, inverse_indices = jnp.unique( flat_triangles, axis=0, return_inverse=True, size=3 * self.coordinates.shape[0], equal_nan=True, - fill_value=np.nan, + fill_value=jnp.nan, ) - nan_mask = np.isnan(vertices).any(axis=1) - inverse_indices = np.where(nan_mask[inverse_indices], -1, inverse_indices) + nan_mask = jnp.isnan(vertices).any(axis=1) + inverse_indices = jnp.where(nan_mask[inverse_indices], -1, inverse_indices) indices = inverse_indices.reshape(-1, 3) return vertices, indices - def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles: + def with_vertices(self, vertices: jnp.ndarray) -> ArrayTriangles: """ Create a new set of triangles with the vertices replaced. @@ -208,7 +283,7 @@ def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles: vertices=vertices, ) - def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": + def for_indexes(self, indexes: jnp.ndarray) -> "CoordinateArrayTriangles": """ Create a new CoordinateArrayTriangles containing triangles corresponding to the given indexes @@ -222,9 +297,9 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": The new CoordinateArrayTriangles instance. """ mask = indexes == -1 - safe_indexes = np.where(mask, 0, indexes) - coordinates = np.take(self.coordinates, safe_indexes, axis=0) - coordinates = np.where(mask[:, None], np.nan, coordinates) + safe_indexes = jnp.where(mask, 0, indexes) + coordinates = jnp.take(self.coordinates, safe_indexes, axis=0) + coordinates = jnp.where(mask[:, None], jnp.nan, coordinates) return CoordinateArrayTriangles( coordinates=coordinates, @@ -234,5 +309,24 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": flipped=self.flipped, ) - def containing_indices(self, shape: np.ndarray) -> np.ndarray: - raise NotImplementedError("JAX ArrayTriangles are used for this method.") + @property + def vertices(self) -> np.ndarray: + """ + The unique vertices of the triangles. + """ + return self._vertices_and_indices[0] + + @property + def indices(self) -> np.ndarray: + """ + The indices of the vertices of the triangles. + """ + return self._vertices_and_indices[1] + + @property + def means(self): + return jnp.mean(self.triangles, axis=1) + + @property + def area(self): + return (3**0.5 / 4 * self.side_length**2) * len(self) diff --git a/autoarray/structures/triangles/coordinate_array/__init__.py b/autoarray/structures/triangles/coordinate_array/__init__.py deleted file mode 100644 index f70bc8a9a..000000000 --- a/autoarray/structures/triangles/coordinate_array/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .coordinate_array import CoordinateArrayTriangles - -try: - from .jax_coordinate_array import ( - CoordinateArrayTriangles as JAXCoordinateArrayTriangles, - ) -except ImportError: - pass diff --git a/autoarray/structures/triangles/coordinate_array/abstract_coordinate_array.py b/autoarray/structures/triangles/coordinate_array/abstract_coordinate_array.py deleted file mode 100644 index 5c0fe799c..000000000 --- a/autoarray/structures/triangles/coordinate_array/abstract_coordinate_array.py +++ /dev/null @@ -1,162 +0,0 @@ -from abc import abstractmethod, ABC - -import numpy as np - -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR, AbstractTriangles -from autoconf import cached_property - - -class AbstractCoordinateArray(AbstractTriangles, ABC): - def __init__( - self, - coordinates: np.ndarray, - side_length: float = 1.0, - x_offset: float = 0.0, - y_offset: float = 0.0, - flipped: bool = False, - ): - """ - Represents a set of triangles by integer coordinates. - - Parameters - ---------- - coordinates - Integer x y coordinates for each triangle. - side_length - The side length of the triangles. - flipped - Whether the triangles are flipped upside down. - y_offset - An y_offset to apply to the y coordinates so that up-sampled triangles align. - """ - self.coordinates = coordinates - self.side_length = side_length - self.flipped = flipped - - self.scaling_factors = self.numpy.array( - [0.5 * side_length, HEIGHT_FACTOR * side_length] - ) - self.x_offset = x_offset - self.y_offset = y_offset - - @property - @abstractmethod - def numpy(self): - pass - - @cached_property - def vertex_coordinates(self) -> np.ndarray: - """ - The vertices of the triangles as an Nx3x2 array. - """ - coordinates = self.coordinates - return self.numpy.concatenate( - [ - coordinates + self.flip_array * np.array([0, 1], dtype=np.int32), - coordinates + self.flip_array * np.array([1, -1], dtype=np.int32), - coordinates + self.flip_array * np.array([-1, -1], dtype=np.int32), - ], - dtype=np.int32, - ) - - @cached_property - def triangles(self) -> np.ndarray: - """ - The vertices of the triangles as an Nx3x2 array. - """ - centres = self.centres - return self.numpy.stack( - ( - centres - + self.flip_array - * self.numpy.array( - [0.0, 0.5 * self.side_length * HEIGHT_FACTOR], - ), - centres - + self.flip_array - * self.numpy.array( - [0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] - ), - centres - + self.flip_array - * self.numpy.array( - [-0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR] - ), - ), - axis=1, - ) - - @property - def centres(self) -> np.ndarray: - """ - The centres of the triangles. - """ - return self.scaling_factors * self.coordinates + self.numpy.array( - [self.x_offset, self.y_offset] - ) - - @cached_property - def flip_mask(self) -> np.ndarray: - """ - A mask for the triangles that are flipped. - - Every other triangle is flipped so that they tessellate. - """ - mask = (self.coordinates[:, 0] + self.coordinates[:, 1]) % 2 != 0 - if self.flipped: - mask = ~mask - return mask - - @cached_property - @abstractmethod - def flip_array(self) -> np.ndarray: - """ - An array of 1s and -1s to flip the triangles. - """ - - def __iter__(self): - return iter(self.triangles) - - @cached_property - @abstractmethod - def _vertices_and_indices(self): - pass - - @property - def vertices(self) -> np.ndarray: - """ - The unique vertices of the triangles. - """ - return self._vertices_and_indices[0] - - @property - def indices(self) -> np.ndarray: - """ - The indices of the vertices of the triangles. - """ - return self._vertices_and_indices[1] - - def with_vertices(self, vertices: np.ndarray) -> AbstractTriangles: - """ - Create a new set of triangles with the vertices replaced. - - Parameters - ---------- - vertices - The new vertices to use. - - Returns - ------- - The new set of triangles with the new vertices. - """ - - @property - def means(self): - return self.numpy.mean(self.triangles, axis=1) - - @property - def area(self): - return (3**0.5 / 4 * self.side_length**2) * len(self) - - def __len__(self): - return self.numpy.count_nonzero(~self.numpy.isnan(self.coordinates).any(axis=1)) diff --git a/autoarray/structures/triangles/coordinate_array/coordinate_array.py b/autoarray/structures/triangles/coordinate_array/coordinate_array.py deleted file mode 100644 index 997c8ab7f..000000000 --- a/autoarray/structures/triangles/coordinate_array/coordinate_array.py +++ /dev/null @@ -1,188 +0,0 @@ -import numpy as np - -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.coordinate_array.abstract_coordinate_array import ( - AbstractCoordinateArray, -) -from autoarray.structures.triangles.array import ArrayTriangles -from autoarray.structures.triangles.shape import Shape -from autoconf import cached_property - - -class CoordinateArrayTriangles(AbstractCoordinateArray): - @cached_property - def flip_array(self) -> np.ndarray: - """ - An array of 1s and -1s to flip the triangles. - """ - array = np.ones( - self.coordinates.shape[0], - dtype=np.int32, - ) - array[self.flip_mask] = -1 - - return array[:, np.newaxis] - - @property - def numpy(self): - return np - - @classmethod - def for_limits_and_scale( - cls, - x_min: float, - x_max: float, - y_min: float, - y_max: float, - scale: float = 1.0, - **_, - ): - x_shift = int(2 * x_min / scale) - y_shift = int(y_min / (HEIGHT_FACTOR * scale)) - - coordinates = [] - - for x in range(x_shift, int(2 * x_max / scale) + 1): - for y in range(y_shift - 1, int(y_max / (HEIGHT_FACTOR * scale)) + 2): - coordinates.append([x, y]) - - return cls( - coordinates=np.array(coordinates, dtype=np.int32), - side_length=scale, - ) - - def up_sample(self) -> "CoordinateArrayTriangles": - """ - Up-sample the triangles by adding a new vertex at the midpoint of each edge. - """ - new_coordinates = np.zeros( - (4 * self.coordinates.shape[0], 2), - dtype=np.int32, - ) - n_normal = 4 * np.sum(~self.flip_mask) - - new_coordinates[:n_normal] = np.vstack( - ( - 2 * self.coordinates[~self.flip_mask], - 2 * self.coordinates[~self.flip_mask] + np.array([1, 0]), - 2 * self.coordinates[~self.flip_mask] + np.array([-1, 0]), - 2 * self.coordinates[~self.flip_mask] + np.array([0, 1]), - ) - ) - new_coordinates[n_normal:] = np.vstack( - ( - 2 * self.coordinates[self.flip_mask], - 2 * self.coordinates[self.flip_mask] + np.array([1, 1]), - 2 * self.coordinates[self.flip_mask] + np.array([-1, 1]), - 2 * self.coordinates[self.flip_mask] + np.array([0, 1]), - ) - ) - - return CoordinateArrayTriangles( - coordinates=new_coordinates, - side_length=self.side_length / 2, - y_offset=self.y_offset + -0.25 * HEIGHT_FACTOR * self.side_length, - x_offset=self.x_offset, - flipped=True, - ) - - def neighborhood(self) -> "CoordinateArrayTriangles": - """ - Create a new set of triangles that are the neighborhood of the current triangles. - - Ensures that the new triangles are unique. - """ - new_coordinates = np.zeros( - (4 * self.coordinates.shape[0], 2), - dtype=np.int32, - ) - n_normal = 4 * np.sum(~self.flip_mask) - - new_coordinates[:n_normal] = np.vstack( - ( - self.coordinates[~self.flip_mask], - self.coordinates[~self.flip_mask] + np.array([1, 0]), - self.coordinates[~self.flip_mask] + np.array([-1, 0]), - self.coordinates[~self.flip_mask] + np.array([0, -1]), - ) - ) - new_coordinates[n_normal:] = np.vstack( - ( - self.coordinates[self.flip_mask], - self.coordinates[self.flip_mask] + np.array([1, 0]), - self.coordinates[self.flip_mask] + np.array([-1, 0]), - self.coordinates[self.flip_mask] + np.array([0, 1]), - ) - ) - return CoordinateArrayTriangles( - coordinates=np.unique(new_coordinates, axis=0), - side_length=self.side_length, - y_offset=self.y_offset, - x_offset=self.x_offset, - flipped=self.flipped, - ) - - @cached_property - def _vertices_and_indices(self): - flat_triangles = self.triangles.reshape(-1, 2) - vertices, inverse_indices = np.unique( - flat_triangles, - axis=0, - return_inverse=True, - ) - indices = inverse_indices.reshape(-1, 3) - return vertices, indices - - def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles: - """ - Create a new set of triangles with the vertices replaced. - - Parameters - ---------- - vertices - The new vertices to use. - - Returns - ------- - The new set of triangles with the new vertices. - """ - return ArrayTriangles( - indices=self.indices, - vertices=vertices, - ) - - def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles": - """ - Create a new CoordinateArrayTriangles containing triangles corresponding to the given indexes - - Parameters - ---------- - indexes - The indexes of the triangles to include in the new CoordinateArrayTriangles. - - Returns - ------- - The new CoordinateArrayTriangles instance. - """ - return CoordinateArrayTriangles( - coordinates=self.coordinates[indexes], - side_length=self.side_length, - y_offset=self.y_offset, - x_offset=self.x_offset, - flipped=self.flipped, - ) - - def containing_indices(self, shape: Shape) -> np.ndarray: - """ - Find the triangles that insect with a given shape. - - Parameters - ---------- - shape - The shape - - Returns - ------- - The indices of triangles that intersect the shape. - """ - return self.with_vertices(self.vertices).containing_indices(shape) diff --git a/autoarray/structures/vectors/irregular.py b/autoarray/structures/vectors/irregular.py index 5895149fb..b49ec3868 100644 --- a/autoarray/structures/vectors/irregular.py +++ b/autoarray/structures/vectors/irregular.py @@ -1,5 +1,6 @@ import logging import numpy as np +import jax.numpy as jnp from typing import List, Tuple, Union from autoarray.structures.vectors.abstract import AbstractVectorYX2D @@ -43,7 +44,6 @@ def __init__( grid The irregular grid of (y,x) coordinates where each vector is located. """ - if type(values) is list: values = np.asarray(values) @@ -120,13 +120,13 @@ def vectors_within_radius( squared_distances = self.grid.distances_to_coordinate_from(coordinate=centre) mask = squared_distances < radius - if np.all(mask == False): + if jnp.all(mask == False): raise exc.VectorYXException( "The input radius removed all vectors / points on the grid." ) return VectorYX2DIrregular( - values=self[mask], grid=Grid2DIrregular(self.grid[mask]) + values=jnp.array(self.array)[mask], grid=Grid2DIrregular(self.grid[mask]) ) def vectors_within_annulus( diff --git a/autoarray/structures/vectors/uniform.py b/autoarray/structures/vectors/uniform.py index 6213ecfbf..88dfe6a9c 100644 --- a/autoarray/structures/vectors/uniform.py +++ b/autoarray/structures/vectors/uniform.py @@ -137,9 +137,6 @@ def __init__( mapping large data arrays to and from the slim / native formats, which can be a computational bottleneck. """ - if type(values) is list: - values = np.asarray(values) - values = grid_2d_util.convert_grid_2d( grid_2d=values, mask_2d=mask, store_native=store_native ) @@ -396,7 +393,10 @@ def magnitudes(self) -> Array2D: """ Returns the magnitude of every vector which are computed as sqrt(y**2 + x**2). """ - return Array2D(values=jnp.sqrt(self.array[:, 0] ** 2.0 + self.array[:, 1] ** 2.0), mask=self.mask) + return Array2D( + values=jnp.sqrt(self.array[:, 0] ** 2.0 + self.array[:, 1] ** 2.0), + mask=self.mask, + ) @property def y(self) -> Array2D: diff --git a/autoarray/structures/visibilities.py b/autoarray/structures/visibilities.py index 2b4113c7c..abfb9713e 100644 --- a/autoarray/structures/visibilities.py +++ b/autoarray/structures/visibilities.py @@ -50,16 +50,8 @@ def __init__(self, visibilities: Union[np.ndarray, List[complex]]): .ravel() ) - self.ordered_1d = np.concatenate( - (np.real(visibilities), np.imag(visibilities)), axis=0 - ) - super().__init__(array=visibilities) - def __array_finalize__(self, obj): - if hasattr(obj, "ordered_1d"): - self.ordered_1d = obj.ordered_1d - @property def slim(self) -> "AbstractVisibilities": return self @@ -74,7 +66,7 @@ def in_array(self) -> np.ndarray: Returns the 1D complex NumPy array of values with shape [total_visibilities] as a NumPy float array of shape [total_visibilities, 2]. """ - return np.stack((np.real(self), np.imag(self)), axis=-1) + return np.stack((np.real(self.array), np.imag(self.array)), axis=-1) @property def in_grid(self) -> Grid2DIrregular: @@ -213,9 +205,7 @@ def __init__(self, visibilities: Union[np.ndarray, List[complex]], *args, **kwar A collection of (real, imag) visibilities noise-map values which are used to represent the noise-map in an `Interferometer` dataset. - This data structure behaves the same as the `Visibilities` structure (see `AbstractVisibilities.__new__`). The - only difference is that it includes a `WeightOperator` used by `Inversion`'s which use `LinearOperators` and - the library `PyLops` to fit `Interferometer` data. + This data structure behaves the same as the `Visibilities` structure (see `AbstractVisibilities.__new__`). Parameters ---------- @@ -234,20 +224,4 @@ def __init__(self, visibilities: Union[np.ndarray, List[complex]], *args, **kwar .ravel() ) - self.ordered_1d = np.concatenate( - (np.real(visibilities), np.imag(visibilities)), axis=0 - ) super().__init__(visibilities=visibilities) - - weight_list = 1.0 / self.in_array**2.0 - - self.weight_list_ordered_1d = np.concatenate( - (weight_list[:, 0], weight_list[:, 1]), axis=0 - ) - - def __array_finalize__(self, obj): - if hasattr(obj, "ordered_1d"): - self.ordered_1d = obj.ordered_1d - - if hasattr(obj, "weight_list_ordered_1d"): - self.weight_list_ordered_1d = obj.weight_list_ordered_1d diff --git a/autoarray/util/__init__.py b/autoarray/util/__init__.py index 46b5afb5b..a9ba1dfd3 100644 --- a/autoarray/util/__init__.py +++ b/autoarray/util/__init__.py @@ -11,12 +11,20 @@ from autoarray.layout import layout_util as layout from autoarray.fit import fit_util as fit from autoarray.inversion.pixelization.mesh import mesh_util as mesh +from autoarray.inversion.pixelization.mesh import mesh_numba_util as mesh_numba from autoarray.inversion.pixelization.mappers import mapper_util as mapper +from autoarray.inversion.pixelization.mappers import mapper_numba_util as mapper_numba from autoarray.inversion.regularization import regularization_util as regularization from autoarray.inversion.inversion import inversion_util as inversion from autoarray.inversion.inversion.imaging import ( inversion_imaging_util as inversion_imaging, ) +from autoarray.inversion.inversion.imaging import ( + inversion_imaging_util as inversion_imaging, +) +from autoarray.inversion.inversion.imaging import ( + inversion_imaging_numba_util as inversion_imaging_numba, +) from autoarray.inversion.inversion.interferometer import ( inversion_interferometer_util as inversion_interferometer, ) diff --git a/autoarray/util/cholesky_funcs.py b/autoarray/util/cholesky_funcs.py deleted file mode 100644 index bd211eeb5..000000000 --- a/autoarray/util/cholesky_funcs.py +++ /dev/null @@ -1,100 +0,0 @@ -import numpy as np -from scipy import linalg -import math -import time -from autoarray import numba_util - - -@numba_util.jit() -def _choldowndate(U, x): - n = x.size - for k in range(n - 1): - Ukk = U[k, k] - xk = x[k] - r = math.sqrt(Ukk**2 - xk**2) - c = r / Ukk - s = xk / Ukk - U[k, k] = r - U[k, k + 1 :] = (U[k, (k + 1) :] - s * x[k + 1 :]) / c - x[k + 1 :] = c * x[k + 1 :] - s * U[k, k + 1 :] - - k = n - 1 - U[k, k] = math.sqrt(U[k, k] ** 2 - x[k] ** 2) - return U - - -@numba_util.jit() -def _cholupdate(U, x): - n = x.size - for k in range(n - 1): - Ukk = U[k, k] - xk = x[k] - - r = np.sqrt(Ukk**2 + xk**2) - - c = r / Ukk - s = xk / Ukk - U[k, k] = r - - U[k, k + 1 :] = (U[k, (k + 1) :] + s * x[k + 1 :]) / c - x[k + 1 :] = c * x[k + 1 :] - s * U[k, k + 1 :] - - k = n - 1 - U[k, k] = np.sqrt(U[k, k] ** 2 + x[k] ** 2) - - return U - - -def cholinsert(U, index, x): - S = np.insert(np.insert(U, index, 0, axis=0), index, 0, axis=1) - - S[:index, index] = S12 = linalg.solve_triangular( - U[:index, :index], x[:index], trans=1, lower=False, overwrite_b=True - ) - - S[index, index] = s22 = math.sqrt(x[index] - S12.dot(S12)) - - if index == U.shape[0]: - return S - else: - S[index, index + 1 :] = S23 = (x[index + 1 :] - S12.T @ U[:index, index:]) / s22 - _choldowndate(S[index + 1 :, index + 1 :], S23) # S33 - return S - - -def cholinsertlast(U, x): - """ - Update the Cholesky matrix U by inserting a vector at the end of the matrix - Inserting a vector to the end of U doesn't require _cholupdate, so save some time. - It's a special case of `cholinsert` (as shown above, if index == U.shape[0]) - As in current Cholesky scheme implemented in fnnls, we only use this kind of insertion, so I - separate it out from the `cholinsert`. - """ - index = U.shape[0] - - S = np.insert(np.insert(U, index, 0, axis=0), index, 0, axis=1) - - S[:index, index] = S12 = linalg.solve_triangular( - U[:index, :index], x[:index], trans=1, lower=False, overwrite_b=True - ) - - S[index, index] = s22 = math.sqrt(x[index] - S12.dot(S12)) - - return S - - -def choldeleteindexes(U, indexes): - indexes = sorted(indexes, reverse=True) - - for index in indexes: - L = np.delete(np.delete(U, index, axis=0), index, axis=1) - - # If the deleted index is at the end of matrix, then we do not need to update the U. - - if index == L.shape[0]: - U = L - else: - _cholupdate(L[index:, index:], U[index, index + 1 :]) - U = L - - return U diff --git a/autoarray/util/fnnls.py b/autoarray/util/fnnls.py deleted file mode 100644 index 3f49c1f2d..000000000 --- a/autoarray/util/fnnls.py +++ /dev/null @@ -1,155 +0,0 @@ -import numpy as np -from scipy import linalg as slg - -from autoarray.util.cholesky_funcs import cholinsertlast, choldeleteindexes - -from autoarray import exc - -""" - This file contains functions use the Bro & Jong (1997) algorithm to solve the non-negative least - square problem. The `fnnls and fix_constraint` is orginally copied from - "https://github.com/jvendrow/fnnls". - For our purpose in PyAutoArray, we create `fnnls_modefied` to take ZTZ and ZTx as inputs directly. - Furthermore, we add two functions `fnnls_Cholesky and fix_constraint_Cholesky` to realize a scheme - that solves the lstsq problem in the algorithm by Cholesky factorisation. For ~ 1000 free - parameters, we see a speed up by 2 times and should be more for more parameters. - We have also noticed that by setting the P_initial to be `sla.solve(ZTZ, ZTx, assume_a='pos') > 0` - will speed up our task (~ 1000 free parameters) by ~ 3 times as it significantly reduces the - iteration time. -""" - - -def fnnls_cholesky( - ZTZ, - ZTx, - P_initial=np.zeros(0, dtype=int), -): - """ - Similar to fnnls, but use solving the lstsq problem by updating Cholesky factorisation. - """ - - lstsq = lambda A, x: slg.solve( - A, - x, - assume_a="pos", - overwrite_a=True, - overwrite_b=True, - ) - - n = np.shape(ZTZ)[0] - epsilon = 2.2204e-16 - tolerance = epsilon * n - max_repetitions = 3 - no_update = 0 - loop_count = 0 - loop_count2 = 0 - - P = np.zeros(n, dtype=bool) - P[P_initial] = True - d = np.zeros(n) - w = ZTx - (ZTZ) @ d - s_chol = np.zeros(n) - - if P_initial.shape[0] != 0: - P_number = np.arange(len(P), dtype="int") - P_inorder = P_number[P_initial] - s_chol[P] = lstsq((ZTZ)[P][:, P], (ZTx)[P]) - d = s_chol.clip(min=0) - else: - P_inorder = np.array([], dtype="int") - - # P_inorder is similar as P. They are both used to select solutions in the passive set. - # P_inorder saves the `indexes` of those passive solutions. - # P saves [True/False] for all solutions. True indicates a solution in the passive set while False - # indicates it's in the active set. - # The benifit of P_inorder is that we are able to not only select out solutions in the passive set - # and can sort them in the order of added to the passive set. This will make updating the - # Cholesky factorisation simpler and thus save time. - - while (not np.all(P)) and np.max(w[~P]) > tolerance: - # make copy of passive set to check for change at end of loop - - current_P = P.copy() - idmax = np.argmax(w * ~P) - P_inorder = np.append(P_inorder, int(idmax)) - - if loop_count == 0: - # We need to initialize the Cholesky factorisation, U, for the first loop. - U = slg.cholesky(ZTZ[P_inorder][:, P_inorder]) - else: - U = cholinsertlast(U, ZTZ[idmax][P_inorder]) - - # solve the lstsq problem by cho_solve - - s_chol[P_inorder] = slg.cho_solve((U, False), ZTx[P_inorder]) - - P[idmax] = True - while np.any(P) and np.min(s_chol[P]) <= tolerance: - s_chol, d, P, P_inorder, U = fix_constraint_cholesky( - ZTx=ZTx, - s_chol=s_chol, - d=d, - P=P, - P_inorder=P_inorder, - U=U, - tolerance=tolerance, - ) - - loop_count2 += 1 - if loop_count2 > 10000: - raise RuntimeError - - d = s_chol.copy() - w = ZTx - (ZTZ) @ d - loop_count += 1 - - if loop_count > 10000: - raise RuntimeError - - if np.all(current_P == P): - no_update += 1 - else: - no_update = 0 - - if no_update >= max_repetitions: - break - - return d - - -def fix_constraint_cholesky(ZTx, s_chol, d, P, P_inorder, U, tolerance): - """ - Similar to fix_constraint, but solve the lstsq by Cholesky factorisation. - If this function is called, it means some solutions in the current passive sets needed to be - taken out and put into the active set. - So, this function involves 3 procedure: - 1. Identifying what solutions should be taken out of the current passive set. - 2. Updating the P, P_inorder and the Cholesky factorisation U. - 3. Solving the lstsq by using the new Cholesky factorisation U. - As some solutions are taken out from the passive set, the Cholesky factorisation needs to be - updated by choldeleteindexes. To realize that, we call the `choldeleteindexes` from - cholesky_funcs. - """ - q = P * (s_chol <= tolerance) - alpha = np.min(d[q] / (d[q] - s_chol[q])) - - # set d as close to s as possible while maintaining non-negativity - d = d + alpha * (s_chol - d) - - id_delete = np.where(d[P_inorder] <= tolerance)[0] - - U = choldeleteindexes(U, id_delete) # update the Cholesky factorisation - - P_inorder = np.delete(P_inorder, id_delete) # update the P_inorder - - P[d <= tolerance] = False # update the P - - # solve the lstsq problem by cho_solve - - if len(P_inorder): - # there could be a case where P_inorder is empty. - s_chol[P_inorder] = slg.cho_solve((U, False), ZTx[P_inorder]) - - s_chol[~P] = 0.0 # set solutions taken out of the passive set to be 0 - - return s_chol, d, P, P_inorder, U diff --git a/optional_requirements.txt b/optional_requirements.txt deleted file mode 100644 index cbb64e17e..000000000 --- a/optional_requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -pylops>=1.10.0,<=2.3.1 -pynufft -#jax==0.4.3 -#jaxlib==0.4.3 -#numba diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..bc96b93ff --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,62 @@ +[build-system] +requires = ["setuptools>=79.0", "setuptools-scm", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "autoarray" +dynamic = ["version"] + description="PyAuto Data Structures" +readme = { file = "README.rst", content-type = "text/x-rst" } +license-files = [ + "LICENSE", +] +requires-python = ">=3.9" +authors = [ + { name = "James Nightingale", email = "James.Nightingale@newcastle.ac.uk" }, + { name = "Richard Hayes", email = "richard@rghsoftware.co.uk" }, +] +classifiers = [ + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Physics", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] +keywords = ["cli"] +dependencies = [ + "astropy>=5.0,<=6.1.2", + "decorator>=4.0.0", + "dill>=0.3.1.1", + "jaxnnls==1.0.1", + "matplotlib>=3.7.0", + "scipy<=1.14.0", + "scikit-image<=0.24.0", + "scikit-learn<=1.5.1" +] + +[project.urls] +Homepage = "https://github.com/Jammy2211/PyAutoArray" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +exclude = ["docs", "test_autoarray", "test_autoarray*"] + +[tool.setuptools_scm] +version_scheme = "post-release" +local_scheme = "no-local-version" + + +[project.optional-dependencies] +optional=[ + "pynufft" +] +test = ["pytest"] +dev = ["pytest", "black"] + +[tool.pytest.ini_options] +testpaths = ["test_autoarray"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index c4b499b41..000000000 --- a/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -astropy>=5.0,<=6.1.2 -decorator>=4.0.0 -dill>=0.3.1.1 -matplotlib>=3.7.0 -scipy<=1.14.0 -scikit-image<=0.24.0 -scikit-learn<=1.5.1 diff --git a/setup.py b/setup.py index bad6a2389..4c77d449f 100644 --- a/setup.py +++ b/setup.py @@ -1,58 +1,8 @@ import os -from codecs import open -from os import environ -from os.path import abspath, dirname, join - -from setuptools import find_packages, setup - -this_dir = abspath(dirname(__file__)) -with open(join(this_dir, "README.rst"), encoding="utf-8") as file: - long_description = file.read() - -with open(join(this_dir, "requirements.txt")) as f: - requirements = f.read().split("\n") - -version = environ.get("VERSION", "1.0.dev0") -requirements.extend([f"autoconf=={version}"]) - - -def config_packages(directory): - paths = [directory.replace("/", ".")] - for path, directories, filenames in os.walk(directory): - for directory in directories: - paths.append(f"{path}/{directory}".replace("/", ".")) - return paths +from setuptools import setup +version = os.environ.get("VERSION", "1.0.dev0") setup( - name="autoarray", version=version, - description="PyAuto Data Structures", - long_description=long_description, - long_description_content_type="text/x-rst", - url="https://github.com/Jammy2211/PyAutoArray", - author="James Nightingale and Richard Hayes", - author_email="james.w.nightingale@durham.ac.uk", - include_package_data=True, - license="MIT License", - classifiers=[ - "Intended Audience :: Science/Research", - "Topic :: Scientific/Engineering :: Physics", - "License :: OSI Approved :: MIT License", - "Natural Language :: English", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.2", - "Programming Language :: Python :: 3.3", - "Programming Language :: Python :: 3.4", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - ], - keywords="cli", - packages=find_packages(exclude=["docs", "test_autoarray", "test_autoarray*"]) - + config_packages("autoarray/config"), - install_requires=requirements, - setup_requires=["pytest-runner"], - tests_require=["pytest"], ) diff --git a/test_autoarray/config/visualize.yaml b/test_autoarray/config/visualize.yaml index 8934bb465..d631ae7e9 100644 --- a/test_autoarray/config/visualize.yaml +++ b/test_autoarray/config/visualize.yaml @@ -4,16 +4,6 @@ general: imshow_origin: upper zoom_around_mask: true disable_zoom_for_fits: true # If True, the zoom-in around the masked region is disabled when outputting .fits files, which is useful to retain the same dimensions as the input data. - include_2d: - border: false - mapper_image_plane_mesh_grid: false - mapper_source_plane_data_grid: false - mapper_source_plane_mesh_grid: false - mask: true - origin: true - parallel_overscan: true - serial_overscan: true - serial_prescan: true subplot_shape: 1: (1, 1) # The shape of subplots for a figure with 1 subplot. 2: (2, 2) # The shape of subplots for a figure with 2 subplots. @@ -28,21 +18,6 @@ general: 64: (8, 8) # The shape of subplots for a figure with 64 (or less than the above value) of subplots. 81: (9, 9) # The shape of subplots for a figure with 81 (or less than the above value) of subplots. 100: (10, 10) # The shape of subplots for a figure with 100 (or less than the above value) of subplots. -include: - include_1d: - mask: false - origin: false - include_2d: - border: false - mapper_image_plane_mesh_grid: false - mapper_source_plane_data_grid: false - mapper_source_plane_mesh_grid: false - mask: true - origin: true - parallel_overscan: true - positions: true - serial_overscan: false - serial_prescan: true mat_wrap: Axis: figure: diff --git a/test_autoarray/conftest.py b/test_autoarray/conftest.py index 657ac4b0d..b322cab3c 100644 --- a/test_autoarray/conftest.py +++ b/test_autoarray/conftest.py @@ -1,3 +1,10 @@ +import jax.numpy as jnp + + +def pytest_configure(): + _ = jnp.sum(jnp.array([0.0])) # Force backend init + + import os from os import path import pytest @@ -69,6 +76,11 @@ def make_array_2d_7x7(): return fixtures.make_array_2d_7x7() +@pytest.fixture(name="array_2d_rgb_7x7") +def make_array_2d_rgb_7x7(): + return fixtures.make_array_2d_rgb_7x7() + + @pytest.fixture(name="layout_2d_7x7") def make_layout_2d_7x7(): return fixtures.make_layout_2d_7x7() diff --git a/test_autoarray/dataset/abstract/test_dataset.py b/test_autoarray/dataset/abstract/test_dataset.py index dd99c1201..794af761c 100644 --- a/test_autoarray/dataset/abstract/test_dataset.py +++ b/test_autoarray/dataset/abstract/test_dataset.py @@ -32,9 +32,9 @@ def test__signal_to_noise_map(): dataset = ds.AbstractDataset(data=array, noise_map=noise_map) - assert ( - dataset.signal_to_noise_map.native == np.array([[0.1, 0.2], [0.1, 1.0]]) - ).all() + assert dataset.signal_to_noise_map.native == pytest.approx( + np.array([[0.1, 0.2], [0.1, 1.0]]), 1.0e-4 + ) assert dataset.signal_to_noise_max == 1.0 array = aa.Array2D.no_mask([[-1.0, 2.0], [3.0, -4.0]], pixel_scales=1.0) @@ -43,9 +43,9 @@ def test__signal_to_noise_map(): dataset = ds.AbstractDataset(data=array, noise_map=noise_map) - assert ( - dataset.signal_to_noise_map.native == np.array([[0.0, 0.2], [0.1, 0.0]]) - ).all() + assert dataset.signal_to_noise_map.native == pytest.approx( + np.array([[0.0, 0.2], [0.1, 0.0]]), 1.0e-4 + ) assert dataset.signal_to_noise_max == 0.2 @@ -115,15 +115,18 @@ def test__grid_settings__sub_size(image_7x7, noise_map_7x7): def test__new_imaging_with_arrays_trimmed_via_kernel_shape(): - data = aa.Array2D.full(fill_value=20.0, shape_native=(3, 3), pixel_scales=1.0) - data[4] = 5.0 - noise_map_array = aa.Array2D.full( - fill_value=5.0, shape_native=(3, 3), pixel_scales=1.0 + data = aa.Array2D.no_mask( + values=[[20.0, 20.0, 20.0], [20.0, 5.0, 20.0], [20.0, 20.0, 20.0]], + pixel_scales=1.0, + ) + + noise_map = aa.Array2D.no_mask( + values=[[20.0, 20.0, 20.0], [20.0, 2.0, 20.0], [20.0, 20.0, 20.0]], + pixel_scales=1.0, ) - noise_map_array[4] = 2.0 - dataset = ds.AbstractDataset(data=data, noise_map=noise_map_array) + dataset = ds.AbstractDataset(data=data, noise_map=noise_map) dataset_trimmed = dataset.trimmed_after_convolution_from(kernel_shape=(3, 3)) diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index ca33f1b40..ead9e51f5 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -78,7 +78,7 @@ def test__from_fits(): ) assert (dataset.data.native == np.ones((3, 3))).all() - assert (dataset.psf.native == (1.0 / 9.0) * np.ones((3, 3))).all() + assert dataset.psf.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4) assert (dataset.noise_map.native == 3.0 * np.ones((3, 3))).all() assert dataset.pixel_scales == (0.1, 0.1) @@ -96,7 +96,7 @@ def test__from_fits(): ) assert (dataset.data.native == np.ones((3, 3))).all() - assert (dataset.psf.native == (1.0 / 9.0) * np.ones((3, 3))).all() + assert dataset.psf.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4) assert (dataset.noise_map.native == 3.0 * np.ones((3, 3))).all() assert dataset.pixel_scales == (0.1, 0.1) @@ -105,6 +105,7 @@ def test__from_fits(): def test__output_to_fits(imaging_7x7, test_data_path): + imaging_7x7.output_to_fits( data_path=path.join(test_data_path, "data.fits"), psf_path=path.join(test_data_path, "psf.fits"), @@ -139,7 +140,9 @@ def test__apply_mask(imaging_7x7, mask_2d_7x7, psf_3x3): == 2.0 * np.ones((7, 7)) * np.invert(mask_2d_7x7) ).all() - assert masked_imaging_7x7.psf.slim == pytest.approx((1.0 / 3.0) * psf_3x3.slim, 1.0e-4) + assert masked_imaging_7x7.psf.slim == pytest.approx( + (1.0 / 3.0) * psf_3x3.slim.array, 1.0e-4 + ) assert type(masked_imaging_7x7.psf) == aa.Kernel2D assert masked_imaging_7x7.w_tilde.curvature_preload.shape == (35,) @@ -156,10 +159,21 @@ def test__apply_noise_scaling(imaging_7x7, mask_2d_7x7): assert masked_imaging_7x7.noise_map.native[4, 4] == 1e5 -def test__apply_noise_scaling__use_signal_to_noise_value(imaging_7x7, mask_2d_7x7): - imaging_7x7 = copy.copy(imaging_7x7) +def test__apply_noise_scaling__use_signal_to_noise_value( + image_7x7, psf_3x3, noise_map_7x7, mask_2d_7x7 +): + + image_7x7 = np.array(image_7x7.native.array) + image_7x7[3, 3] = 2.0 - imaging_7x7.data[24] = 2.0 + image_7x7 = aa.Array2D(values=image_7x7, mask=mask_2d_7x7) + + imaging_7x7 = aa.Imaging( + data=image_7x7, + psf=psf_3x3, + noise_map=noise_map_7x7, + over_sample_size_lp=1, + ) masked_imaging_7x7 = imaging_7x7.apply_noise_scaling( mask=mask_2d_7x7, signal_to_noise_value=0.1, should_zero_data=False @@ -248,5 +262,6 @@ def test__psf_not_odd_x_odd_kernel__raises_error(): noise_map = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) psf = aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) - dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=False) - + dataset = aa.Imaging( + data=image, noise_map=noise_map, psf=psf, pad_for_psf=False + ) diff --git a/test_autoarray/dataset/interferometer/test_simulator.py b/test_autoarray/dataset/interferometer/test_simulator.py index e45908cba..ea434c163 100644 --- a/test_autoarray/dataset/interferometer/test_simulator.py +++ b/test_autoarray/dataset/interferometer/test_simulator.py @@ -30,7 +30,7 @@ def test__from_image__setup_with_all_features_off( visibilities = transformer.visibilities_from(image=image) - assert dataset.data == pytest.approx(visibilities, 1.0e-4) + assert dataset.data == pytest.approx(visibilities.array, 1.0e-4) def test__setup_with_noise(uv_wavelengths_7x2, transformer_7x7_7): diff --git a/test_autoarray/dataset/test_preprocess.py b/test_autoarray/dataset/test_preprocess.py index f484fd648..8e99a74e8 100644 --- a/test_autoarray/dataset/test_preprocess.py +++ b/test_autoarray/dataset/test_preprocess.py @@ -133,15 +133,15 @@ def test__noise_map_from_image_exposure_time_map(): data_eps=image, exposure_time_map=exposure_time_map ) - assert ( - poisson_noise_map.native - == np.array( + assert poisson_noise_map.native == pytest.approx( + np.array( [ [np.sqrt(5.0), np.sqrt(6.0) / 2.0], [np.sqrt(30.0) / 3.0, np.sqrt(80.0) / 4.0], ] - ) - ).all() + ), + 1.0e-4, + ) def test__noise_map_from_image_exposure_time_map_and_background_noise_map(): @@ -493,10 +493,10 @@ def test__background_noise_map_via_edges_of_image_from_5(): def test__exposure_time_map_from_exposure_time_and_inverse_noise_map(): exposure_time = 6.0 - background_noise_map = aa.Array2D.full( - fill_value=0.25, shape_native=(3, 3), pixel_scales=1.0 + + background_noise_map = aa.Array2D.no_mask( + [[0.5, 0.25, 0.25], [0.25, 0.25, 0.25], [0.25, 0.25, 0.25]], pixel_scales=1.0 ) - background_noise_map[0] = 0.5 exposure_time_map = ( aa.preprocess.exposure_time_map_via_exposure_time_and_background_noise_map_from( @@ -556,6 +556,7 @@ def test__poisson_noise_from_data(): def test__data_with_poisson_noised_added(): data = aa.Array2D.zeros(shape_native=(2, 2), pixel_scales=1.0) exposure_time_map = aa.Array2D.ones(shape_native=(2, 2), pixel_scales=1.0) + data_with_poisson_noise = aa.preprocess.data_eps_with_poisson_noise_added( data_eps=data, exposure_time_map=exposure_time_map, seed=1 ) @@ -661,49 +662,32 @@ def test__data_with_complex_gaussian_noise_added(): def test__noise_map_with_signal_to_noise_limit_from(): - image = aa.Array2D.full(fill_value=20.0, shape_native=(2, 2), pixel_scales=1.0) - image[3] = 5.0 - noise_map_array = aa.Array2D.full( - fill_value=5.0, shape_native=(2, 2), pixel_scales=1.0 - ) - noise_map_array[3] = 2.0 + image = aa.Array2D.no_mask(values=[[20, 20], [20, 5]], pixel_scales=1.0) + noise_map = aa.Array2D.no_mask(values=[[5, 5], [5, 2]], pixel_scales=1.0) noise_map = aa.preprocess.noise_map_with_signal_to_noise_limit_from( - data=image, noise_map=noise_map_array, signal_to_noise_limit=100.0 + data=image, noise_map=noise_map, signal_to_noise_limit=100.0 ) assert (noise_map.slim == np.array([5.0, 5.0, 5.0, 2.0])).all() - image = aa.Array2D.full(fill_value=20.0, shape_native=(2, 2), pixel_scales=1.0) - image[3] = 5.0 - - noise_map_array = aa.Array2D.full( - fill_value=5.0, shape_native=(2, 2), pixel_scales=1.0 - ) - noise_map_array[3] = 2.0 + noise_map = aa.Array2D.no_mask(values=[[5, 5], [5, 2]], pixel_scales=1.0) noise_map = aa.preprocess.noise_map_with_signal_to_noise_limit_from( - data=image, noise_map=noise_map_array, signal_to_noise_limit=2.0 + data=image, noise_map=noise_map, signal_to_noise_limit=2.0 ) assert (noise_map.native == np.array([[10.0, 10.0], [10.0, 2.5]])).all() - image = aa.Array2D.full(fill_value=20.0, shape_native=(2, 2), pixel_scales=1.0) - image[2] = 5.0 - image[3] = 5.0 - - noise_map_array = aa.Array2D.full( - fill_value=5.0, shape_native=(2, 2), pixel_scales=1.0 - ) - noise_map_array[2] = 2.0 - noise_map_array[3] = 2.0 + image = aa.Array2D.no_mask(values=[[20, 20], [5, 5]], pixel_scales=1.0) + noise_map = aa.Array2D.no_mask(values=[[5, 5], [2, 2]], pixel_scales=1.0) mask = aa.Mask2D(mask=[[True, False], [False, True]], pixel_scales=1.0) noise_map = aa.preprocess.noise_map_with_signal_to_noise_limit_from( data=image, - noise_map=noise_map_array, + noise_map=noise_map, signal_to_noise_limit=2.0, noise_limit_mask=mask, ) diff --git a/test_autoarray/fit/plot/test_fit_imaging_plotters.py b/test_autoarray/fit/plot/test_fit_imaging_plotters.py index 4e5480ca7..22223ff61 100644 --- a/test_autoarray/fit/plot/test_fit_imaging_plotters.py +++ b/test_autoarray/fit/plot/test_fit_imaging_plotters.py @@ -19,7 +19,6 @@ def make_plot_path_setup(): def test__fit_quantities_are_output(fit_imaging_7x7, plot_path, plot_patch): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_7x7, - include_2d=aplt.Include2D(origin=True, mask=True, border=True), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -63,7 +62,6 @@ def test__fit_quantities_are_output(fit_imaging_7x7, plot_path, plot_patch): def test__fit_sub_plot(fit_imaging_7x7, plot_path, plot_patch): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_7x7, - include_2d=aplt.Include2D(origin=True, mask=True, border=True), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -77,7 +75,6 @@ def test__output_as_fits__correct_output_format( ): fit_plotter = aplt.FitImagingPlotter( fit=fit_imaging_7x7, - include_2d=aplt.Include2D(origin=True, mask=True, border=True), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="fits")), ) @@ -87,4 +84,4 @@ def test__output_as_fits__correct_output_format( file_path=path.join(plot_path, "data.fits"), hdu=0 ) - assert image_from_plot.shape == (7, 7) + assert image_from_plot.shape == (5, 5) diff --git a/test_autoarray/fit/test_fit_imaging.py b/test_autoarray/fit/test_fit_imaging.py index c0c525e71..c90a453ac 100644 --- a/test_autoarray/fit/test_fit_imaging.py +++ b/test_autoarray/fit/test_fit_imaging.py @@ -30,7 +30,7 @@ def test__data_and_model_are_identical__no_masking__check_values_are_correct(): assert fit.chi_squared == 0.0 assert fit.reduced_chi_squared == 0.0 - assert fit.noise_normalization == np.sum(np.log(2 * np.pi * noise_map**2.0)) + assert fit.noise_normalization == np.sum(np.log(2 * np.pi * noise_map.array**2.0)) assert fit.log_likelihood == -0.5 * (fit.chi_squared + fit.noise_normalization) @@ -59,7 +59,7 @@ def test__data_and_model_are_different__include_masking__check_values_are_correc assert fit.chi_squared == 0.25 assert fit.reduced_chi_squared == 0.25 / 3.0 - assert fit.noise_normalization == np.sum(np.log(2 * np.pi * noise_map**2.0)) + assert fit.noise_normalization == np.sum(np.log(2 * np.pi * noise_map.array**2.0)) assert fit.log_likelihood == -0.5 * (fit.chi_squared + fit.noise_normalization) @@ -92,7 +92,7 @@ def test__data_and_model_are_identical__inversion_included__changes_certain_prop assert fit.chi_squared == 0.0 assert fit.reduced_chi_squared == 0.0 - assert fit.noise_normalization == np.sum(np.log(2 * np.pi * noise_map**2.0)) + assert fit.noise_normalization == np.sum(np.log(2 * np.pi * noise_map.array**2.0)) assert fit.log_likelihood == -0.5 * (fit.chi_squared + fit.noise_normalization) assert fit.log_likelihood_with_regularization == -0.5 * ( @@ -102,28 +102,3 @@ def test__data_and_model_are_identical__inversion_included__changes_certain_prop fit.chi_squared + 2.0 + 3.0 - 4.0 + fit.noise_normalization ) assert fit.figure_of_merit == fit.log_evidence - - -def test__run_time_dict__profiles_appropriate_functions(): - mask = aa.Mask2D(mask=[[False, False], [False, False]], pixel_scales=(1.0, 1.0)) - - data = aa.Array2D(values=[1.0, 2.0, 3.0, 4.0], mask=mask) - noise_map = aa.Array2D(values=[2.0, 2.0, 2.0, 2.0], mask=mask) - - dataset = aa.Imaging(data=data, noise_map=noise_map) - - dataset = dataset.apply_mask(mask=mask) - - model_data = aa.Array2D(values=[1.0, 2.0, 3.0, 4.0], mask=mask) - - run_time_dict = {} - - fit = aa.m.MockFitImaging( - dataset=dataset, - use_mask_in_fit=False, - model_data=model_data, - run_time_dict=run_time_dict, - ) - fit.figure_of_merit - - assert "figure_of_merit_0" in fit.run_time_dict diff --git a/test_autoarray/fit/test_fit_util.py b/test_autoarray/fit/test_fit_util.py index 6bbb5f871..0702dc95a 100644 --- a/test_autoarray/fit/test_fit_util.py +++ b/test_autoarray/fit/test_fit_util.py @@ -1,41 +1,41 @@ -import autoarray as aa - -import jax.numpy as jnp +import numpy as np import pytest +import autoarray as aa + def test__residual_map_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + model_data = np.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) - assert (residual_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() + assert (residual_map == np.array([0.0, 0.0, 0.0, 0.0])).all() - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) - assert (residual_map == jnp.array([-1.0, 0.0, 1.0, 2.0])).all() + assert (residual_map == np.array([-1.0, 0.0, 1.0, 2.0])).all() def test__residual_map_with_mask_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - mask = jnp.array([True, False, False, True]) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + mask = np.array([True, False, False, True]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data ) - assert (residual_map == jnp.array([0.0, 0.0, 1.0, 0.0])).all() + assert (residual_map == np.array([0.0, 0.0, 1.0, 0.0])).all() def test__normalized_residual_map_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) - model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + noise_map = np.array([2.0, 2.0, 2.0, 2.0]) + model_data = np.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -43,9 +43,11 @@ def test__normalized_residual_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert normalized_residual_map == pytest.approx(jnp.array([0.0, 0.0, 0.0, 0.0]), 1.0e-4) + assert normalized_residual_map == pytest.approx( + np.array([0.0, 0.0, 0.0, 0.0]), 1.0e-4 + ) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -53,14 +55,16 @@ def test__normalized_residual_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert normalized_residual_map == pytest.approx(jnp.array([-(1.0 / 2.0), 0.0, (1.0 / 2.0), (2.0 / 2.0)]), 1.0e-4) + assert normalized_residual_map == pytest.approx( + np.array([-(1.0 / 2.0), 0.0, (1.0 / 2.0), (2.0 / 2.0)]), 1.0e-4 + ) def test__normalized_residual_map_with_mask_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - mask = jnp.array([True, False, False, True]) - noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + mask = np.array([True, False, False, True]) + noise_map = np.array([2.0, 2.0, 2.0, 2.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -70,15 +74,15 @@ def test__normalized_residual_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - print(normalized_residual_map) - - assert normalized_residual_map == pytest.approx(jnp.array([0.0, 0.0, (1.0 / 2.0), 0.0]), abs=1.0e-4) + assert normalized_residual_map == pytest.approx( + np.array([0.0, 0.0, (1.0 / 2.0), 0.0]), abs=1.0e-4 + ) def test__normalized_residual_map_complex_from(): - data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = jnp.array([2.0 + 2.0j, 2.0 + 2.0j]) - model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) + data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = np.array([2.0 + 2.0j, 2.0 + 2.0j]) + model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -86,13 +90,13 @@ def test__normalized_residual_map_complex_from(): residual_map=residual_map, noise_map=noise_map ) - assert (normalized_residual_map == jnp.array([0.5 - 1.0j, 0.5 - 1.0j])).all() + assert (normalized_residual_map == np.array([0.5 - 1.0j, 0.5 - 1.0j])).all() def test__chi_squared_map_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) - model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + noise_map = np.array([2.0, 2.0, 2.0, 2.0]) + model_data = np.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -100,9 +104,9 @@ def test__chi_squared_map_from(): residual_map=residual_map, noise_map=noise_map ) - assert (chi_squared_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() + assert (chi_squared_map == np.array([0.0, 0.0, 0.0, 0.0])).all() - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -112,15 +116,15 @@ def test__chi_squared_map_from(): assert ( chi_squared_map - == jnp.array([(1.0 / 2.0) ** 2.0, 0.0, (1.0 / 2.0) ** 2.0, (2.0 / 2.0) ** 2.0]) + == np.array([(1.0 / 2.0) ** 2.0, 0.0, (1.0 / 2.0) ** 2.0, (2.0 / 2.0) ** 2.0]) ).all() def test__chi_squared_map_with_mask_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - mask = jnp.array([True, False, False, True]) - noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + mask = np.array([True, False, False, True]) + noise_map = np.array([2.0, 2.0, 2.0, 2.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -130,9 +134,9 @@ def test__chi_squared_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - assert (chi_squared_map == jnp.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() + assert (chi_squared_map == np.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -142,13 +146,13 @@ def test__chi_squared_map_with_mask_from(): residual_map=residual_map, mask=mask, noise_map=noise_map ) - assert (chi_squared_map == jnp.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() + assert (chi_squared_map == np.array([0.0, 0.0, (1.0 / 2.0) ** 2.0, 0.0])).all() def test__chi_squared_map_complex_from(): - data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = jnp.array([2.0 + 2.0j, 2.0 + 2.0j]) - model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) + data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = np.array([2.0 + 2.0j, 2.0 + 2.0j]) + model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -156,13 +160,13 @@ def test__chi_squared_map_complex_from(): residual_map=residual_map, noise_map=noise_map ) - assert (chi_squared_map == jnp.array([0.25 + 1.0j, 0.25 + 1.0j])).all() + assert (chi_squared_map == np.array([0.25 + 1.0j, 0.25 + 1.0j])).all() def test__chi_squared_with_noise_covariance_from(): resdiual_map = aa.Array2D.no_mask([[1.0, 1.0], [2.0, 2.0]], pixel_scales=1.0) - noise_covariance_matrix_inv = jnp.array( + noise_covariance_matrix_inv = np.array( [ [1.0, 1.0, 4.0, 0.0], [0.0, 1.0, 9.0, 0.0], @@ -180,10 +184,10 @@ def test__chi_squared_with_noise_covariance_from(): def test__chi_squared_with_mask_fast_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - mask = jnp.array([True, False, False, True]) - noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + mask = np.array([True, False, False, True]) + noise_map = np.array([1.0, 2.0, 3.0, 4.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -206,10 +210,10 @@ def test__chi_squared_with_mask_fast_from(): assert chi_squared == pytest.approx(chi_squared_fast, 1.0e-4) - data = jnp.array([[10.0, 10.0], [10.0, 10.0]]) - mask = jnp.array([[True, False], [False, True]]) - noise_map = jnp.array([[1.0, 2.0], [3.0, 4.0]]) - model_data = jnp.array([[11.0, 10.0], [9.0, 8.0]]) + data = np.array([[10.0, 10.0], [10.0, 10.0]]) + mask = np.array([[True, False], [False, True]]) + noise_map = np.array([[1.0, 2.0], [3.0, 4.0]]) + model_data = np.array([[11.0, 10.0], [9.0, 8.0]]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -234,9 +238,9 @@ def test__chi_squared_with_mask_fast_from(): def test__log_likelihood_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - noise_map = jnp.array([2.0, 2.0, 2.0, 2.0]) - model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + noise_map = np.array([2.0, 2.0, 2.0, 2.0]) + model_data = np.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -254,17 +258,17 @@ def test__log_likelihood_from(): chi_squared = 0.0 noise_normalization = ( - jnp.log(2.0 * jnp.pi * (2.0**2.0)) - + jnp.log(2.0 * jnp.pi * (2.0**2.0)) - + jnp.log(2.0 * jnp.pi * (2.0**2.0)) - + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + np.log(2.0 * np.pi * (2.0**2.0)) + + np.log(2.0 * np.pi * (2.0**2.0)) + + np.log(2.0 * np.pi * (2.0**2.0)) + + np.log(2.0 * np.pi * (2.0**2.0)) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1.0e-4 ) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -287,17 +291,17 @@ def test__log_likelihood_from(): ((1.0 / 2.0) ** 2.0) + 0.0 + ((1.0 / 2.0) ** 2.0) + ((2.0 / 2.0) ** 2.0) ) noise_normalization = ( - jnp.log(2.0 * jnp.pi * (2.0**2.0)) - + jnp.log(2.0 * jnp.pi * (2.0**2.0)) - + jnp.log(2.0 * jnp.pi * (2.0**2.0)) - + jnp.log(2.0 * jnp.pi * (2.0**2.0)) + np.log(2.0 * np.pi * (2.0**2.0)) + + np.log(2.0 * np.pi * (2.0**2.0)) + + np.log(2.0 * np.pi * (2.0**2.0)) + + np.log(2.0 * np.pi * (2.0**2.0)) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1.0e-4 ) - noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) + noise_map = np.array([1.0, 2.0, 3.0, 4.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -317,10 +321,10 @@ def test__log_likelihood_from(): chi_squared = 1.0 + (1.0 / (3.0**2.0)) + 0.25 noise_normalization = ( - jnp.log(2 * jnp.pi * (1.0**2.0)) - + jnp.log(2 * jnp.pi * (2.0**2.0)) - + jnp.log(2 * jnp.pi * (3.0**2.0)) - + jnp.log(2 * jnp.pi * (4.0**2.0)) + np.log(2 * np.pi * (1.0**2.0)) + + np.log(2 * np.pi * (2.0**2.0)) + + np.log(2 * np.pi * (3.0**2.0)) + + np.log(2 * np.pi * (4.0**2.0)) ) assert log_likelihood == pytest.approx( @@ -329,10 +333,10 @@ def test__log_likelihood_from(): def test__log_likelihood_from__with_mask(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - mask = jnp.array([True, False, False, True]) - noise_map = jnp.array([1.0, 2.0, 3.0, 4.0]) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + mask = np.array([True, False, False, True]) + noise_map = np.array([1.0, 2.0, 3.0, 4.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -357,18 +361,18 @@ def test__log_likelihood_from__with_mask(): # chi squared = 0, 0.25, (0.25 and 1.0 are masked) chi_squared = 0.0 + (1.0 / 3.0) ** 2.0 - noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( - 2 * jnp.pi * (3.0**2.0) + noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( + 2 * np.pi * (3.0**2.0) ) assert log_likelihood == pytest.approx( -0.5 * (chi_squared + noise_normalization), 1e-4 ) - data = jnp.array([[10.0, 10.0], [10.0, 10.0]]) - mask = jnp.array([[True, False], [False, True]]) - noise_map = jnp.array([[1.0, 2.0], [3.0, 4.0]]) - model_data = jnp.array([[11.0, 10.0], [9.0, 8.0]]) + data = np.array([[10.0, 10.0], [10.0, 10.0]]) + mask = np.array([[True, False], [False, True]]) + noise_map = np.array([[1.0, 2.0], [3.0, 4.0]]) + model_data = np.array([[11.0, 10.0], [9.0, 8.0]]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -393,8 +397,8 @@ def test__log_likelihood_from__with_mask(): # chi squared = 0, 0.25, (0.25 and 1.0 are masked) chi_squared = 0.0 + (1.0 / 3.0) ** 2.0 - noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( - 2 * jnp.pi * (3.0**2.0) + noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( + 2 * np.pi * (3.0**2.0) ) assert log_likelihood == pytest.approx( @@ -403,9 +407,10 @@ def test__log_likelihood_from__with_mask(): def test__log_likelihood_from__complex_data(): - data = jnp.array([10.0 + 10.0j, 10.0 + 10.0j]) - noise_map = jnp.array([2.0 + 1.0j, 2.0 + 1.0j]) - model_data = jnp.array([9.0 + 12.0j, 9.0 + 12.0j]) + + data = np.array([10.0 + 10.0j, 10.0 + 10.0j]) + noise_map = np.array([2.0 + 1.0j, 2.0 + 1.0j]) + model_data = np.array([9.0 + 12.0j, 9.0 + 12.0j]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -426,8 +431,8 @@ def test__log_likelihood_from__complex_data(): # chi squared = 0.25 and 4.0 chi_squared = 4.25 - noise_normalization = jnp.log(2 * jnp.pi * (2.0**2.0)) + jnp.log( - 2 * jnp.pi * (1.0**2.0) + noise_normalization = np.log(2 * np.pi * (2.0**2.0)) + np.log( + 2 * np.pi * (1.0**2.0) ) assert log_likelihood == pytest.approx( @@ -456,8 +461,8 @@ def test__log_evidence_from(): def test__residual_flux_fraction_map_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - model_data = jnp.array([10.0, 10.0, 10.0, 10.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + model_data = np.array([10.0, 10.0, 10.0, 10.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -465,9 +470,9 @@ def test__residual_flux_fraction_map_from(): residual_map=residual_map, data=data ) - assert (residual_flux_fraction_map == jnp.array([0.0, 0.0, 0.0, 0.0])).all() + assert (residual_flux_fraction_map == np.array([0.0, 0.0, 0.0, 0.0])).all() - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) @@ -475,13 +480,13 @@ def test__residual_flux_fraction_map_from(): residual_map=residual_map, data=data ) - assert (residual_flux_fraction_map == jnp.array([-0.1, 0.0, 0.1, 0.2])).all() + assert (residual_flux_fraction_map == np.array([-0.1, 0.0, 0.1, 0.2])).all() def test__residual_flux_fraction_map_with_mask_from(): - data = jnp.array([10.0, 10.0, 10.0, 10.0]) - mask = jnp.array([True, False, False, True]) - model_data = jnp.array([11.0, 10.0, 9.0, 8.0]) + data = np.array([10.0, 10.0, 10.0, 10.0]) + mask = np.array([True, False, False, True]) + model_data = np.array([11.0, 10.0, 9.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -491,9 +496,9 @@ def test__residual_flux_fraction_map_with_mask_from(): residual_map=residual_map, mask=mask, data=data ) - assert (residual_flux_fraction_map == jnp.array([0.0, 0.0, 0.1, 0.0])).all() + assert (residual_flux_fraction_map == np.array([0.0, 0.0, 0.1, 0.0])).all() - model_data = jnp.array([11.0, 9.0, 8.0, 8.0]) + model_data = np.array([11.0, 9.0, 8.0, 8.0]) residual_map = aa.util.fit.residual_map_with_mask_from( data=data, mask=mask, model_data=model_data @@ -503,4 +508,4 @@ def test__residual_flux_fraction_map_with_mask_from(): residual_map=residual_map, mask=mask, data=data ) - assert (residual_flux_fraction_map == jnp.array([0.0, 0.1, 0.2, 0.0])).all() + assert (residual_flux_fraction_map == np.array([0.0, 0.1, 0.2, 0.0])).all() diff --git a/test_autoarray/inversion/inversion/imaging/test_imaging.py b/test_autoarray/inversion/inversion/imaging/test_imaging.py index df615ad85..46bfe55ac 100644 --- a/test_autoarray/inversion/inversion/imaging/test_imaging.py +++ b/test_autoarray/inversion/inversion/imaging/test_imaging.py @@ -11,17 +11,31 @@ directory = path.dirname(path.realpath(__file__)) -def test__operated_mapping_matrix_property(psf_7x7, rectangular_mapper_7x7_3x3): +def test__operated_mapping_matrix_property(psf_3x3, rectangular_mapper_7x7_3x3): + inversion = aa.m.MockInversionImaging( - psf=psf_7x7, linear_obj_list=[rectangular_mapper_7x7_3x3] + mask=rectangular_mapper_7x7_3x3.mapper_grids.mask, + psf=psf_3x3, + linear_obj_list=[rectangular_mapper_7x7_3x3], ) - assert inversion.operated_mapping_matrix_list[0][0, 0] == pytest.approx(1.0, 1e-4) - assert inversion.operated_mapping_matrix[0, 0] == pytest.approx(1.0, 1e-4) + assert inversion.operated_mapping_matrix_list[0][0, 0] == pytest.approx( + 1.61999997, 1e-4 + ) + assert inversion.operated_mapping_matrix[0, 0] == pytest.approx(1.61999997408, 1e-4) + mask = aa.Mask2D( + [ + [True, True, True, True], + [True, False, False, True], + [True, True, True, True], + ], + pixel_scales=1.0, + ) psf = aa.m.MockPSF(operated_mapping_matrix=np.ones((2, 2))) inversion = aa.m.MockInversionImaging( + mask=mask, psf=psf, linear_obj_list=[rectangular_mapper_7x7_3x3, rectangular_mapper_7x7_3x3], ) @@ -42,7 +56,7 @@ def test__operated_mapping_matrix_property(psf_7x7, rectangular_mapper_7x7_3x3): def test__operated_mapping_matrix_property__with_operated_mapping_matrix_override( - psf_7x7, rectangular_mapper_7x7_3x3 + psf_3x3, rectangular_mapper_7x7_3x3 ): psf = aa.m.MockPSF(operated_mapping_matrix=np.ones((2, 2))) @@ -54,7 +68,9 @@ def test__operated_mapping_matrix_property__with_operated_mapping_matrix_overrid ) inversion = aa.m.MockInversionImaging( - psf=psf, linear_obj_list=[rectangular_mapper_7x7_3x3, linear_obj] + mask=rectangular_mapper_7x7_3x3.mapper_grids.mask, + psf=psf, + linear_obj_list=[rectangular_mapper_7x7_3x3, linear_obj], ) operated_mapping_matrix_0 = np.array([[1.0, 1.0], [1.0, 1.0]]) @@ -85,7 +101,7 @@ def test__curvature_matrix(rectangular_mapper_7x7_3x3): ) dataset = aa.DatasetInterface( - data=np.ones(2), + data=aa.Array2D.ones(shape_native=(2, 10), pixel_scales=1.0), noise_map=noise_map, psf=psf, ) @@ -128,7 +144,13 @@ def test__w_tilde_checks_noise_map_and_raises_exception_if_preloads_dont_match_n grid = aa.Grid2D.from_mask(mask=mask) w_tilde = WTildeImaging( - curvature_preload=None, indexes=None, lengths=None, noise_map_value=2.0 + curvature_preload=None, + indexes=None, + lengths=None, + noise_map_value=2.0, + noise_map=None, + psf=None, + mask=mask, ) with pytest.raises(exc.InversionException): @@ -147,4 +169,5 @@ def test__w_tilde_checks_noise_map_and_raises_exception_if_preloads_dont_match_n mapping_matrix=np.ones(matrix_shape), source_plane_data_grid=grid ) ], + settings=aa.SettingsInversion(use_w_tilde=True), ) diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index e05e13350..cafb8722b 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -17,7 +17,7 @@ def test__w_tilde_imaging_from(): native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) - w_tilde = aa.util.inversion_imaging.w_tilde_curvature_imaging_from( + w_tilde = aa.util.inversion_imaging_numba.w_tilde_curvature_imaging_from( noise_map_native=noise_map_2d, kernel_native=kernel, native_index_for_slim_index=native_index_for_slim_index, @@ -87,7 +87,7 @@ def test__w_tilde_curvature_preload_imaging_from(): w_tilde_preload, w_tilde_indexes, w_tilde_lengths, - ) = aa.util.inversion_imaging.w_tilde_curvature_preload_imaging_from( + ) = aa.util.inversion_imaging_numba.w_tilde_curvature_preload_imaging_from( noise_map_native=noise_map_2d, kernel_native=kernel, native_index_for_slim_index=native_index_for_slim_index, @@ -183,7 +183,7 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): psf = kernel - pixelization = aa.mesh.Rectangular(shape=(20, 20)) + pixelization = aa.mesh.RectangularUniform(shape=(20, 20)) # TODO : Use pytest.parameterize @@ -204,40 +204,46 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): mapping_matrix = mapper.mapping_matrix blurred_mapping_matrix = psf.convolve_mapping_matrix( - mapping_matrix=mapping_matrix + mapping_matrix=mapping_matrix, mask=mask ) data_vector = ( aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=blurred_mapping_matrix, + blurred_mapping_matrix=np.array(blurred_mapping_matrix), image=np.array(image), noise_map=np.array(noise_map), ) ) w_tilde_data = aa.util.inversion_imaging.w_tilde_data_imaging_from( - image_native=np.array(image.native), - noise_map_native=np.array(noise_map.native), - kernel_native=np.array(kernel.native), - native_index_for_slim_index=mask.derive_indexes.native_for_slim, + image_native=np.array(image.native.array), + noise_map_native=np.array(noise_map.native.array), + kernel_native=np.array(kernel.native.array), + native_index_for_slim_index=np.array( + mask.derive_indexes.native_for_slim + ).astype("int"), ) ( data_to_pix_unique, data_weights, pix_lengths, - ) = aa.util.mapper.data_slim_to_pixelization_unique_from( + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( data_pixels=w_tilde_data.shape[0], - pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index, - pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=np.array( + mapper.pix_indexes_for_sub_slim_index + ), + pix_sizes_for_sub_slim_index=np.array(mapper.pix_sizes_for_sub_slim_index), + pix_weights_for_sub_slim_index=np.array( + mapper.pix_weights_for_sub_slim_index + ), pix_pixels=mapper.params, sub_size=np.array(grid.over_sample_size), ) data_vector_via_w_tilde = ( - aa.util.inversion_imaging.data_vector_via_w_tilde_data_imaging_from( - w_tilde_data=w_tilde_data, + aa.util.inversion_imaging_numba.data_vector_via_w_tilde_data_imaging_from( + w_tilde_data=np.array(w_tilde_data), data_to_pix_unique=data_to_pix_unique.astype("int"), data_weights=data_weights, pix_lengths=pix_lengths.astype("int"), @@ -260,7 +266,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): psf = kernel - pixelization = aa.mesh.Rectangular(shape=(20, 20)) + pixelization = aa.mesh.RectangularUniform(shape=(20, 20)) mapper_grids = pixelization.mapper_grids_from( mask=mask, @@ -272,17 +278,21 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): mapping_matrix = mapper.mapping_matrix - w_tilde = aa.util.inversion_imaging.w_tilde_curvature_imaging_from( - noise_map_native=np.array(noise_map.native), - kernel_native=np.array(kernel.native), - native_index_for_slim_index=mask.derive_indexes.native_for_slim, + w_tilde = aa.util.inversion_imaging_numba.w_tilde_curvature_imaging_from( + noise_map_native=np.array(noise_map.native.array), + kernel_native=np.array(kernel.native.array), + native_index_for_slim_index=np.array( + mask.derive_indexes.native_for_slim + ).astype("int"), ) curvature_matrix_via_w_tilde = aa.util.inversion.curvature_matrix_via_w_tilde_from( w_tilde=w_tilde, mapping_matrix=mapping_matrix ) - blurred_mapping_matrix = psf.convolve_mapping_matrix(mapping_matrix=mapping_matrix) + blurred_mapping_matrix = psf.convolve_mapping_matrix( + mapping_matrix=mapping_matrix, mask=mask + ) curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( mapping_matrix=blurred_mapping_matrix, @@ -303,7 +313,7 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): psf = kernel - pixelization = aa.mesh.Rectangular(shape=(20, 20)) + pixelization = aa.mesh.RectangularUniform(shape=(20, 20)) for sub_size in range(1, 2, 3): grid = aa.Grid2D.from_mask(mask=mask, over_sample_size=sub_size) @@ -325,26 +335,32 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): w_tilde_preload, w_tilde_indexes, w_tilde_lengths, - ) = aa.util.inversion_imaging.w_tilde_curvature_preload_imaging_from( - noise_map_native=np.array(noise_map.native), - kernel_native=np.array(kernel.native), - native_index_for_slim_index=mask.derive_indexes.native_for_slim, + ) = aa.util.inversion_imaging_numba.w_tilde_curvature_preload_imaging_from( + noise_map_native=np.array(noise_map.native.array), + kernel_native=np.array(kernel.native.array), + native_index_for_slim_index=np.array( + mask.derive_indexes.native_for_slim + ).astype("int"), ) ( data_to_pix_unique, data_weights, pix_lengths, - ) = aa.util.mapper.data_slim_to_pixelization_unique_from( + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( data_pixels=w_tilde_lengths.shape[0], - pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index, - pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, + pix_indexes_for_sub_slim_index=np.array( + mapper.pix_indexes_for_sub_slim_index + ), + pix_sizes_for_sub_slim_index=np.array(mapper.pix_sizes_for_sub_slim_index), + pix_weights_for_sub_slim_index=np.array( + mapper.pix_weights_for_sub_slim_index + ), pix_pixels=mapper.params, sub_size=np.array(grid.over_sample_size), ) - curvature_matrix_via_w_tilde = aa.util.inversion_imaging.curvature_matrix_via_w_tilde_curvature_preload_imaging_from( + curvature_matrix_via_w_tilde = aa.util.inversion_imaging_numba.curvature_matrix_via_w_tilde_curvature_preload_imaging_from( curvature_preload=w_tilde_preload, curvature_indexes=w_tilde_indexes.astype("int"), curvature_lengths=w_tilde_lengths.astype("int"), @@ -355,11 +371,12 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): ) blurred_mapping_matrix = psf.convolve_mapping_matrix( - mapping_matrix=mapping_matrix + mapping_matrix=mapping_matrix, + mask=mask, ) curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=blurred_mapping_matrix, + mapping_matrix=np.array(blurred_mapping_matrix), noise_map=np.array(noise_map), ) diff --git a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py index 8875ea378..96cad9eed 100644 --- a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py +++ b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py @@ -68,23 +68,6 @@ def test__data_vector_via_transformed_mapping_matrix_from(): assert (data_vector_complex_via_blurred == data_vector_via_transformed).all() -def test__inversion_interferometer__via_mapper( - interferometer_7_no_fft, - rectangular_mapper_7x7_3x3, - delaunay_mapper_9_3x3, - voronoi_mapper_9_3x3, - regularization_constant, -): - inversion = aa.Inversion( - dataset=interferometer_7_no_fft, - linear_obj_list=[rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_linear_operators=True), - ) - - assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangular) - assert isinstance(inversion, aa.InversionInterferometerMappingPyLops) - - def test__w_tilde_curvature_interferometer_from(): noise_map = np.array([1.0, 2.0, 3.0]) uv_wavelengths = np.array([[0.0001, 2.0, 3000.0], [3000.0, 2.0, 0.0001]]) @@ -265,13 +248,13 @@ def test__identical_inversion_values_for_two_methods(): inversion_w_tilde = aa.Inversion( dataset=dataset, linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_w_tilde=True), + settings=aa.SettingsInversion(use_w_tilde=True, use_positive_only_solver=True), ) inversion_mapping_matrices = aa.Inversion( dataset=dataset, linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_w_tilde=False), + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), ) assert (inversion_w_tilde.data == inversion_mapping_matrices.data).all() @@ -289,20 +272,24 @@ def test__identical_inversion_values_for_two_methods(): == inversion_mapping_matrices.regularization_matrix ).all() + assert inversion_w_tilde.data_vector == pytest.approx( + inversion_mapping_matrices.data_vector, 1.0e-8 + ) assert inversion_w_tilde.curvature_matrix == pytest.approx( inversion_mapping_matrices.curvature_matrix, 1.0e-8 ) assert inversion_w_tilde.curvature_reg_matrix == pytest.approx( inversion_mapping_matrices.curvature_reg_matrix, 1.0e-8 ) + assert inversion_w_tilde.reconstruction == pytest.approx( - inversion_mapping_matrices.reconstruction, 1.0e-2 + inversion_mapping_matrices.reconstruction, abs=1.0e-1 ) - assert inversion_w_tilde.mapped_reconstructed_image == pytest.approx( - inversion_mapping_matrices.mapped_reconstructed_image, 1.0e-2 + assert inversion_w_tilde.mapped_reconstructed_image.array == pytest.approx( + inversion_mapping_matrices.mapped_reconstructed_image.array, abs=1.0e-1 ) - assert inversion_w_tilde.mapped_reconstructed_data == pytest.approx( - inversion_mapping_matrices.mapped_reconstructed_data, 1.0e-2 + assert inversion_w_tilde.mapped_reconstructed_data.array == pytest.approx( + inversion_mapping_matrices.mapped_reconstructed_data.array, abs=1.0e-1 ) @@ -360,13 +347,17 @@ def test__identical_inversion_source_and_image_loops(): inversion_image_loop = aa.Inversion( dataset=dataset, linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_w_tilde=True, use_source_loop=False), + settings=aa.SettingsInversion( + use_w_tilde=True, use_source_loop=False, use_positive_only_solver=True + ), ) inversion_source_loop = aa.Inversion( dataset=dataset, linear_obj_list=[mapper], - settings=aa.SettingsInversion(use_w_tilde=True, use_source_loop=True), + settings=aa.SettingsInversion( + use_w_tilde=True, use_source_loop=True, use_positive_only_solver=True + ), ) assert (inversion_image_loop.data == inversion_source_loop.data).all() @@ -393,9 +384,9 @@ def test__identical_inversion_source_and_image_loops(): assert inversion_image_loop.reconstruction == pytest.approx( inversion_source_loop.reconstruction, 1.0e-2 ) - assert inversion_image_loop.mapped_reconstructed_image == pytest.approx( - inversion_source_loop.mapped_reconstructed_image, 1.0e-2 + assert inversion_image_loop.mapped_reconstructed_image.array == pytest.approx( + inversion_source_loop.mapped_reconstructed_image.array, 1.0e-2 ) - assert inversion_image_loop.mapped_reconstructed_data == pytest.approx( - inversion_source_loop.mapped_reconstructed_data, 1.0e-2 + assert inversion_image_loop.mapped_reconstructed_data.array == pytest.approx( + inversion_source_loop.mapped_reconstructed_data.array, 1.0e-2 ) diff --git a/test_autoarray/inversion/inversion/test_abstract.py b/test_autoarray/inversion/inversion/test_abstract.py index 5ce2ad473..5f1a918cf 100644 --- a/test_autoarray/inversion/inversion/test_abstract.py +++ b/test_autoarray/inversion/inversion/test_abstract.py @@ -6,7 +6,6 @@ from autoarray import exc - directory = path.dirname(path.realpath(__file__)) @@ -116,8 +115,8 @@ def test__curvature_matrix__via_w_tilde__identical_to_mapping(): grid = aa.Grid2D.from_mask(mask=mask, over_sample_size=1) - mesh_0 = aa.mesh.Rectangular(shape=(3, 3)) - mesh_1 = aa.mesh.Rectangular(shape=(4, 4)) + mesh_0 = aa.mesh.RectangularUniform(shape=(3, 3)) + mesh_1 = aa.mesh.RectangularUniform(shape=(4, 4)) mapper_grids_0 = mesh_0.mapper_grids_from( mask=mask, @@ -243,7 +242,7 @@ def test__curvature_reg_matrix_reduced(): curvature_reg_matrix = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) linear_obj_list = [ - aa.m.MockLinearObj(parameters=2, regularization=1), + aa.m.MockMapper(parameters=2, regularization=aa.m.MockRegularization()), aa.m.MockLinearObj(parameters=1, regularization=None), ] @@ -251,38 +250,13 @@ def test__curvature_reg_matrix_reduced(): linear_obj_list=linear_obj_list, curvature_reg_matrix=curvature_reg_matrix ) + print(inversion.curvature_reg_matrix_reduced) + assert ( inversion.curvature_reg_matrix_reduced == np.array([[1.0, 2.0], [4.0, 5.0]]) ).all() -# def test__curvature_reg_matrix_solver__edge_pixels_set_to_zero(): -# -# curvature_reg_matrix = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) -# -# linear_obj_list = [ -# aa.m.MockMapper(parameters=3, regularization=None, edge_pixel_list=[0]) -# ] -# -# inversion = aa.m.MockInversion( -# linear_obj_list=linear_obj_list, -# curvature_reg_matrix=curvature_reg_matrix, -# settings=aa.SettingsInversion(force_edge_pixels_to_zeros=True), -# ) -# -# curvature_reg_matrix = np.array( -# [ -# [0.0, 2.0, 3.0], -# [0.0, 5.0, 6.0], -# [0.0, 8.0, 9.0], -# ] -# ) -# -# assert inversion.curvature_reg_matrix_solver == pytest.approx( -# curvature_reg_matrix, 1.0e-4 -# ) - - def test__regularization_matrix(): reg_0 = aa.m.MockRegularization(regularization_matrix=np.ones((2, 2))) reg_1 = aa.m.MockRegularization(regularization_matrix=2.0 * np.ones((3, 3))) @@ -309,7 +283,7 @@ def test__regularization_matrix(): def test__reconstruction_reduced(): linear_obj_list = [ - aa.m.MockLinearObj(parameters=2, regularization=aa.m.MockRegularization()), + aa.m.MockMapper(parameters=2, regularization=aa.m.MockRegularization()), aa.m.MockLinearObj(parameters=1, regularization=None), ] @@ -452,17 +426,6 @@ def test__data_subtracted_dict(): assert (inversion.data_subtracted_dict[linear_obj_1] == 2.0 * np.ones(3)).all() -def test__reconstruction_raises_exception_for_linalg_error(): - # noinspection PyTypeChecker - inversion = aa.m.MockInversion( - data_vector=np.ones(3), curvature_reg_matrix=np.ones((3, 3)) - ) - - with pytest.raises(exc.InversionException): - # noinspection PyStatementEffect - inversion.reconstruction - - def test__regularization_term(): reconstruction = np.array([1.0, 1.0, 1.0]) diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index 79ea56020..5c0b11b88 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -1,6 +1,7 @@ import copy import numpy as np import pytest +from dill import settings import autoarray as aa @@ -43,6 +44,7 @@ def test__inversion_imaging__via_linear_obj_func_list(masked_imaging_7x7_no_blur linear_obj = aa.m.MockLinearObjFuncList( parameters=2, grid=grid, mapping_matrix=np.full(fill_value=0.5, shape=(9, 2)) ) + linear_obj.mapping_matrix[0, 0] = 1.0 inversion = aa.Inversion( dataset=masked_imaging_7x7_no_blur, @@ -52,8 +54,8 @@ def test__inversion_imaging__via_linear_obj_func_list(masked_imaging_7x7_no_blur assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObjFuncList) assert isinstance(inversion, aa.InversionImagingMapping) + assert inversion.reconstruction == pytest.approx(np.array([0.0, 2.0]), abs=1.0e-4) assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) - assert inversion.reconstruction == pytest.approx(np.array([1.0, 1.0]), 1.0e-4) def test__inversion_imaging__via_mapper( @@ -67,9 +69,14 @@ def test__inversion_imaging__via_mapper( settings=aa.SettingsInversion(use_w_tilde=False), ) - assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionImagingMapping) - assert inversion.log_det_curvature_reg_matrix_term == pytest.approx(6.9546, 1.0e-4) + assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( + 7.2571757082, 1.0e-4 + ) + # assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( + # 4.609440907938719, 1.0e-4 + # ) assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) inversion = aa.Inversion( @@ -78,9 +85,11 @@ def test__inversion_imaging__via_mapper( settings=aa.SettingsInversion(use_w_tilde=True), ) - assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionImagingWTilde) - assert inversion.log_det_curvature_reg_matrix_term == pytest.approx(6.9546, 1.0e-4) + assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( + 7.257175708246, 1.0e-4 + ) assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) inversion = aa.Inversion( @@ -183,6 +192,24 @@ def test__inversion_imaging__via_regularizations( assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) +def test__inversion_imaging__source_pixel_zeroed_indices( + masked_imaging_7x7_no_blur, + rectangular_mapper_7x7_3x3, +): + inversion = aa.Inversion( + dataset=masked_imaging_7x7_no_blur, + linear_obj_list=[rectangular_mapper_7x7_3x3], + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), + preloads=aa.Preloads( + mapper_indices=range(0, 9), source_pixel_zeroed_indices=np.array([0]) + ), + ) + + assert inversion.reconstruction.shape[0] == 9 + assert inversion.reconstruction[0] == 0.0 + assert inversion.reconstruction[1] > 0.0 + + def test__inversion_imaging__via_linear_obj_func_and_mapper( masked_imaging_7x7_no_blur, rectangular_mapper_7x7_3x3, @@ -209,10 +236,10 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper( ) assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObj) - assert isinstance(inversion.linear_obj_list[1], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[1], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionImagingMapping) assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( - 6.95465245, 1.0e-4 + 7.2571757082469945, 1.0e-4 ) assert inversion.reconstruction_dict[linear_obj] == pytest.approx( np.array([2.0]), 1.0e-4 @@ -220,6 +247,22 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper( assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3][0] < 1.0e-4 assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) + inversion = aa.Inversion( + dataset=masked_imaging_7x7_no_blur, + linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], + settings=aa.SettingsInversion( + use_w_tilde=True, + no_regularization_add_to_curvature_diag_value=False, + ), + ) + + assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObj) + assert isinstance(inversion.linear_obj_list[1], aa.MapperRectangularUniform) + assert isinstance(inversion, aa.InversionImagingWTilde) + assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( + 7.2571757082469945, 1.0e-4 + ) + def test__inversion_imaging__via_linear_obj_func_and_mapper__force_edge_pixels_to_zero( masked_imaging_7x7_no_blur, delaunay_mapper_9_3x3 @@ -231,7 +274,9 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper__force_edge_pixels_t linear_obj = aa.m.MockLinearObj( parameters=1, grid=grid, - mapping_matrix=np.full(fill_value=0.5, shape=(9, 1)), + mapping_matrix=np.array( + [[1.0], [2.0], [3.0], [2.0], [3.0], [4.0], [3.0], [1.0], [2.0]] + ), regularization=None, ) @@ -260,12 +305,14 @@ def test__inversion_imaging__via_linear_obj_func_and_mapper__force_edge_pixels_t ), ) + mapper_edge_pixel_list = inversion.mapper_edge_pixel_list + assert isinstance(inversion.linear_obj_list[0], aa.m.MockLinearObj) assert isinstance(inversion.linear_obj_list[1], aa.MapperDelaunay) assert isinstance(inversion, aa.InversionImagingMapping) - assert inversion.reconstruction == pytest.approx( - np.array([2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), 1.0e-4 - ) + # assert inversion.reconstruction[mapper_edge_pixel_list[0]] == pytest.approx(0.0, abs=1.0e-2) + # assert inversion.reconstruction[mapper_edge_pixel_list[1]] == pytest.approx(0.0, abs=1.0e-2) + # assert inversion.reconstruction[mapper_edge_pixel_list[2]] == pytest.approx(0.0, abs=1.0e-2) def test__inversion_imaging__compare_mapping_and_w_tilde_values( @@ -289,8 +336,8 @@ def test__inversion_imaging__compare_mapping_and_w_tilde_values( assert inversion_w_tilde.reconstruction == pytest.approx( inversion_mapping.reconstruction, 1.0e-4 ) - assert inversion_w_tilde.mapped_reconstructed_image == pytest.approx( - inversion_mapping.mapped_reconstructed_image, 1.0e-4 + assert inversion_w_tilde.mapped_reconstructed_image.array == pytest.approx( + inversion_mapping.mapped_reconstructed_image.array, 1.0e-4 ) assert inversion_w_tilde.log_det_curvature_reg_matrix_term == pytest.approx( inversion_mapping.log_det_curvature_reg_matrix_term @@ -309,13 +356,19 @@ def test__inversion_imaging__linear_obj_func_and_non_func_give_same_terms( grid = aa.Grid2D.from_mask(mask=mask) linear_obj = aa.m.MockLinearObj( - parameters=2, grid=grid, mapping_matrix=np.full(fill_value=0.5, shape=(9, 2)) + parameters=2, + grid=grid, + mapping_matrix=np.full(fill_value=0.5, shape=(9, 2)), ) inversion = aa.Inversion( dataset=masked_imaging_7x7_no_blur, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=False), + settings=aa.SettingsInversion( + use_w_tilde=False, + use_positive_only_solver=True, + force_edge_pixels_to_zeros=False, + ), ) masked_imaging_7x7_no_blur = copy.copy(masked_imaging_7x7_no_blur) @@ -327,7 +380,11 @@ def test__inversion_imaging__linear_obj_func_and_non_func_give_same_terms( inversion_no_linear_func = aa.Inversion( dataset=masked_imaging_7x7_no_blur, linear_obj_list=[rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=False), + settings=aa.SettingsInversion( + use_w_tilde=False, + use_positive_only_solver=True, + force_edge_pixels_to_zeros=False, + ), ) assert inversion.regularization_term == pytest.approx( @@ -346,11 +403,6 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( rectangular_mapper_7x7_3x3, delaunay_mapper_9_3x3, ): - masked_imaging_7x7 = copy.copy(masked_imaging_7x7) - masked_imaging_7x7.data[4] = 2.0 - masked_imaging_7x7.noise_map[3] = 4.0 - masked_imaging_7x7.psf[0] = 0.1 - masked_imaging_7x7.psf[4] = 0.9 mask = masked_imaging_7x7.mask @@ -367,13 +419,13 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( inversion_mapping = aa.Inversion( dataset=masked_imaging_7x7, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=False), + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), ) inversion_w_tilde = aa.Inversion( dataset=masked_imaging_7x7, linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], - settings=aa.SettingsInversion(use_w_tilde=True), + settings=aa.SettingsInversion(use_w_tilde=True, use_positive_only_solver=True), ) assert inversion_mapping.data_vector == pytest.approx( @@ -382,8 +434,14 @@ def test__inversion_imaging__linear_obj_func_with_w_tilde( assert inversion_mapping.curvature_matrix == pytest.approx( inversion_w_tilde.curvature_matrix, 1.0e-4 ) - assert inversion_mapping.mapped_reconstructed_image == pytest.approx( - inversion_w_tilde.mapped_reconstructed_image, 1.0e-4 + assert inversion_mapping.curvature_reg_matrix == pytest.approx( + inversion_w_tilde.curvature_reg_matrix, 1.0e-4 + ) + assert inversion_mapping.reconstruction == pytest.approx( + inversion_w_tilde.reconstruction, 1.0e-4 + ) + assert inversion_mapping.mapped_reconstructed_image.array == pytest.approx( + inversion_w_tilde.mapped_reconstructed_image.array, 1.0e-4 ) linear_obj_1 = aa.m.MockLinearObjFuncList( @@ -437,7 +495,7 @@ def test__inversion_interferometer__via_mapper( settings=aa.SettingsInversion(use_w_tilde=False), ) - assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangular) + assert isinstance(inversion.linear_obj_list[0], aa.MapperRectangularUniform) assert isinstance(inversion, aa.InversionInterferometerMapping) assert inversion.mapped_reconstructed_data == pytest.approx( 1.0 + 0.0j * np.ones(shape=(7,)), 1.0e-4 @@ -470,19 +528,19 @@ def test__inversion_matrices__x2_mappers( delaunay_mapper_9_3x3, regularization_constant, ): + inversion = aa.Inversion( dataset=masked_imaging_7x7_no_blur, linear_obj_list=[rectangular_mapper_7x7_3x3, delaunay_mapper_9_3x3], + settings=aa.SettingsInversion(use_positive_only_solver=True), ) - assert ( - inversion.operated_mapping_matrix[0:9, 0:9] - == rectangular_mapper_7x7_3x3.mapping_matrix - ).all() - assert ( - inversion.operated_mapping_matrix[0:9, 9:18] - == delaunay_mapper_9_3x3.mapping_matrix - ).all() + assert inversion.operated_mapping_matrix[0:9, 0:9] == pytest.approx( + rectangular_mapper_7x7_3x3.mapping_matrix, abs=1.0e-4 + ) + assert inversion.operated_mapping_matrix[0:9, 9:18] == pytest.approx( + delaunay_mapper_9_3x3.mapping_matrix, abs=1.0e-4 + ) operated_mapping_matrix = np.hstack( [ @@ -518,26 +576,21 @@ def test__inversion_matrices__x2_mappers( assert (inversion.regularization_matrix[0:9, 9:18] == np.zeros((9, 9))).all() assert (inversion.regularization_matrix[9:18, 0:9] == np.zeros((9, 9))).all() - reconstruction_0 = 0.5 * np.ones(9) - reconstruction_1 = 0.5 * np.ones(9) - - assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3] == pytest.approx( - reconstruction_0, 1.0e-4 - ) - assert inversion.reconstruction_dict[delaunay_mapper_9_3x3] == pytest.approx( - reconstruction_1, 1.0e-4 - ) - assert inversion.reconstruction == pytest.approx( - np.concatenate([reconstruction_0, reconstruction_1]), 1.0e-4 + assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3][ + 4 + ] == pytest.approx(0.5000029374603968, 1.0e-4) + assert inversion.reconstruction_dict[delaunay_mapper_9_3x3][4] == pytest.approx( + 0.4999970390886761, 1.0e-4 ) + assert inversion.reconstruction[13] == pytest.approx(0.49999703908867, 1.0e-4) - assert inversion.mapped_reconstructed_data_dict[ - rectangular_mapper_7x7_3x3 - ] == pytest.approx(0.5 * np.ones(9), 1.0e-4) - assert inversion.mapped_reconstructed_data_dict[ - delaunay_mapper_9_3x3 - ] == pytest.approx(0.5 * np.ones(9), 1.0e-4) - assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) + assert inversion.mapped_reconstructed_data_dict[rectangular_mapper_7x7_3x3][ + 4 + ] == pytest.approx(0.5000029, 1.0e-4) + assert inversion.mapped_reconstructed_data_dict[delaunay_mapper_9_3x3][ + 3 + ] == pytest.approx(0.49999704, 1.0e-4) + assert inversion.mapped_reconstructed_image[4] == pytest.approx(0.99999998, 1.0e-4) def test__inversion_imaging__positive_only_solver(masked_imaging_7x7_no_blur): @@ -559,3 +612,37 @@ def test__inversion_imaging__positive_only_solver(masked_imaging_7x7_no_blur): assert isinstance(inversion, aa.InversionImagingMapping) assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4) assert inversion.reconstruction == pytest.approx(np.array([2.0]), 1.0e-4) + + +def test__data_linear_func_matrix_dict( + masked_imaging_7x7, + rectangular_mapper_7x7_3x3, +): + + mask = masked_imaging_7x7.mask + + grid = aa.Grid2D.from_mask(mask=mask) + + mapping_matrix = np.full(fill_value=0.5, shape=(9, 2)) + mapping_matrix[0, 0] = 0.8 + mapping_matrix[1, 1] = 0.4 + + linear_obj = aa.m.MockLinearObjFuncList( + parameters=2, grid=grid, mapping_matrix=mapping_matrix + ) + + inversion_mapping = aa.Inversion( + dataset=masked_imaging_7x7, + linear_obj_list=[linear_obj, rectangular_mapper_7x7_3x3], + settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True), + ) + + assert inversion_mapping.data_linear_func_matrix_dict[linear_obj][ + 0 + ] == pytest.approx([0.075, 0.05972222], 1.0e-4) + assert inversion_mapping.data_linear_func_matrix_dict[linear_obj][ + 1 + ] == pytest.approx([0.09166667, 0.07847222], 1.0e-4) + assert inversion_mapping.data_linear_func_matrix_dict[linear_obj][ + 2 + ] == pytest.approx([0.06458333, 0.05972222], 1.0e-4) diff --git a/test_autoarray/inversion/inversion/test_inversion_util.py b/test_autoarray/inversion/inversion/test_inversion_util.py index 86b722812..0b1661f2e 100644 --- a/test_autoarray/inversion/inversion/test_inversion_util.py +++ b/test_autoarray/inversion/inversion/test_inversion_util.py @@ -81,28 +81,11 @@ def test__reconstruction_positive_negative_from(): reconstruction = aa.util.inversion.reconstruction_positive_negative_from( data_vector=data_vector, curvature_reg_matrix=curvature_reg_matrix, - mapper_param_range_list=[[0, 3]], ) assert reconstruction == pytest.approx(np.array([1.0, -1.0, 3.0]), 1.0e-4) -def test__reconstruction_positive_negative_from__check_solution_raises_error_cause_all_values_identical(): - data_vector = np.array([1.0, 1.0, 1.0]) - - curvature_reg_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - - # reconstruction = np.array([1.0, 1.0, 1.0]) - - with pytest.raises(aa.exc.InversionException): - aa.util.inversion.reconstruction_positive_negative_from( - data_vector=data_vector, - curvature_reg_matrix=curvature_reg_matrix, - mapper_param_range_list=[[0, 3]], - force_check_reconstruction=True, - ) - - def test__mapped_reconstructed_data_via_mapping_matrix_from(): mapping_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) @@ -138,7 +121,7 @@ def test__mapped_reconstructed_data_via_image_to_pix_unique_from(): data_to_pix_unique, data_weights, pix_lengths, - ) = aa.util.mapper.data_slim_to_pixelization_unique_from( + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( data_pixels=3, pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, pix_sizes_for_sub_slim_index=pix_indexes_for_sub_slim_index_sizes, @@ -149,13 +132,11 @@ def test__mapped_reconstructed_data_via_image_to_pix_unique_from(): reconstruction = np.array([1.0, 1.0, 2.0]) - mapped_reconstructed_data = ( - aa.util.inversion.mapped_reconstructed_data_via_image_to_pix_unique_from( - data_to_pix_unique=data_to_pix_unique.astype("int"), - data_weights=data_weights, - pix_lengths=pix_lengths.astype("int"), - reconstruction=reconstruction, - ) + mapped_reconstructed_data = aa.util.inversion_imaging_numba.mapped_reconstructed_data_via_image_to_pix_unique_from( + data_to_pix_unique=data_to_pix_unique.astype("int"), + data_weights=data_weights, + pix_lengths=pix_lengths.astype("int"), + reconstruction=reconstruction, ) assert (mapped_reconstructed_data == np.array([1.0, 1.0, 2.0])).all() @@ -170,7 +151,7 @@ def test__mapped_reconstructed_data_via_image_to_pix_unique_from(): data_to_pix_unique, data_weights, pix_lengths, - ) = aa.util.mapper.data_slim_to_pixelization_unique_from( + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( data_pixels=3, pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, pix_sizes_for_sub_slim_index=pix_indexes_for_sub_slim_index_sizes, @@ -181,13 +162,11 @@ def test__mapped_reconstructed_data_via_image_to_pix_unique_from(): reconstruction = np.array([1.0, 1.0, 2.0]) - mapped_reconstructed_data = ( - aa.util.inversion.mapped_reconstructed_data_via_image_to_pix_unique_from( - data_to_pix_unique=data_to_pix_unique.astype("int"), - data_weights=data_weights, - pix_lengths=pix_lengths.astype("int"), - reconstruction=reconstruction, - ) + mapped_reconstructed_data = aa.util.inversion_imaging_numba.mapped_reconstructed_data_via_image_to_pix_unique_from( + data_to_pix_unique=data_to_pix_unique.astype("int"), + data_weights=data_weights, + pix_lengths=pix_lengths.astype("int"), + reconstruction=reconstruction, ) assert (mapped_reconstructed_data == np.array([1.25, 1.0, 1.75])).all() diff --git a/test_autoarray/inversion/inversion/test_settings_dict.py b/test_autoarray/inversion/inversion/test_settings_dict.py index 6a6f8ca9f..21540bdd3 100644 --- a/test_autoarray/inversion/inversion/test_settings_dict.py +++ b/test_autoarray/inversion/inversion/test_settings_dict.py @@ -17,8 +17,6 @@ def make_settings_dict(): "use_positive_only_solver": False, "positive_only_uses_p_initial": False, "force_edge_pixels_to_zeros": True, - "force_edge_image_pixels_to_zeros": False, - "image_pixels_source_zero": None, "no_regularization_add_to_curvature_diag_value": 1e-08, "use_w_tilde_numpy": False, "use_source_loop": False, diff --git a/test_autoarray/inversion/pixelization/image_mesh/test_overlay.py b/test_autoarray/inversion/pixelization/image_mesh/test_overlay.py index 22536d5c8..54f242236 100644 --- a/test_autoarray/inversion/pixelization/image_mesh/test_overlay.py +++ b/test_autoarray/inversion/pixelization/image_mesh/test_overlay.py @@ -323,13 +323,13 @@ def test__image_plane_mesh_grid_from__simple(): total_pixels = overlay_util.total_pixels_2d_from( mask_2d=mask.array, - overlaid_centres=overlaid_centres, + overlaid_centres=np.array(overlaid_centres), ) overlay_for_mask_2d_util = overlay_util.overlay_for_mask_from( total_pixels=total_pixels, mask=mask.array, - overlaid_centres=overlaid_centres, + overlaid_centres=np.array(overlaid_centres), ).astype("int") image_mesh_util = overlay_util.overlay_via_unmasked_overlaid_from( diff --git a/test_autoarray/inversion/pixelization/mappers/test_abstract.py b/test_autoarray/inversion/pixelization/mappers/test_abstract.py index 7217dc79a..271e03772 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_abstract.py +++ b/test_autoarray/inversion/pixelization/mappers/test_abstract.py @@ -16,7 +16,7 @@ def test__pix_indexes_for_slim_indexes__different_types_of_lists_input(): parameters=9, ) - pixe_indexes_for_slim_indexes = mapper.pix_indexes_for_slim_indexes( + pixe_indexes_for_slim_indexes = mapper.slim_indexes_for_pix_indexes( pix_indexes=[0, 1] ) @@ -31,7 +31,7 @@ def test__pix_indexes_for_slim_indexes__different_types_of_lists_input(): parameters=9, ) - pixe_indexes_for_slim_indexes = mapper.pix_indexes_for_slim_indexes( + pixe_indexes_for_slim_indexes = mapper.slim_indexes_for_pix_indexes( pix_indexes=[[0], [4]] ) @@ -69,38 +69,6 @@ def test__sub_slim_indexes_for_pix_index(): [0, 1, 2, 3, 4, 5, 6, 7], ] - ( - sub_slim_indexes_for_pix_index, - sub_slim_sizes_for_pix_index, - sub_slim_weights_for_pix_index, - ) = mapper.sub_slim_indexes_for_pix_index_arr - - assert ( - sub_slim_indexes_for_pix_index - == np.array( - [ - [0, 3, 6, -1, -1, -1, -1, -1], - [1, 4, -1, -1, -1, -1, -1, -1], - [2, -1, -1, -1, -1, -1, -1, -1], - [5, 7, -1, -1, -1, -1, -1, -1], - [0, 1, 2, 3, 4, 5, 6, 7], - ] - ) - ).all() - assert (sub_slim_sizes_for_pix_index == np.array([3, 2, 1, 2, 8])).all() - assert ( - sub_slim_weights_for_pix_index - == np.array( - [ - [0.1, 0.4, 0.7, -1, -1, -1, -1, -1], - [0.2, 0.5, -1, -1, -1, -1, -1, -1], - [0.3, -1, -1, -1, -1, -1, -1, -1], - [0.6, 0.8, -1, -1, -1, -1, -1, -1], - [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2], - ] - ) - ).all() - def test__data_weight_total_for_pix_from(): mapper = aa.m.MockMapper( @@ -222,8 +190,8 @@ def test__mapped_to_source_from(grid_2d_7x7): ) mapped_to_source_util = aa.util.mapper.mapped_to_source_via_mapping_matrix_from( - mapping_matrix=mapper.mapping_matrix, - array_slim=np.array(array_slim), + mapping_matrix=np.array(mapper.mapping_matrix), + array_slim=array_slim, ) mapped_to_source_mapper = mapper.mapped_to_source_from(array=array_slim) diff --git a/test_autoarray/inversion/pixelization/mappers/test_delaunay.py b/test_autoarray/inversion/pixelization/mappers/test_delaunay.py index 33b03fbb4..070d47598 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_delaunay.py +++ b/test_autoarray/inversion/pixelization/mappers/test_delaunay.py @@ -28,7 +28,7 @@ def test__pix_indexes_for_sub_slim_index__matches_util(grid_2d_sub_1_7x7): ( pix_indexes_for_sub_slim_index_util, sizes, - ) = aa.util.mapper.pix_indexes_for_sub_slim_index_delaunay_from( + ) = aa.util.mapper_numba.pix_indexes_for_sub_slim_index_delaunay_from( source_plane_data_grid=np.array(mapper.source_plane_data_grid), simplex_index_for_sub_slim_index=simplex_index_for_sub_slim_index, pix_indexes_for_simplex_index=pix_indexes_for_simplex_index, diff --git a/test_autoarray/inversion/pixelization/mappers/test_factory.py b/test_autoarray/inversion/pixelization/mappers/test_factory.py index c08bca937..d8f68507d 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_factory.py +++ b/test_autoarray/inversion/pixelization/mappers/test_factory.py @@ -24,7 +24,7 @@ def test__rectangular_mapper(): grid.over_sampled[0, 0] = -2.0 grid.over_sampled[0, 1] = 2.0 - mesh = aa.mesh.Rectangular(shape=(3, 3)) + mesh = aa.mesh.RectangularUniform(shape=(3, 3)) mapper_grids = mesh.mapper_grids_from( mask=mask, @@ -35,25 +35,25 @@ def test__rectangular_mapper(): mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) - assert isinstance(mapper, aa.MapperRectangular) + assert isinstance(mapper, aa.MapperRectangularUniform) assert mapper.image_plane_mesh_grid == None assert mapper.source_plane_mesh_grid.geometry.shape_native_scaled == pytest.approx( (5.0, 5.0), 1.0e-4 ) assert mapper.source_plane_mesh_grid.origin == pytest.approx((0.5, 0.5), 1.0e-4) - assert ( - mapper.mapping_matrix - == np.array( + assert mapper.mapping_matrix == pytest.approx( + np.array( [ - [0.0, 0.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0675, 0.5775, 0.18, 0.0075, -0.065, -0.1425, 0.0, 0.0375, 0.3375], + [0.18, -0.03, 0.0, 0.84, -0.14, 0.0, 0.18, -0.03, 0.0], + [0.0225, 0.105, 0.0225, 0.105, 0.49, 0.105, 0.0225, 0.105, 0.0225], + [0.0, -0.03, 0.18, 0.0, -0.14, 0.84, 0.0, -0.03, 0.18], + [0.0, 0.0, 0.0, -0.03, -0.14, -0.03, 0.18, 0.84, 0.18], ] - ) - ).all() + ), + 1.0e-4, + ) assert mapper.shape_native == (3, 3) diff --git a/test_autoarray/inversion/pixelization/mappers/test_mapper_util.py b/test_autoarray/inversion/pixelization/mappers/test_mapper_util.py index a5b41a15f..2a862ea63 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_mapper_util.py +++ b/test_autoarray/inversion/pixelization/mappers/test_mapper_util.py @@ -13,62 +13,6 @@ def make_five_pixels(): return np.array([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]) -def _test__sub_slim_indexes_for_pix_index(): - pix_indexes_for_sub_slim_index = np.array( - [[0, 4], [1, 4], [2, 4], [0, 4], [1, 4], [3, 4], [0, 4], [3, 4]] - ).astype("int") - pix_pixels = 5 - pix_weights_for_sub_slim_index = np.array( - [ - [0.1, 0.9], - [0.2, 0.8], - [0.3, 0.7], - [0.4, 0.6], - [0.5, 0.5], - [0.6, 0.4], - [0.7, 0.3], - [0.8, 0.2], - ] - ) - - ( - sub_slim_indexes_for_pix_index, - sub_slim_sizes_for_pix_index, - sub_slim_weights_for_pix_index, - ) = aa.util.mapper.sub_slim_indexes_for_pix_index( - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - pix_pixels=pix_pixels, - ) - - assert ( - sub_slim_indexes_for_pix_index - == np.array( - [ - [0, 3, 6, -1, -1, -1, -1, -1], - [1, 4, -1, -1, -1, -1, -1, -1], - [2, -1, -1, -1, -1, -1, -1, -1], - [5, 7, -1, -1, -1, -1, -1, -1], - [0, 1, 2, 3, 4, 5, 6, 7], - ] - ) - ).all() - assert (sub_slim_sizes_for_pix_index == np.array([3, 2, 1, 2, 8])).all() - - assert ( - sub_slim_weights_for_pix_index - == np.array( - [ - [0.1, 0.4, 0.7, -1, -1, -1, -1, -1], - [0.2, 0.5, -1, -1, -1, -1, -1, -1], - [0.3, -1, -1, -1, -1, -1, -1, -1], - [0.6, 0.8, -1, -1, -1, -1, -1, -1], - [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2], - ] - ) - ).all() - - def test__mapping_matrix(three_pixels, five_pixels): pix_indexes_for_sub_slim_index = np.array([[0], [1], [2]]) slim_index_for_sub_slim_index = np.array([0, 1, 2]) @@ -334,7 +278,7 @@ def test__data_to_pix_unique_from(): data_to_pix_unique, data_weights, pix_lengths, - ) = aa.util.mapper.data_slim_to_pixelization_unique_from( + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( data_pixels=image_pixels, pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, pix_sizes_for_sub_slim_index=pix_size_for_sub_slim_index, @@ -370,7 +314,7 @@ def test__data_to_pix_unique_from(): data_to_pix_unique, data_weights, pix_lengths, - ) = aa.util.mapper.data_slim_to_pixelization_unique_from( + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( data_pixels=image_pixels, pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, pix_sizes_for_sub_slim_index=pix_size_for_sub_slim_index, @@ -399,7 +343,7 @@ def test__weights(): pix_indexes_for_sub_slim_index = np.array([[0, 1, 2], [2, -1, -1]]) - pixel_weights = aa.util.mapper.pixel_weights_delaunay_from( + pixel_weights = aa.util.mapper_numba.pixel_weights_delaunay_from( source_plane_data_grid=source_plane_data_grid, source_plane_mesh_grid=source_plane_mesh_grid, slim_index_for_sub_slim_index=slim_index_for_sub_slim_index, diff --git a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py index 026e230df..f80c67cde 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_rectangular.py +++ b/test_autoarray/inversion/pixelization/mappers/test_rectangular.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import autoarray as aa @@ -21,7 +22,7 @@ def test__pix_indexes_for_sub_slim_index__matches_util(): over_sample_size=1, ) - mesh_grid = aa.Mesh2DRectangular.overlay_grid( + mesh_grid = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid.over_sampled ) @@ -31,24 +32,22 @@ def test__pix_indexes_for_sub_slim_index__matches_util(): mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) - pix_indexes_for_sub_slim_index_util = np.array( - [ - aa.util.geometry.grid_pixel_indexes_2d_slim_from( - grid_scaled_2d_slim=np.array(grid.over_sampled), - shape_native=mesh_grid.shape_native, - pixel_scales=mesh_grid.pixel_scales, - origin=mesh_grid.origin, - ).astype("int") - ] - ).T + mappings, weights = ( + aa.util.mapper.rectangular_mappings_weights_via_interpolation_from( + shape_native=(3, 3), + source_plane_mesh_grid=mesh_grid.array, + source_plane_data_grid=aa.Grid2DIrregular( + mapper_grids.source_plane_data_grid.over_sampled + ).array, + ) + ) - assert ( - mapper.pix_indexes_for_sub_slim_index == pix_indexes_for_sub_slim_index_util - ).all() + assert (mapper.pix_sub_weights.mappings == mappings).all() + assert (mapper.pix_sub_weights.weights == weights).all() def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7): - mesh_grid = aa.Mesh2DRectangular.overlay_grid( + mesh_grid = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid_2d_sub_1_7x7.over_sampled ) @@ -70,7 +69,77 @@ def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7): pix_size_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index, pixel_weights=mapper.pix_weights_for_sub_slim_index, slim_index_for_sub_slim_index=grid_2d_sub_1_7x7.over_sampler.slim_for_sub_slim, - adapt_data=np.array(image_7x7), + adapt_data=image_7x7, ) assert (pixel_signals == pixel_signals_util).all() + + +def test__areas_transformed(mask_2d_7x7): + + grid = aa.Grid2DIrregular( + [ + [-1.5, -1.5], + [-1.5, 0.0], + [-1.5, 1.5], + [0.0, -1.5], + [0.0, 0.0], + [0.0, 1.5], + [1.5, -1.5], + [1.5, 0.0], + [1.5, 1.5], + ], + ) + + mesh = aa.Mesh2DRectangularUniform.overlay_grid( + shape_native=(3, 3), grid=grid, buffer=1e-8 + ) + + mapper_grids = aa.MapperGrids( + mask=mask_2d_7x7, + source_plane_data_grid=grid, + source_plane_mesh_grid=mesh, + ) + + mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) + + assert mapper.areas_transformed[4] == pytest.approx( + 4.0, + abs=1e-8, + ) + + +def test__edges_transformed(mask_2d_7x7): + + grid = aa.Grid2DIrregular( + [ + [-1.5, -1.5], + [-1.5, 0.0], + [-1.5, 1.5], + [0.0, -1.5], + [0.0, 0.0], + [0.0, 1.5], + [1.5, -1.5], + [1.5, 0.0], + [1.5, 1.5], + ], + ) + + mesh = aa.Mesh2DRectangularUniform.overlay_grid( + shape_native=(3, 3), grid=grid, buffer=1e-8 + ) + + mapper_grids = aa.MapperGrids( + mask=mask_2d_7x7, + source_plane_data_grid=grid, + source_plane_mesh_grid=mesh, + ) + + mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None) + + assert mapper.edges_transformed[4] == pytest.approx( + np.array( + [1.5, 1.5], # left + ), + abs=1e-8, + ) diff --git a/test_autoarray/inversion/pixelization/mappers/test_voronoi.py b/test_autoarray/inversion/pixelization/mappers/test_voronoi.py index 57e971aae..92bc5995a 100644 --- a/test_autoarray/inversion/pixelization/mappers/test_voronoi.py +++ b/test_autoarray/inversion/pixelization/mappers/test_voronoi.py @@ -35,7 +35,7 @@ def test__pix_indexes_for_sub_slim_index__matches_util(grid_2d_sub_1_7x7): pix_indexes_for_sub_slim_index_util, sizes, weights, - ) = aa.util.mapper.pix_size_weights_voronoi_nn_from( + ) = aa.util.mapper_numba.pix_size_weights_voronoi_nn_from( grid=grid_2d_sub_1_7x7, mesh_grid=source_plane_mesh_grid ) diff --git a/test_autoarray/inversion/pixelization/mesh/test_mesh_util.py b/test_autoarray/inversion/pixelization/mesh/test_mesh_util.py index 25cca246c..b9058dede 100644 --- a/test_autoarray/inversion/pixelization/mesh/test_mesh_util.py +++ b/test_autoarray/inversion/pixelization/mesh/test_mesh_util.py @@ -110,7 +110,7 @@ def test__voronoi_neighbors_from(): points = np.array([[1.0, -1.0], [1.0, 1.0], [0.0, 0.0], [-1.0, -1.0], [-1.0, 1.0]]) voronoi = scipy.spatial.Voronoi(points, qhull_options="Qbb Qc Qx Qm") - (neighbors, neighbors_sizes) = aa.util.mesh.voronoi_neighbors_from( + (neighbors, neighbors_sizes) = aa.util.mesh_numba.voronoi_neighbors_from( pixels=5, ridge_points=np.array(voronoi.ridge_points) ) @@ -139,7 +139,7 @@ def test__voronoi_neighbors_from(): ) voronoi = scipy.spatial.Voronoi(points, qhull_options="Qbb Qc Qx Qm") - (neighbors, neighbors_sizes) = aa.util.mesh.voronoi_neighbors_from( + (neighbors, neighbors_sizes) = aa.util.mesh_numba.voronoi_neighbors_from( pixels=9, ridge_points=np.array(voronoi.ridge_points) ) @@ -171,7 +171,7 @@ def test__delaunay_interpolated_grid_from(): values = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) - interpolated_grid = aa.util.mesh.delaunay_interpolated_array_from( + interpolated_grid = aa.util.mesh_numba.delaunay_interpolated_array_from( shape_native=shape_native, interpolation_grid_slim=grid_interpolate_slim, pixel_values=values, diff --git a/test_autoarray/inversion/plot/test_inversion_plotters.py b/test_autoarray/inversion/plot/test_inversion_plotters.py index 735365ecd..62737ec87 100644 --- a/test_autoarray/inversion/plot/test_inversion_plotters.py +++ b/test_autoarray/inversion/plot/test_inversion_plotters.py @@ -25,7 +25,7 @@ def test__individual_attributes_are_output_for_all_mappers( ): inversion_plotter = aplt.InversionPlotter( inversion=rectangular_inversion_7x7_3x3, - visuals_2d=aplt.Visuals2D(indexes=[0], pix_indexes=[1]), + visuals_2d=aplt.Visuals2D(indexes=[0]), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -55,7 +55,7 @@ def test__individual_attributes_are_output_for_all_mappers( inversion_plotter = aplt.InversionPlotter( inversion=voronoi_inversion_9_3x3, - visuals_2d=aplt.Visuals2D(indexes=[0], pix_indexes=[1]), + visuals_2d=aplt.Visuals2D(indexes=[0]), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) @@ -101,7 +101,7 @@ def test__inversion_subplot_of_mapper__is_output_for_all_inversions( ): inversion_plotter = aplt.InversionPlotter( inversion=rectangular_inversion_7x7_3x3, - visuals_2d=aplt.Visuals2D(indexes=[0], pix_indexes=[1]), + visuals_2d=aplt.Visuals2D(indexes=[0]), mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), ) diff --git a/test_autoarray/inversion/plot/test_mapper_plotters.py b/test_autoarray/inversion/plot/test_mapper_plotters.py index 3dabe5514..1b91c3ad3 100644 --- a/test_autoarray/inversion/plot/test_mapper_plotters.py +++ b/test_autoarray/inversion/plot/test_mapper_plotters.py @@ -14,91 +14,6 @@ def make_plot_path_setup(): ) -def test__get_2d__via_mapper_for_data_from(rectangular_mapper_7x7_3x3): - include = aplt.Include2D( - origin=True, mask=True, mapper_image_plane_mesh_grid=True, border=True - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, include_2d=include - ) - - get_2d = mapper_plotter.get_2d.via_mapper_for_data_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert get_2d.origin.in_list == [(0.0, 0.0)] - assert (get_2d.mask == rectangular_mapper_7x7_3x3.mapper_grids.mask).all() - assert get_2d.grid == None - - include = aplt.Include2D( - origin=False, mask=False, mapper_image_plane_mesh_grid=False, border=False - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, include_2d=include - ) - - get_2d = mapper_plotter.get_2d.via_mapper_for_data_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert get_2d.origin == None - assert get_2d.mask == None - assert get_2d.grid == None - assert get_2d.border == None - - -def test__get_2d__via_mapper_for_source_from(rectangular_mapper_7x7_3x3): - include = aplt.Include2D( - origin=True, - mapper_source_plane_data_grid=True, - mapper_source_plane_mesh_grid=True, - border=True, - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, include_2d=include - ) - - get_2d = mapper_plotter.get_2d.via_mapper_for_source_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert mapper_plotter.visuals_2d.origin == None - assert get_2d.origin.in_list == [(0.0, 0.0)] - assert ( - get_2d.grid == rectangular_mapper_7x7_3x3.source_plane_data_grid.over_sampled - ).all() - assert (get_2d.mesh_grid == rectangular_mapper_7x7_3x3.source_plane_mesh_grid).all() - border_grid = ( - rectangular_mapper_7x7_3x3.mapper_grids.source_plane_data_grid.over_sampled[ - rectangular_mapper_7x7_3x3.border_relocator.sub_border_slim - ] - ) - assert (get_2d.border == border_grid).all() - - include = aplt.Include2D( - origin=False, - border=False, - mapper_source_plane_data_grid=False, - mapper_source_plane_mesh_grid=False, - ) - - mapper_plotter = aplt.MapperPlotter( - mapper=rectangular_mapper_7x7_3x3, include_2d=include - ) - - get_2d = mapper_plotter.get_2d.via_mapper_for_source_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert get_2d.origin == None - assert get_2d.grid == None - assert get_2d.mesh_grid == None - assert get_2d.border == None - - def test__figure_2d( rectangular_mapper_7x7_3x3, delaunay_mapper_9_3x3, @@ -107,20 +22,17 @@ def test__figure_2d( plot_patch, ): visuals_2d = aplt.Visuals2D( - indexes=[[(0, 0), (0, 1)], [(1, 2)]], pix_indexes=[[0, 1], [2]] + indexes=[[(0, 0), (0, 1)], [(1, 2)]], ) mat_plot_2d = aplt.MatPlot2D( output=aplt.Output(path=plot_path, filename="mapper1", format="png") ) - include_2d = aplt.Include2D(origin=True, mapper_source_plane_mesh_grid=True) - mapper_plotter = aplt.MapperPlotter( mapper=rectangular_mapper_7x7_3x3, visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, ) mapper_plotter.figure_2d() @@ -133,10 +45,9 @@ def test__figure_2d( mapper=delaunay_mapper_9_3x3, visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, ) - mapper_plotter.figure_2d() + mapper_plotter.figure_2d(interpolate_to_uniform=True) assert path.join(plot_path, "mapper1.png") in plot_patch.paths @@ -151,10 +62,9 @@ def test__figure_2d( mapper=voronoi_mapper_9_3x3, visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d, - include_2d=include_2d, ) - mapper_plotter.figure_2d() + mapper_plotter.figure_2d(interpolate_to_uniform=True) assert path.join(plot_path, "mapper1.png") in plot_patch.paths @@ -167,16 +77,17 @@ def test__subplot_image_and_mapper( plot_path, plot_patch, ): - visuals_2d = aplt.Visuals2D(indexes=[0, 1, 2], pix_indexes=[[0, 1], [2]]) + visuals_2d = aplt.Visuals2D(indexes=[0, 1, 2]) mapper_plotter = aplt.MapperPlotter( mapper=rectangular_mapper_7x7_3x3, visuals_2d=visuals_2d, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - include_2d=aplt.Include2D(mapper_source_plane_mesh_grid=True), ) - mapper_plotter.subplot_image_and_mapper(image=imaging_7x7.data) + mapper_plotter.subplot_image_and_mapper( + image=imaging_7x7.data, interpolate_to_uniform=True + ) assert path.join(plot_path, "subplot_image_and_mapper.png") in plot_patch.paths plot_patch.paths = [] @@ -185,10 +96,11 @@ def test__subplot_image_and_mapper( mapper=delaunay_mapper_9_3x3, visuals_2d=visuals_2d, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - include_2d=aplt.Include2D(mapper_source_plane_mesh_grid=True), ) - mapper_plotter.subplot_image_and_mapper(image=imaging_7x7.data) + mapper_plotter.subplot_image_and_mapper( + image=imaging_7x7.data, interpolate_to_uniform=True + ) assert path.join(plot_path, "subplot_image_and_mapper.png") in plot_patch.paths pytest.importorskip( @@ -202,8 +114,9 @@ def test__subplot_image_and_mapper( mapper=voronoi_mapper_9_3x3, visuals_2d=visuals_2d, mat_plot_2d=aplt.MatPlot2D(output=aplt.Output(path=plot_path, format="png")), - include_2d=aplt.Include2D(mapper_source_plane_mesh_grid=True), ) - mapper_plotter.subplot_image_and_mapper(image=imaging_7x7.data) + mapper_plotter.subplot_image_and_mapper( + image=imaging_7x7.data, interpolate_to_uniform=True + ) assert path.join(plot_path, "subplot_image_and_mapper.png") in plot_patch.paths diff --git a/test_autoarray/inversion/regularizations/test_adaptive_brightness.py b/test_autoarray/inversion/regularizations/test_adaptive_brightness.py index b808682b7..b3cf4f132 100644 --- a/test_autoarray/inversion/regularizations/test_adaptive_brightness.py +++ b/test_autoarray/inversion/regularizations/test_adaptive_brightness.py @@ -55,7 +55,6 @@ def test__regularization_matrix__matches_util(): aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) ) diff --git a/test_autoarray/inversion/regularizations/test_regularization_util.py b/test_autoarray/inversion/regularizations/test_regularization_util.py index e9af395d4..05a4bd0d4 100644 --- a/test_autoarray/inversion/regularizations/test_regularization_util.py +++ b/test_autoarray/inversion/regularizations/test_regularization_util.py @@ -227,8 +227,6 @@ def test__brightness_zeroth_regularization_weights_from(): def test__weighted_regularization_matrix_from(): neighbors = np.array([[2], [3], [0], [1]]) - neighbors_sizes = np.array([1, 1, 1, 1]) - b_matrix = np.array([[-1, 0, 1, 0], [0, -1, 0, 1], [1, 0, -1, 0], [0, 1, 0, -1]]) test_regularization_matrix = np.matmul(b_matrix.T, b_matrix) @@ -238,7 +236,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -251,8 +248,6 @@ def test__weighted_regularization_matrix_from(): neighbors = np.array([[1, 2], [0, -1], [0, -1]]) - neighbors_sizes = np.array([2, 1, 1]) - b_matrix_1 = np.array( [[-1, 1, 0], [-1, 0, 1], [1, -1, 0]] # Pair 1 # Pair 2 ) # Pair 1 flip @@ -272,7 +267,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -291,14 +285,11 @@ def test__weighted_regularization_matrix_from(): neighbors = np.array([[1, 3], [0, 2], [1, 3], [0, 2]]) - neighbors_sizes = np.array([2, 2, 2, 2]) - regularization_weights = np.ones((4,)) regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -383,7 +374,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -415,12 +405,9 @@ def test__weighted_regularization_matrix_from(): [[1, 2, -1, -1], [0, 2, 3, -1], [0, 1, -1, -1], [1, -1, -1, -1]] ) - neighbors_sizes = np.array([2, 3, 2, 1]) - regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) @@ -436,7 +423,6 @@ def test__weighted_regularization_matrix_from(): ] ) - neighbors_sizes = np.array([2, 3, 4, 2, 4, 3]) regularization_weights = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) # I'm inputting the regularization weight_list directly thiss time, as it'd be a pain to multiply with a @@ -503,7 +489,6 @@ def test__weighted_regularization_matrix_from(): regularization_matrix = aa.util.regularization.weighted_regularization_matrix_from( regularization_weights=regularization_weights, neighbors=neighbors, - neighbors_sizes=neighbors_sizes, ) assert regularization_matrix == pytest.approx(test_regularization_matrix, 1.0e-4) diff --git a/test_autoarray/inversion/test_linear_obj.py b/test_autoarray/inversion/test_linear_obj.py index c54204350..f906c2989 100644 --- a/test_autoarray/inversion/test_linear_obj.py +++ b/test_autoarray/inversion/test_linear_obj.py @@ -27,7 +27,7 @@ def test__data_to_pix_unique_from(): data_to_pix_unique, data_weights, pix_lengths, - ) = aa.util.mapper.data_slim_to_pixelization_unique_from( + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( data_pixels=image_pixels, pix_indexes_for_sub_slim_index=pix_index_for_sub_slim_index, pix_sizes_for_sub_slim_index=pix_sizes_for_sub_slim_index, diff --git a/test_autoarray/layout/test_region.py b/test_autoarray/layout/test_region.py index 643b8532a..690cfa9bb 100644 --- a/test_autoarray/layout/test_region.py +++ b/test_autoarray/layout/test_region.py @@ -152,7 +152,7 @@ def test__slice_2d__addition(): image = np.ones((2, 2)) region = aa.Region2D(region=(0, 1, 0, 1)) - array = array.at[region.slice].add(image[region.slice]) + array[region.slice] += image[region.slice] assert (array == np.array([[1.0, 0.0], [0.0, 0.0]])).all() @@ -161,7 +161,7 @@ def test__slice_2d__addition(): image = np.ones((2, 2)) region = aa.Region2D(region=(0, 1, 0, 1)) - array = array.at[region.slice].add(image[region.slice]) + array[region.slice] += image[region.slice] assert (array == np.array([[2.0, 1.0], [1.0, 1.0]])).all() @@ -170,7 +170,7 @@ def test__slice_2d__addition(): image = np.ones((3, 3)) region = aa.Region2D(region=(1, 3, 2, 3)) - array = array.at[region.slice].add(image[region.slice]) + array[region.slice] += image[region.slice] assert ( array == np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [1.0, 1.0, 2.0]]) @@ -183,7 +183,7 @@ def test__slice_2d__set_to_zerose(): region = aa.Region2D(region=(0, 1, 0, 1)) - array = array.at[region.slice].set(0) + array[region.slice] = 0 assert (array == np.array([[0.0, 1.0], [1.0, 1.0]])).all() @@ -192,7 +192,7 @@ def test__slice_2d__set_to_zerose(): region = aa.Region2D(region=(1, 3, 2, 3)) - array = array.at[region.slice].set(0) + array[region.slice] = 0 assert ( array == np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) diff --git a/test_autoarray/mask/derive/test_zoom_2d.py b/test_autoarray/mask/derive/test_zoom_2d.py new file mode 100644 index 000000000..0578138a1 --- /dev/null +++ b/test_autoarray/mask/derive/test_zoom_2d.py @@ -0,0 +1,213 @@ +import numpy as np + +import autoarray as aa + + +def test__quantities(): + mask = aa.Mask2D.all_false(shape_native=(4, 6), pixel_scales=(1.0, 1.0)) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (1.5, 2.5) + assert zoom.offset_pixels == (0, 0) + assert zoom.shape_native == (6, 6) + + mask = aa.Mask2D.all_false(shape_native=(6, 4), pixel_scales=(1.0, 1.0)) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (2.5, 1.5) + assert zoom.offset_pixels == (0, 0) + assert zoom.shape_native == (6, 6) + + mask = aa.Mask2D( + mask=np.array([[True, True, True], [True, True, False], [True, True, True]]), + pixel_scales=(1.0, 1.0), + ) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (1, 2) + assert zoom.offset_pixels == (0, 1) + assert zoom.shape_native == (1, 1) + + mask = aa.Mask2D( + mask=np.array([[True, True, True], [True, True, True], [True, False, True]]), + pixel_scales=(1.0, 1.0), + ) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (2, 1) + assert zoom.offset_pixels == (1, 0) + assert zoom.shape_native == (1, 1) + + mask = aa.Mask2D( + mask=np.array([[False, True, False], [True, True, True], [True, True, True]]), + pixel_scales=(1.0, 1.0), + ) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (0, 1) + assert zoom.offset_pixels == (-1, 0) + assert zoom.shape_native == (3, 3) + + mask = aa.Mask2D( + mask=np.array([[False, False, True], [True, True, True], [True, True, True]]), + pixel_scales=(1.0, 1.0), + ) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (0, 0.5) + assert zoom.offset_pixels == (-1, -0.5) + assert zoom.shape_native == (1, 2) + + mask = aa.Mask2D( + mask=np.array( + [ + [True, True, True, True, True, True, True], + [True, True, True, True, True, True, True], + [True, True, True, True, True, True, False], + ] + ), + pixel_scales=(1.0, 1.0), + ) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (2, 6) + assert zoom.offset_pixels == (1, 3) + + mask = aa.Mask2D( + mask=np.array( + [ + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, False], + ] + ), + pixel_scales=(1.0, 1.0), + ) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (4, 2) + assert zoom.offset_pixels == (2, 1) + + mask = aa.Mask2D( + mask=np.array( + [ + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, True], + [True, True, False], + ] + ), + pixel_scales=(1.0, 1.0), + ) + zoom = aa.Zoom2D(mask=mask) + + assert zoom.centre == (6, 2) + assert zoom.offset_pixels == (3, 1) + + +def test__array_2d_from(): + array_2d = [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + + mask = aa.Mask2D( + mask=[ + [True, True, True, True], + [True, False, False, True], + [True, False, False, True], + [True, True, True, True], + ], + pixel_scales=(1.0, 1.0), + ) + + arr = aa.Array2D(values=array_2d, mask=mask) + zoom = aa.Zoom2D(mask=mask) + arr_zoomed = zoom.array_2d_from(array=arr, buffer=0) + + assert (arr_zoomed.native == np.array([[6.0, 7.0], [10.0, 11.0]])).all() + + mask = aa.Mask2D( + mask=np.array( + [ + [True, True, True, True], + [True, False, False, True], + [False, False, False, True], + [True, True, True, True], + ] + ), + pixel_scales=(1.0, 1.0), + ) + + arr = aa.Array2D(values=array_2d, mask=mask) + zoom = aa.Zoom2D(mask=mask) + arr_zoomed = zoom.array_2d_from(array=arr, buffer=0) + + assert (arr_zoomed.native == np.array([[0.0, 6.0, 7.0], [9.0, 10.0, 11.0]])).all() + + mask = aa.Mask2D( + mask=np.array( + [ + [True, False, True, True], + [True, False, False, True], + [True, False, False, True], + [True, True, True, True], + ] + ), + pixel_scales=(1.0, 1.0), + ) + + arr = aa.Array2D(values=array_2d, mask=mask) + zoom = aa.Zoom2D(mask=mask) + arr_zoomed = zoom.array_2d_from(array=arr, buffer=0) + + assert (arr_zoomed.native == np.array([[2.0, 0.0], [6.0, 7.0], [10.0, 11.0]])).all() + + array_2d = np.ones(shape=(4, 4)) + + mask = aa.Mask2D( + mask=np.array( + [ + [True, True, True, True], + [True, False, False, True], + [True, False, False, True], + [True, True, True, True], + ] + ), + pixel_scales=(1.0, 1.0), + ) + + arr = aa.Array2D(values=array_2d, mask=mask) + zoom = aa.Zoom2D(mask=mask) + arr_zoomed = zoom.array_2d_from(array=arr, buffer=0) + + assert arr_zoomed.mask.origin == (0.0, 0.0) + + array_2d = np.ones(shape=(6, 6)) + + mask = aa.Mask2D( + mask=np.array( + [ + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, False, False, True], + [True, True, True, False, False, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + ] + ), + pixel_scales=(1.0, 1.0), + ) + + arr = aa.Array2D(values=array_2d, mask=mask) + zoom = aa.Zoom2D(mask=mask) + arr_zoomed = zoom.array_2d_from(array=arr, buffer=0) + + assert arr_zoomed.mask.origin == (0.0, 1.0) diff --git a/test_autoarray/mask/test_mask_1d_util.py b/test_autoarray/mask/test_mask_1d_util.py index c8251471f..1b032c37f 100644 --- a/test_autoarray/mask/test_mask_1d_util.py +++ b/test_autoarray/mask/test_mask_1d_util.py @@ -1,14 +1,6 @@ -from autoarray import exc from autoarray import util import numpy as np -import pytest - - -def test__total_image_pixels_1d_from(): - mask_1d = np.array([False, True, False, False, False, True]) - - assert util.mask_1d.total_pixels_1d_from(mask_1d=mask_1d) == 4 def test__native_index_for_slim_index_1d_from(): diff --git a/test_autoarray/mask/test_mask_2d.py b/test_autoarray/mask/test_mask_2d.py index 399510a0f..ed5f35fe0 100644 --- a/test_autoarray/mask/test_mask_2d.py +++ b/test_autoarray/mask/test_mask_2d.py @@ -490,256 +490,6 @@ def test__resized_from(): assert (mask_resized == mask_resized_manual).all() -def test__zoom_quantities(): - mask = aa.Mask2D.all_false(shape_native=(3, 5), pixel_scales=(1.0, 1.0)) - assert mask.zoom_centre == (1.0, 2.0) - assert mask.zoom_offset_pixels == (0, 0) - assert mask.zoom_shape_native == (5, 5) - - mask = aa.Mask2D.all_false(shape_native=(5, 3), pixel_scales=(1.0, 1.0)) - assert mask.zoom_centre == (2.0, 1.0) - assert mask.zoom_offset_pixels == (0, 0) - assert mask.zoom_shape_native == (5, 5) - - mask = aa.Mask2D.all_false(shape_native=(4, 6), pixel_scales=(1.0, 1.0)) - assert mask.zoom_centre == (1.5, 2.5) - assert mask.zoom_offset_pixels == (0, 0) - assert mask.zoom_shape_native == (6, 6) - - mask = aa.Mask2D.all_false(shape_native=(6, 4), pixel_scales=(1.0, 1.0)) - assert mask.zoom_centre == (2.5, 1.5) - assert mask.zoom_offset_pixels == (0, 0) - assert mask.zoom_shape_native == (6, 6) - - -def test__mask_is_single_false__extraction_centre_is_central_pixel(): - mask = aa.Mask2D( - mask=np.array([[False, True, True], [True, True, True], [True, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (0, 0) - assert mask.zoom_offset_pixels == (-1, -1) - assert mask.zoom_shape_native == (1, 1) - - mask = aa.Mask2D( - mask=np.array([[True, True, False], [True, True, True], [True, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (0, 2) - assert mask.zoom_offset_pixels == (-1, 1) - assert mask.zoom_shape_native == (1, 1) - - mask = aa.Mask2D( - mask=np.array([[True, True, True], [True, True, True], [False, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (2, 0) - assert mask.zoom_offset_pixels == (1, -1) - assert mask.zoom_shape_native == (1, 1) - - mask = aa.Mask2D( - mask=np.array([[True, True, True], [True, True, True], [True, True, False]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (2, 2) - assert mask.zoom_offset_pixels == (1, 1) - assert mask.zoom_shape_native == (1, 1) - - mask = aa.Mask2D( - mask=np.array([[True, False, True], [True, True, True], [True, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (0, 1) - assert mask.zoom_offset_pixels == (-1, 0) - assert mask.zoom_shape_native == (1, 1) - - mask = aa.Mask2D( - mask=np.array([[True, True, True], [False, True, True], [True, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (1, 0) - assert mask.zoom_offset_pixels == (0, -1) - assert mask.zoom_shape_native == (1, 1) - - mask = aa.Mask2D( - mask=np.array([[True, True, True], [True, True, False], [True, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (1, 2) - assert mask.zoom_offset_pixels == (0, 1) - assert mask.zoom_shape_native == (1, 1) - - mask = aa.Mask2D( - mask=np.array([[True, True, True], [True, True, True], [True, False, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (2, 1) - assert mask.zoom_offset_pixels == (1, 0) - assert mask.zoom_shape_native == (1, 1) - - -def test__mask_is_x2_false__extraction_centre_is_central_pixel(): - mask = aa.Mask2D( - mask=np.array([[False, True, True], [True, True, True], [True, True, False]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (1, 1) - assert mask.zoom_offset_pixels == (0, 0) - assert mask.zoom_shape_native == (3, 3) - - mask = aa.Mask2D( - mask=np.array([[False, True, True], [True, True, True], [False, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (1, 0) - assert mask.zoom_offset_pixels == (0, -1) - assert mask.zoom_shape_native == (3, 3) - - mask = aa.Mask2D( - mask=np.array([[False, True, False], [True, True, True], [True, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (0, 1) - assert mask.zoom_offset_pixels == (-1, 0) - assert mask.zoom_shape_native == (3, 3) - - mask = aa.Mask2D( - mask=np.array([[False, False, True], [True, True, True], [True, True, True]]), - pixel_scales=(1.0, 1.0), - ) - assert mask.zoom_centre == (0, 0.5) - assert mask.zoom_offset_pixels == (-1, -0.5) - assert mask.zoom_shape_native == (1, 2) - - -def test__rectangular_mask(): - mask = aa.Mask2D( - mask=np.array( - [ - [False, True, True, True], - [True, True, True, True], - [True, True, True, True], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - assert mask.zoom_centre == (0, 0) - assert mask.zoom_offset_pixels == (-1.0, -1.5) - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True, True], - [True, True, True, True], - [True, True, True, False], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - assert mask.zoom_centre == (2, 3) - assert mask.zoom_offset_pixels == (1.0, 1.5) - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True, True, True], - [True, True, True, True, True], - [True, True, True, True, False], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - assert mask.zoom_centre == (2, 4) - assert mask.zoom_offset_pixels == (1, 2) - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, True], - [True, True, True, True, True, True, False], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - assert mask.zoom_centre == (2, 6) - assert mask.zoom_offset_pixels == (1, 3) - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True], - [True, True, True], - [True, True, True], - [True, True, True], - [True, True, False], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - assert mask.zoom_centre == (4, 2) - assert mask.zoom_offset_pixels == (2, 1) - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True], - [True, True, True], - [True, True, True], - [True, True, True], - [True, True, True], - [True, True, True], - [True, True, False], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - assert mask.zoom_centre == (6, 2) - assert mask.zoom_offset_pixels == (3, 1) - - -def test__zoom_mask_unmasked(): - mask = aa.Mask2D( - mask=np.array( - [ - [False, True, True, True], - [True, False, True, True], - [True, True, True, True], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - zoom_mask = mask.zoom_mask_unmasked - - assert (zoom_mask == np.array([[False, False], [False, False]])).all() - assert zoom_mask.origin == (0.5, -1.0) - - mask = aa.Mask2D( - mask=np.array( - [ - [False, True, True, True], - [True, False, True, True], - [True, False, True, True], - ] - ), - pixel_scales=(1.0, 2.0), - ) - - zoom_mask = mask.zoom_mask_unmasked - - assert ( - zoom_mask == np.array([[False, False], [False, False], [False, False]]) - ).all() - assert zoom_mask.origin == (0.0, -2.0) - - def test__mask_centre(): mask = np.array( [ diff --git a/test_autoarray/mask/test_mask_2d_util.py b/test_autoarray/mask/test_mask_2d_util.py index f3db2938e..ee79a9543 100644 --- a/test_autoarray/mask/test_mask_2d_util.py +++ b/test_autoarray/mask/test_mask_2d_util.py @@ -738,8 +738,6 @@ def test__edge_1d_indexes_from(): edge_pixels = util.mask_2d.edge_1d_indexes_from(mask_2d=mask) - print(edge_pixels) - assert (edge_pixels == np.array([0])).all() mask = np.array( diff --git a/test_autoarray/operators/over_sample/test_decorator.py b/test_autoarray/operators/over_sample/test_decorator.py index 47dae1e88..72af0215a 100644 --- a/test_autoarray/operators/over_sample/test_decorator.py +++ b/test_autoarray/operators/over_sample/test_decorator.py @@ -28,9 +28,23 @@ def test__in_grid_2d__over_sample_uniform__out_ndarray_1d(): over_sample_uniform = aa.OverSampler(mask=mask, sub_size=2) - mask_sub_2 = aa.util.over_sample.oversample_mask_2d_from( - mask=np.array(mask), sub_size=2 - ) + def oversample_mask_2d_from(mask: np.ndarray, sub_size: int) -> np.ndarray: + + oversample_mask = np.full( + (mask.shape[0] * sub_size, mask.shape[1] * sub_size), True + ) + + for y in range(mask.shape[0]): + for x in range(mask.shape[1]): + if not mask[y, x]: + oversample_mask[ + y * sub_size : (y + 1) * sub_size, + x * sub_size : (x + 1) * sub_size, + ] = False + + return oversample_mask + + mask_sub_2 = oversample_mask_2d_from(mask=np.array(mask), sub_size=2) mask_sub_2 = aa.Mask2D(mask=mask_sub_2, pixel_scales=(0.5, 0.5)) diff --git a/test_autoarray/operators/over_sample/test_over_sample_util.py b/test_autoarray/operators/over_sample/test_over_sample_util.py index c0552d8a0..fb68da898 100644 --- a/test_autoarray/operators/over_sample/test_over_sample_util.py +++ b/test_autoarray/operators/over_sample/test_over_sample_util.py @@ -12,77 +12,6 @@ def test__total_sub_pixels_2d_from(): ) -def test__native_sub_index_for_slim_sub_index_2d_from(): - mask = np.array([[True, True, True], [True, False, True], [True, True, True]]) - - sub_mask_index_for_sub_mask_1d_index = ( - util.over_sample.native_sub_index_for_slim_sub_index_2d_from( - mask_2d=mask, sub_size=np.array([2]) - ) - ) - - assert ( - sub_mask_index_for_sub_mask_1d_index - == np.array([[2, 2], [2, 3], [3, 2], [3, 3]]) - ).all() - - mask = np.array([[True, False, True], [False, False, False], [True, False, True]]) - - sub_mask_index_for_sub_mask_1d_index = ( - util.over_sample.native_sub_index_for_slim_sub_index_2d_from( - mask_2d=mask, sub_size=np.array([2, 2, 2, 2, 2]) - ) - ) - - assert ( - sub_mask_index_for_sub_mask_1d_index - == np.array( - [ - [0, 2], - [0, 3], - [1, 2], - [1, 3], - [2, 0], - [2, 1], - [3, 0], - [3, 1], - [2, 2], - [2, 3], - [3, 2], - [3, 3], - [2, 4], - [2, 5], - [3, 4], - [3, 5], - [4, 2], - [4, 3], - [5, 2], - [5, 3], - ] - ) - ).all() - - mask = np.array( - [ - [True, True, True], - [True, False, True], - [True, True, True], - [True, True, False], - ] - ) - - sub_mask_index_for_sub_mask_1d_index = ( - util.over_sample.native_sub_index_for_slim_sub_index_2d_from( - mask_2d=mask, sub_size=np.array([2, 2]) - ) - ) - - assert ( - sub_mask_index_for_sub_mask_1d_index - == np.array([[2, 2], [2, 3], [3, 2], [3, 3], [6, 4], [6, 5], [7, 4], [7, 5]]) - ).all() - - def test__slim_index_for_sub_slim_index_via_mask_2d_from(): mask = np.array([[True, True, True], [True, False, True], [True, True, True]]) @@ -150,119 +79,11 @@ def test__slim_index_for_sub_slim_index_via_mask_2d_from(): ).all() -def test__sub_slim_index_for_sub_native_index_from(): - mask = np.full(fill_value=False, shape=(3, 3)) - - sub_mask_1d_index_for_sub_mask_index = ( - util.over_sample.sub_slim_index_for_sub_native_index_from(sub_mask_2d=mask) - ) - - assert ( - sub_mask_1d_index_for_sub_mask_index - == np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - ).all() - - mask = np.full(fill_value=False, shape=(2, 3)) - - sub_mask_1d_index_for_sub_mask_index = ( - util.over_sample.sub_slim_index_for_sub_native_index_from(sub_mask_2d=mask) - ) - - assert ( - sub_mask_1d_index_for_sub_mask_index == np.array([[0, 1, 2], [3, 4, 5]]) - ).all() - - mask = np.full(fill_value=False, shape=(3, 2)) - - sub_mask_1d_index_for_sub_mask_index = ( - util.over_sample.sub_slim_index_for_sub_native_index_from(sub_mask_2d=mask) - ) - - assert ( - sub_mask_1d_index_for_sub_mask_index == np.array([[0, 1], [2, 3], [4, 5]]) - ).all() - - mask = np.array([[False, True, False], [True, True, False], [False, False, True]]) - - sub_mask_1d_index_for_sub_mask_index = ( - util.over_sample.sub_slim_index_for_sub_native_index_from(sub_mask_2d=mask) - ) - - assert ( - sub_mask_1d_index_for_sub_mask_index - == np.array([[0, -1, 1], [-1, -1, 2], [3, 4, -1]]) - ).all() - - mask = np.array( - [ - [False, True, True, False], - [True, True, False, False], - [False, False, True, False], - ] - ) - - sub_mask_1d_index_for_sub_mask_index = ( - util.over_sample.sub_slim_index_for_sub_native_index_from(sub_mask_2d=mask) - ) - - assert ( - sub_mask_1d_index_for_sub_mask_index - == np.array([[0, -1, -1, 1], [-1, -1, 2, 3], [4, 5, -1, 6]]) - ).all() - - mask = np.array( - [ - [False, True, False], - [True, True, False], - [False, False, True], - [False, False, True], - ] - ) - - sub_mask_1d_index_for_sub_mask_index = ( - util.over_sample.sub_slim_index_for_sub_native_index_from(sub_mask_2d=mask) - ) - - assert ( - sub_mask_1d_index_for_sub_mask_index - == np.array([[0, -1, 1], [-1, -1, 2], [3, 4, -1], [5, 6, -1]]) - ).all() - - -def test__oversample_mask_from(): - mask = np.array( - [ - [True, True, True, True], - [True, False, False, True], - [True, False, False, True], - [True, True, True, True], - ] - ) - - oversample_mask = util.over_sample.oversample_mask_2d_from(mask=mask, sub_size=2) - - assert ( - oversample_mask - == np.array( - [ - [True, True, True, True, True, True, True, True], - [True, True, True, True, True, True, True, True], - [True, True, False, False, False, False, True, True], - [True, True, False, False, False, False, True, True], - [True, True, False, False, False, False, True, True], - [True, True, False, False, False, False, True, True], - [True, True, True, True, True, True, True, True], - [True, True, True, True, True, True, True, True], - ] - ) - ).all() - - def test__grid_2d_slim_over_sampled_via_mask_from(): mask = np.array([[True, True, False], [False, False, False], [True, True, False]]) grid = aa.util.over_sample.grid_2d_slim_over_sampled_via_mask_from( - mask_2d=mask, pixel_scales=(3.0, 3.0), sub_size=np.array([2, 2, 2, 2, 2]) + mask_2d=mask, pixel_scales=(3.0, 3.0), sub_size=2 ) assert ( diff --git a/test_autoarray/operators/over_sample/test_over_sampler.py b/test_autoarray/operators/over_sample/test_over_sampler.py index a32b11e8f..e7d601614 100644 --- a/test_autoarray/operators/over_sample/test_over_sampler.py +++ b/test_autoarray/operators/over_sample/test_over_sampler.py @@ -91,26 +91,6 @@ def test__binned_array_2d_from(): assert binned_array_2d.slim == pytest.approx(np.array([1.0, 8.0]), 1.0e-4) -def test__sub_mask_index_for_sub_mask_1d_index(): - mask = aa.Mask2D( - mask=[[True, True, True], [True, False, False], [True, True, False]], - pixel_scales=1.0, - sub_size=2, - ) - - over_sampling = aa.OverSampler(mask=mask, sub_size=2) - - sub_mask_index_for_sub_mask_1d_index = ( - aa.util.over_sample.native_sub_index_for_slim_sub_index_2d_from( - mask_2d=np.array(mask), sub_size=np.array([2, 2, 2]) - ) - ) - - assert over_sampling.sub_mask_native_for_sub_mask_slim == pytest.approx( - sub_mask_index_for_sub_mask_1d_index, 1e-4 - ) - - def test__slim_index_for_sub_slim_index(): mask = aa.Mask2D( mask=[[True, False, True], [False, False, False], [True, False, False]], diff --git a/test_autoarray/operators/test_transformer.py b/test_autoarray/operators/test_transformer.py index 33ea1ca42..4d432cf6b 100644 --- a/test_autoarray/operators/test_transformer.py +++ b/test_autoarray/operators/test_transformer.py @@ -4,259 +4,133 @@ import pytest -class MockDeriveMask2D: - def __init__(self, grid): - self.mask = grid.derive_mask.all_false - self.grid = grid - - @property - def sub_1(self): - return self - - @property - def derive_grid(self): - return MockDeriveGrid2D( - grid=self.grid, - ) - - -class MockDeriveGrid2D: - def __init__(self, grid): - self.unmasked = MockMaskedGrid(grid=grid) - - -class MockRealSpaceMask: - def __init__(self, grid): - self.grid = grid - self.unmasked = MockMaskedGrid(grid=grid) - - @property - def pixels_in_mask(self): - return self.unmasked.slim.in_radians.shape[0] - - @property - def derive_mask(self): - return MockDeriveMask2D( - grid=self.grid, - ) - - @property - def derive_grid(self): - return MockDeriveGrid2D( - grid=self.grid, - ) - - @property - def pixel_scales(self): - return self.grid.pixel_scales - - @property - def origin(self): - return self.grid.origin - - -class MockMaskedGrid: - def __init__(self, grid): - self.in_radians = grid - self.slim = grid - - -def test__dft__visibilities_from(): - uv_wavelengths = np.ones(shape=(4, 2)) - - grid_radians = aa.Grid2D.no_mask(values=[[[1.0, 1.0]]], pixel_scales=1.0) - - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__visibilities_from(visibilities_7, uv_wavelengths_7x2, mask_2d_7x7): transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) - image = aa.Array2D.ones(shape_native=(1, 1), pixel_scales=1.0) - - visibilities = transformer.visibilities_from(image=image) - - assert visibilities == pytest.approx( - np.array([1.0 + 0.0j, 1.0 + 0.0j, 1.0 + 0.0j, 1.0 + 0.0j]), 1.0e-4 - ) - - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - - real_space_mask = MockRealSpaceMask(grid=grid_radians) - - transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, + image = aa.Array2D( + values=[ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + mask=mask_2d_7x7, ) - image = aa.Array2D.ones(shape_native=(1, 2), pixel_scales=1.0) - visibilities = transformer.visibilities_from(image=image) - assert visibilities == pytest.approx( + assert visibilities[0:3] == pytest.approx( np.array( - [-0.091544 - 1.45506j, -0.73359736 - 0.781201j, -0.613160 - 0.077460j] + [ + -0.06434514 - 0.61763293j, + 1.71143349 - 1.184022j, + 0.90200541 + 0.03726693j, + ] ), 1.0e-4, ) - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__image_from(visibilities_7, uv_wavelengths_7x2, mask_2d_7x7): transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) - image = aa.Array2D.no_mask([[3.0, 6.0]], pixel_scales=1.0) - - visibilities = transformer.visibilities_from(image=image) + image = transformer.image_from(visibilities=visibilities_7) - assert visibilities == pytest.approx( - np.array([-2.46153 - 6.418822j, -5.14765 - 1.78146j, -3.11681 + 2.48210j]), - 1.0e-4, - ) + assert image[0:3] == pytest.approx([-1.49022481, -0.22395855, -0.45588535], 1.0e-4) -def test__dft__visibilities_from__preload_and_non_preload_give_same_answer(): - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__visibilities_from__preload_and_non_preload_give_same_answer( + visibilities_7, uv_wavelengths_7x2, mask_2d_7x7 +): transformer_preload = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=True, ) transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) - image = aa.Array2D.no_mask([[2.0, 6.0]], pixel_scales=1.0) + image = aa.Array2D( + values=[ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + mask=mask_2d_7x7, + ) visibilities_via_preload = transformer_preload.visibilities_from(image=image) visibilities = transformer.visibilities_from(image=image) - assert (visibilities_via_preload == visibilities).all() + assert visibilities_via_preload == pytest.approx(visibilities.array, 1.0e-4) -def test__dft__transform_mapping_matrix(): - uv_wavelengths = np.ones(shape=(4, 2)) - grid_radians = aa.Grid2D.no_mask(values=[[[1.0, 1.0]]], pixel_scales=1.0) - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__transform_mapping_matrix( + visibilities_7, uv_wavelengths_7x2, mask_2d_7x7 +): transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) - mapping_matrix = np.ones(shape=(1, 1)) + mapping_matrix = np.ones(shape=(9, 1)) transformed_mapping_matrix = transformer.transform_mapping_matrix( mapping_matrix=mapping_matrix ) - assert transformed_mapping_matrix == pytest.approx( - np.array([[1.0 + 0.0j], [1.0 + 0.0j], [1.0 + 0.0j], [1.0 + 0.0j]]), 1.0e-4 - ) - - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) - - transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) - - mapping_matrix = np.ones(shape=(2, 2)) - - transformed_mapping_matrix = transformer.transform_mapping_matrix( - mapping_matrix=mapping_matrix - ) - - assert transformed_mapping_matrix == pytest.approx( - np.array( - [ - [-0.091544 - 1.45506j, -0.091544 - 1.45506j], - [-0.733597 - 0.78120j, -0.733597 - 0.78120j], - [-0.61316 - 0.07746j, -0.61316 - 0.07746j], - ] - ), - 1.0e-4, - ) - - grid_radians = aa.Grid2D.no_mask( - [[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) - - uv_wavelengths = np.array([[0.7, 0.8], [0.9, 1.0]]) - - transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) - - mapping_matrix = np.array([[0.0, 0.5], [0.0, 0.2], [1.0, 0.0]]) - - transformed_mapping_matrix = transformer.transform_mapping_matrix( - mapping_matrix=mapping_matrix - ) - - assert transformed_mapping_matrix == pytest.approx( + assert transformed_mapping_matrix[0:3, :] == pytest.approx( np.array( [ - [0.42577 + 0.90482j, -0.10473 - 0.46607j], - [0.968583 - 0.24868j, -0.20085 - 0.32227j], + [1.48496084 + 0.00000000e00j], + [3.02988906 + 4.44089210e-16], + [0.86395556 + 0.00000000e00], ] ), - 1.0e-4, + abs=1.0e-4, ) -def test__dft__transformed_mapping_matrix__preload_and_non_preload_give_same_answer(): - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - grid_radians = aa.Grid2D.no_mask( - values=[[[0.1, 0.2], [0.3, 0.4]]], pixel_scales=1.0 - ) - real_space_mask = MockRealSpaceMask(grid=grid_radians) +def test__dft__transformed_mapping_matrix__preload_and_non_preload_give_same_answer( + visibilities_7, uv_wavelengths_7x2, mask_2d_7x7 +): transformer_preload = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=True, ) transformer = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, preload_transform=False, ) - mapping_matrix = np.array([[3.0, 5.0], [1.0, 2.0]]) + mapping_matrix = np.ones(shape=(9, 1)) transformed_mapping_matrix_preload = transformer_preload.transform_mapping_matrix( mapping_matrix=mapping_matrix @@ -270,53 +144,40 @@ def test__dft__transformed_mapping_matrix__preload_and_non_preload_give_same_ans def test__nufft__visibilities_from(): - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) - grid_radians = aa.Grid2D.uniform(shape_native=(5, 5), pixel_scales=0.005).in_radians - real_space_mask = MockRealSpaceMask(grid=grid_radians) + uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) + real_space_mask = aa.Mask2D.all_false(shape_native=(5, 5), pixel_scales=0.005) image = aa.Array2D.ones( - shape_native=grid_radians.shape_native, - pixel_scales=grid_radians.pixel_scales, + shape_native=(5, 5), + pixel_scales=0.005, ) - transformer_dft = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) - - visibilities_dft = transformer_dft.visibilities_from(image=image.native) - - real_space_mask = aa.Mask2D.all_false(shape_native=(5, 5), pixel_scales=0.005) - transformer_nufft = aa.TransformerNUFFT( uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask ) visibilities_nufft = transformer_nufft.visibilities_from(image=image.native) - assert visibilities_dft == pytest.approx(visibilities_nufft, 2.0) assert visibilities_nufft[0] == pytest.approx(25.02317617953263 + 0.0j, 1.0e-7) -def test__nufft__transform_mapping_matrix(): - uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) +def test__nufft__image_from(visibilities_7, uv_wavelengths_7x2, mask_2d_7x7): - grid_radians = aa.Grid2D.uniform(shape_native=(5, 5), pixel_scales=0.005) - real_space_mask = MockRealSpaceMask(grid=grid_radians) + transformer = aa.TransformerNUFFT( + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, + ) - mapping_matrix = np.ones(shape=(25, 3)) + image = transformer.image_from(visibilities=visibilities_7) - transformer_dft = aa.TransformerDFT( - uv_wavelengths=uv_wavelengths, - real_space_mask=real_space_mask, - preload_transform=False, - ) + assert image[0:3] == pytest.approx([0.00726546, 0.01149121, 0.01421022], 1.0e-4) - transformed_mapping_matrix_dft = transformer_dft.transform_mapping_matrix( - mapping_matrix=mapping_matrix - ) + +def test__nufft__transform_mapping_matrix(): + uv_wavelengths = np.array([[0.2, 1.0], [0.5, 1.1], [0.8, 1.2]]) + + mapping_matrix = np.ones(shape=(25, 3)) real_space_mask = aa.Mask2D.all_false(shape_native=(5, 5), pixel_scales=0.005) @@ -328,13 +189,6 @@ def test__nufft__transform_mapping_matrix(): mapping_matrix=mapping_matrix ) - assert transformed_mapping_matrix_dft == pytest.approx( - transformed_mapping_matrix_nufft, 2.0 - ) - assert transformed_mapping_matrix_dft == pytest.approx( - transformed_mapping_matrix_nufft, 2.0 - ) - assert transformed_mapping_matrix_nufft[0, 0] == pytest.approx( 25.02317 + 0.0j, 1.0e-4 ) diff --git a/test_autoarray/plot/get_visuals/__init__.py b/test_autoarray/plot/get_visuals/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_autoarray/plot/get_visuals/test_one_d.py b/test_autoarray/plot/get_visuals/test_one_d.py deleted file mode 100644 index 73f05bf43..000000000 --- a/test_autoarray/plot/get_visuals/test_one_d.py +++ /dev/null @@ -1,32 +0,0 @@ -from os import path -import pytest - -import autoarray.plot as aplt - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "imaging" - ) - - -def test__via_array_1d_from(array_1d_7): - visuals_1d = aplt.Visuals1D(origin=(1.0, 1.0)) - include_1d = aplt.Include1D(origin=True, mask=True) - - get_visuals = aplt.GetVisuals1D(include=include_1d, visuals=visuals_1d) - - visuals_1d_via = get_visuals.via_array_1d_from(array_1d=array_1d_7) - - assert visuals_1d_via.origin == (1.0, 1.0) - assert (visuals_1d_via.mask == array_1d_7.mask).all() - - include_1d = aplt.Include1D(origin=False, mask=False) - - get_visuals = aplt.GetVisuals1D(include=include_1d, visuals=visuals_1d) - - visuals_1d_via = get_visuals.via_array_1d_from(array_1d=array_1d_7) - - assert visuals_1d_via.origin == (1.0, 1.0) - assert visuals_1d_via.mask == None diff --git a/test_autoarray/plot/get_visuals/test_two_d.py b/test_autoarray/plot/get_visuals/test_two_d.py deleted file mode 100644 index 5494bcf22..000000000 --- a/test_autoarray/plot/get_visuals/test_two_d.py +++ /dev/null @@ -1,164 +0,0 @@ -from os import path -import pytest - -import autoarray.plot as aplt - - -@pytest.fixture(name="plot_path") -def make_plot_path_setup(): - return path.join( - "{}".format(path.dirname(path.realpath(__file__))), "files", "plots", "imaging" - ) - - -def test__via_mask_from(mask_2d_7x7): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0), vectors=2) - include_2d = aplt.Include2D(origin=True, mask=True, border=True) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mask_from(mask=mask_2d_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert (visuals_2d_via.mask == mask_2d_7x7).all() - assert (visuals_2d_via.border == mask_2d_7x7.derive_grid.border).all() - assert visuals_2d_via.vectors == 2 - - include_2d = aplt.Include2D(origin=False, mask=False, border=False) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mask_from(mask=mask_2d_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert visuals_2d_via.mask == None - assert visuals_2d_via.border == None - assert visuals_2d_via.vectors == 2 - - -def test__via_grid_from(grid_2d_7x7): - visuals_2d = aplt.Visuals2D() - include_2d = aplt.Include2D(origin=True) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_grid_from(grid=grid_2d_7x7) - - assert (visuals_2d_via.origin == grid_2d_7x7.origin).all() - - include_2d = aplt.Include2D(origin=False) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_grid_from(grid=grid_2d_7x7) - - assert visuals_2d_via.origin == None - - -def test__via_mapper_for_data_from(voronoi_mapper_9_3x3): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0)) - include_2d = aplt.Include2D( - origin=True, mask=True, border=True, mapper_image_plane_mesh_grid=True - ) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mapper_for_data_from(mapper=voronoi_mapper_9_3x3) - - assert visuals_2d.origin == (1.0, 1.0) - assert (visuals_2d_via.mask == voronoi_mapper_9_3x3.mapper_grids.mask).all() - assert ( - visuals_2d_via.border - == voronoi_mapper_9_3x3.mapper_grids.mask.derive_grid.border - ).all() - - assert ( - visuals_2d_via.mesh_grid == voronoi_mapper_9_3x3.image_plane_mesh_grid - ).all() - - include_2d = aplt.Include2D( - origin=False, mask=False, border=False, mapper_image_plane_mesh_grid=False - ) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mapper_for_data_from(mapper=voronoi_mapper_9_3x3) - - assert visuals_2d.origin == (1.0, 1.0) - assert visuals_2d_via.mask == None - assert visuals_2d_via.border == None - assert visuals_2d_via.mesh_grid == None - - -def test__via_mapper_for_source_from(rectangular_mapper_7x7_3x3): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0)) - include_2d = aplt.Include2D( - origin=True, - border=True, - mapper_source_plane_data_grid=True, - mapper_source_plane_mesh_grid=True, - ) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mapper_for_source_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert visuals_2d.origin == (1.0, 1.0) - assert ( - visuals_2d_via.grid - == rectangular_mapper_7x7_3x3.source_plane_data_grid.over_sampled - ).all() - border_grid = ( - rectangular_mapper_7x7_3x3.mapper_grids.source_plane_data_grid.over_sampled[ - rectangular_mapper_7x7_3x3.border_relocator.sub_border_slim - ] - ) - assert (visuals_2d_via.border == border_grid).all() - assert ( - visuals_2d_via.mesh_grid == rectangular_mapper_7x7_3x3.source_plane_mesh_grid - ).all() - - include_2d = aplt.Include2D( - origin=False, - border=False, - mapper_source_plane_data_grid=False, - mapper_source_plane_mesh_grid=False, - ) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_mapper_for_source_from( - mapper=rectangular_mapper_7x7_3x3 - ) - - assert visuals_2d.origin == (1.0, 1.0) - assert visuals_2d_via.grid == None - assert visuals_2d_via.border == None - assert visuals_2d_via.mesh_grid == None - - -def test__via_fit_imaging_from(fit_imaging_7x7): - visuals_2d = aplt.Visuals2D(origin=(1.0, 1.0), vectors=2) - include_2d = aplt.Include2D(origin=True, mask=True, border=True) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_fit_imaging_from(fit=fit_imaging_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert (visuals_2d_via.mask == fit_imaging_7x7.mask).all() - assert (visuals_2d_via.border == fit_imaging_7x7.mask.derive_grid.border).all() - assert visuals_2d_via.vectors == 2 - - include_2d = aplt.Include2D(origin=False, mask=False, border=False) - - get_visuals = aplt.GetVisuals2D(include=include_2d, visuals=visuals_2d) - - visuals_2d_via = get_visuals.via_fit_imaging_from(fit=fit_imaging_7x7) - - assert visuals_2d_via.origin == (1.0, 1.0) - assert visuals_2d_via.mask == None - assert visuals_2d_via.border == None - assert visuals_2d_via.vectors == 2 diff --git a/test_autoarray/plot/include/__init__.py b/test_autoarray/plot/include/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_autoarray/plot/include/test_include.py b/test_autoarray/plot/include/test_include.py deleted file mode 100644 index b32616e9d..000000000 --- a/test_autoarray/plot/include/test_include.py +++ /dev/null @@ -1,21 +0,0 @@ -import autoarray.plot as aplt - - -def test__loads_default_values_from_config_if_not_input(): - include = aplt.Include2D() - - assert include.origin is True - assert include.mask == True - assert include.border is False - assert include.parallel_overscan is True - assert include.serial_prescan is True - assert include.serial_overscan is False - - include = aplt.Include2D(origin=False, border=False, serial_overscan=True) - - assert include.origin is False - assert include.mask == True - assert include.border is False - assert include.parallel_overscan is True - assert include.serial_prescan is True - assert include.serial_overscan is True diff --git a/test_autoarray/plot/test_abstract_plotters.py b/test_autoarray/plot/test_abstract_plotters.py index 652c97422..05c905f7c 100644 --- a/test_autoarray/plot/test_abstract_plotters.py +++ b/test_autoarray/plot/test_abstract_plotters.py @@ -24,28 +24,28 @@ def test__get_subplot_shape(): plotter.mat_plot_2d.get_subplot_shape(number_subplots=1000) -def test__get_subplot_figsize(): - plotter = abstract_plotters.AbstractPlotter( - mat_plot_2d=aplt.MatPlot2D(figure=aplt.Figure(figsize="auto")) - ) - - figsize = plotter.get_subplot_figsize(number_subplots=1) - - assert figsize == (6, 6) - - figsize = plotter.get_subplot_figsize(number_subplots=4) - - assert figsize == (12, 12) - - figure = aplt.Figure(figsize=(20, 20)) - - plotter = abstract_plotters.AbstractPlotter( - mat_plot_2d=aplt.MatPlot2D(figure=figure) - ) - - figsize = plotter.get_subplot_figsize(number_subplots=4) - - assert figsize == (20, 20) +# def test__get_subplot_figsize(): +# plotter = abstract_plotters.AbstractPlotter( +# mat_plot_2d=aplt.MatPlot2D(figure=aplt.Figure(figsize="auto")) +# ) +# +# figsize = plotter.get_subplot_figsize(number_subplots=1) +# +# assert figsize == (7, 7) +# +# figsize = plotter.get_subplot_figsize(number_subplots=4) +# +# assert figsize == (7, 7) +# +# figure = aplt.Figure(figsize=(20, 20)) +# +# plotter = abstract_plotters.AbstractPlotter( +# mat_plot_2d=aplt.MatPlot2D(figure=figure) +# ) +# +# figsize = plotter.get_subplot_figsize(number_subplots=4) +# +# assert figsize == (20, 20) def test__open_and_close_subplot_figures(): @@ -105,33 +105,3 @@ def test__uses_figure_or_subplot_configs_correctly(): assert plotter.mat_plot_2d.figure.config_dict["aspect"] == "square" assert plotter.mat_plot_2d.cmap.config_dict["cmap"] == "default" assert plotter.mat_plot_2d.cmap.config_dict["norm"] == "linear" - - -def test__get__visuals(): - visuals_2d = aplt.Visuals2D() - include_2d = aplt.Include2D(origin=False) - - plotter = abstract_plotters.Plotter(visuals_2d=visuals_2d, include_2d=include_2d) - attr = plotter.get_2d.get(name="origin", value=1) - - assert attr == None - - include_2d = aplt.Include2D(origin=True) - plotter = abstract_plotters.Plotter(visuals_2d=visuals_2d, include_2d=include_2d) - attr = plotter.get_2d.get(name="origin", value=1) - - assert attr == 1 - - visuals_2d = aplt.Visuals2D(origin=10) - - include_2d = aplt.Include2D(origin=False) - plotter = abstract_plotters.Plotter(visuals_2d=visuals_2d, include_2d=include_2d) - attr = plotter.get_2d.get(name="origin", value=2) - - assert attr == 10 - - include_2d = aplt.Include2D(origin=True) - plotter = abstract_plotters.Plotter(visuals_2d=visuals_2d, include_2d=include_2d) - attr = plotter.get_2d.get(name="origin", value=2) - - assert attr == 10 diff --git a/test_autoarray/plot/test_multi_plotters.py b/test_autoarray/plot/test_multi_plotters.py index 49494e144..9c2048ac3 100644 --- a/test_autoarray/plot/test_multi_plotters.py +++ b/test_autoarray/plot/test_multi_plotters.py @@ -45,16 +45,14 @@ def __init__( self, y, x, - mat_plot_1d: aplt.MatPlot1D = aplt.MatPlot1D(), - visuals_1d: aplt.Visuals1D = aplt.Visuals1D(), - include_1d: aplt.Include1D = aplt.Include1D(), + mat_plot_1d: aplt.MatPlot1D = None, + visuals_1d: aplt.Visuals1D = None, ): super().__init__( y=y, x=x, mat_plot_1d=mat_plot_1d, visuals_1d=visuals_1d, - include_1d=include_1d, ) def figures_1d(self, figure_name=False): diff --git a/test_autoarray/plot/wrap/base/test_ticks.py b/test_autoarray/plot/wrap/base/test_ticks.py index ead11e885..51f8174e8 100644 --- a/test_autoarray/plot/wrap/base/test_ticks.py +++ b/test_autoarray/plot/wrap/base/test_ticks.py @@ -60,7 +60,9 @@ def test__yticks__set(): units = aplt.Units(use_scaled=True, ticks_convert_factor=None) yticks = aplt.YTicks(fontsize=34) - extent = array.extent_of_zoomed_array(buffer=1) + zoom = aa.Zoom2D(mask=array.mask) + array_zoom = zoom.array_2d_from(array=array, buffer=1) + extent = array_zoom.geometry.extent yticks.set(min_value=extent[2], max_value=extent[3], units=units) yticks = aplt.YTicks(fontsize=34) @@ -105,7 +107,9 @@ def test__xticks__set(): array = aa.Array2D.ones(shape_native=(2, 2), pixel_scales=1.0) units = aplt.Units(use_scaled=True, ticks_convert_factor=None) xticks = aplt.XTicks(fontsize=34) - extent = array.extent_of_zoomed_array(buffer=1) + zoom = aa.Zoom2D(mask=array.mask) + array_zoom = zoom.array_2d_from(array=array, buffer=1) + extent = array_zoom.geometry.extent xticks.set(min_value=extent[0], max_value=extent[1], units=units) xticks = aplt.XTicks(fontsize=34) diff --git a/test_autoarray/structures/arrays/test_array_2d_util.py b/test_autoarray/structures/arrays/test_array_2d_util.py index 9468dd510..0c6ae8bc7 100644 --- a/test_autoarray/structures/arrays/test_array_2d_util.py +++ b/test_autoarray/structures/arrays/test_array_2d_util.py @@ -287,96 +287,6 @@ def test__resized_array_2d_from__padding_with_new_origin(): ).all() -def test__replace_noise_map_2d_values_where_image_2d_values_are_negative(): - image_2d = np.ones(shape=(2, 2)) - - noise_map_2d = np.array([[1.0, 2.0], [3.0, 4.0]]) - - noise_map_2d = ( - util.array_2d.replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=1.0 - ) - ) - - assert (noise_map_2d == noise_map_2d).all() - - image_2d = -1.0 * np.ones(shape=(2, 2)) - - noise_map_2d = np.array([[1.0, 0.5], [0.25, 0.125]]) - - noise_map_2d = ( - util.array_2d.replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=10.0 - ) - ) - - assert (noise_map_2d == noise_map_2d).all() - - noise_map_2d = ( - util.array_2d.replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=4.0 - ) - ) - - assert (noise_map_2d == np.array([[1.0, 0.5], [0.25, 0.25]])).all() - - noise_map_2d = np.array([[1.0, 0.5], [0.25, 0.125]]) - - noise_map_2d = ( - util.array_2d.replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=2.0 - ) - ) - - assert (noise_map_2d == np.array([[1.0, 0.5], [0.5, 0.5]])).all() - - noise_map_2d = np.array([[1.0, 0.5], [0.25, 0.125]]) - - noise_map_2d = ( - util.array_2d.replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=1.0 - ) - ) - - assert (noise_map_2d == np.array([[1.0, 1.0], [1.0, 1.0]])).all() - - noise_map_2d = np.array([[1.0, 0.5], [0.25, 0.125]]) - - noise_map_2d = ( - util.array_2d.replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=0.5 - ) - ) - - assert (noise_map_2d == np.array([[2.0, 2.0], [2.0, 2.0]])).all() - - -def test__same_as_above__image_not_all_negative(): - image_2d = np.array([[1.0, -2.0], [5.0, -4.0]]) - - noise_map_2d = np.array([[3.0, 1.0], [4.0, 8.0]]) - - noise_map_2d = ( - util.array_2d.replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=1.0 - ) - ) - - assert (noise_map_2d == np.array([[3.0, 2.0], [4.0, 8.0]])).all() - - image_2d = np.array([[-10.0, -20.0], [100.0, -30.0]]) - - noise_map_2d = np.array([[1.0, 2.0], [40.0, 3.0]]) - - noise_map_2d = ( - util.array_2d.replace_noise_map_2d_values_where_image_2d_values_are_negative( - image_2d=image_2d, noise_map_2d=noise_map_2d, target_signal_to_noise=5.0 - ) - ) - - assert (noise_map_2d == np.array([[2.0, 4.0], [40.0, 6.0]])).all() - - def test__index_2d_for_index_slim_from(): indexes_1d = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) @@ -521,25 +431,6 @@ def test__array_2d_slim_from(): assert (array_2d_slim == np.array([2, 4, 5, 6, 8])).all() -def test__array_2d_slim_from__complex_array(): - array_2d = np.array( - [ - [1 + 1j, 2 + 2j, 3 + 3], - [4 + 4j, 5 + 5j, 6 + 6j], - [7 + 7j, 8 + 8j, 9 + 9j], - ] - ) - - mask = np.array([[True, True, True], [True, False, True], [True, True, True]]) - - array_2d_slim = util.array_2d.array_2d_slim_complex_from( - mask=mask, - array_2d_native=array_2d, - ) - - assert (array_2d_slim == np.array([5 + 5j])).all() - - def test__array_2d_native_from(): array_2d_slim = np.array([1.0, 2.0, 3.0, 4.0]) @@ -584,19 +475,3 @@ def test__array_2d_native_from(): [[1.0, 2.0, 0.0, 0.0], [3.0, 0.0, 0.0, 0.0], [-1.0, -2.0, 0.0, -3.0]] ) ).all() - - -def test__array_2d_native_from__compelx_array(): - array_2d_slim = np.array( - [1.0 + 1j, 2.0 + 2j, 3.0 + 3j, 4.0 + 4j], dtype="complex128" - ) - - array_2d = util.array_2d.array_2d_native_complex_via_indexes_from( - array_2d_slim=array_2d_slim, - shape_native=(2, 2), - native_index_for_slim_index_2d=np.array( - [[0, 0], [0, 1], [1, 0], [1, 1]], dtype="int" - ), - ) - - assert (array_2d == np.array([[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]])).all() diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 4cf8b92d7..6fa4e7295 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -1,13 +1,11 @@ from astropy import units from astropy.modeling import functional_models from astropy.coordinates import Angle -import jax.numpy as jnp import numpy as np import pytest from os import path import autoarray as aa -from autoarray import exc test_data_path = path.join("{}".format(path.dirname(path.realpath(__file__))), "files") @@ -150,7 +148,7 @@ def test__rescaled_with_odd_dimensions_from__evens_to_odds(): rescale_factor=0.5, normalize=True ) assert kernel_2d.pixel_scales == (2.0, 2.0) - assert (kernel_2d.native == (1.0 / 9.0) * np.ones((3, 3))).all() + assert kernel_2d.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4) array_2d = np.ones((9, 9)) kernel_2d = aa.Kernel2D.no_mask(values=array_2d, pixel_scales=1.0, normalize=False) @@ -158,7 +156,7 @@ def test__rescaled_with_odd_dimensions_from__evens_to_odds(): rescale_factor=0.333333333333333, normalize=True ) assert kernel_2d.pixel_scales == (3.0, 3.0) - assert (kernel_2d.native == (1.0 / 9.0) * np.ones((3, 3))).all() + assert kernel_2d.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4) array_2d = np.ones((18, 6)) kernel_2d = aa.Kernel2D.no_mask(values=array_2d, pixel_scales=1.0, normalize=False) @@ -166,7 +164,7 @@ def test__rescaled_with_odd_dimensions_from__evens_to_odds(): rescale_factor=0.5, normalize=True ) assert kernel_2d.pixel_scales == (2.0, 2.0) - assert (kernel_2d.native == (1.0 / 27.0) * np.ones((9, 3))).all() + assert kernel_2d.native == pytest.approx((1.0 / 27.0) * np.ones((9, 3)), 1.0e-4) array_2d = np.ones((6, 18)) kernel_2d = aa.Kernel2D.no_mask(values=array_2d, pixel_scales=1.0, normalize=False) @@ -174,7 +172,7 @@ def test__rescaled_with_odd_dimensions_from__evens_to_odds(): rescale_factor=0.5, normalize=True ) assert kernel_2d.pixel_scales == (2.0, 2.0) - assert (kernel_2d.native == (1.0 / 27.0) * np.ones((3, 9))).all() + assert kernel_2d.native == pytest.approx((1.0 / 27.0) * np.ones((3, 9)), 1.0e-4) def test__rescaled_with_odd_dimensions_from__different_scalings(): @@ -183,7 +181,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): rescale_factor=2.0, normalize=True ) assert kernel_2d.pixel_scales == (0.4, 0.4) - assert (kernel_2d.native == (1.0 / 25.0) * np.ones((5, 5))).all() + assert kernel_2d.native == pytest.approx((1.0 / 25.0) * np.ones((5, 5)), 1.0e-4) kernel_2d = aa.Kernel2D.ones( shape_native=(40, 40), pixel_scales=1.0, normalize=False @@ -192,7 +190,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): rescale_factor=0.1, normalize=True ) assert kernel_2d.pixel_scales == (8.0, 8.0) - assert (kernel_2d.native == (1.0 / 25.0) * np.ones((5, 5))).all() + assert kernel_2d.native == pytest.approx((1.0 / 25.0) * np.ones((5, 5)), 1.0e-4) kernel_2d = aa.Kernel2D.ones(shape_native=(2, 4), pixel_scales=1.0, normalize=False) kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from( @@ -201,7 +199,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): assert kernel_2d.pixel_scales[0] == pytest.approx(0.4, 1.0e-4) assert kernel_2d.pixel_scales[1] == pytest.approx(0.4444444, 1.0e-4) - assert (kernel_2d.native == (1.0 / 45.0) * np.ones((5, 9))).all() + assert kernel_2d.native == pytest.approx((1.0 / 45.0) * np.ones((5, 9)), 1.0e-4) kernel_2d = aa.Kernel2D.ones(shape_native=(4, 2), pixel_scales=1.0, normalize=False) kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from( @@ -209,7 +207,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): ) assert kernel_2d.pixel_scales[0] == pytest.approx(0.4444444, 1.0e-4) assert kernel_2d.pixel_scales[1] == pytest.approx(0.4, 1.0e-4) - assert (kernel_2d.native == (1.0 / 45.0) * np.ones((9, 5))).all() + assert kernel_2d.native == pytest.approx((1.0 / 45.0) * np.ones((9, 5)), 1.0e-4) kernel_2d = aa.Kernel2D.ones(shape_native=(6, 4), pixel_scales=1.0, normalize=False) kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from( @@ -217,7 +215,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): ) assert kernel_2d.pixel_scales == pytest.approx((2.0, 1.3333333333), 1.0e-4) - assert (kernel_2d.native == (1.0 / 9.0) * np.ones((3, 3))).all() + assert kernel_2d.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4) kernel_2d = aa.Kernel2D.ones( shape_native=(9, 12), pixel_scales=1.0, normalize=False @@ -227,7 +225,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): ) assert kernel_2d.pixel_scales == pytest.approx((3.0, 2.4), 1.0e-4) - assert (kernel_2d.native == (1.0 / 15.0) * np.ones((3, 5))).all() + assert kernel_2d.native == pytest.approx((1.0 / 15.0) * np.ones((3, 5)), 1.0e-4) kernel_2d = aa.Kernel2D.ones(shape_native=(4, 6), pixel_scales=1.0, normalize=False) kernel_2d = kernel_2d.rescaled_with_odd_dimensions_from( @@ -235,7 +233,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): ) assert kernel_2d.pixel_scales == pytest.approx((1.33333333333, 2.0), 1.0e-4) - assert (kernel_2d.native == (1.0 / 9.0) * np.ones((3, 3))).all() + assert kernel_2d.native == pytest.approx((1.0 / 9.0) * np.ones((3, 3)), 1.0e-4) kernel_2d = aa.Kernel2D.ones( shape_native=(12, 9), pixel_scales=1.0, normalize=False @@ -244,7 +242,7 @@ def test__rescaled_with_odd_dimensions_from__different_scalings(): rescale_factor=0.33333333333, normalize=True ) assert kernel_2d.pixel_scales == pytest.approx((2.4, 3.0), 1.0e-4) - assert (kernel_2d.native == (1.0 / 15.0) * np.ones((5, 3))).all() + assert kernel_2d.native == pytest.approx((1.0 / 15.0) * np.ones((5, 3)), 1.0e-4) def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy_gaussian_model(): diff --git a/test_autoarray/structures/arrays/test_uniform_2d.py b/test_autoarray/structures/arrays/test_uniform_2d.py index 39cbc71df..0c55b0c38 100644 --- a/test_autoarray/structures/arrays/test_uniform_2d.py +++ b/test_autoarray/structures/arrays/test_uniform_2d.py @@ -355,135 +355,6 @@ def test__trimmed_after_convolution_from(): assert new_arr.mask.pixel_scales == (1.0, 1.0) -def test__zoomed_around_mask(): - array_2d = [ - [1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0], - ] - - mask = aa.Mask2D( - mask=[ - [True, True, True, True], - [True, False, False, True], - [True, False, False, True], - [True, True, True, True], - ], - pixel_scales=(1.0, 1.0), - ) - - arr_masked = aa.Array2D(values=array_2d, mask=mask) - - arr_zoomed = arr_masked.zoomed_around_mask(buffer=0) - - assert (arr_zoomed.native == np.array([[6.0, 7.0], [10.0, 11.0]])).all() - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True, True], - [True, False, False, True], - [False, False, False, True], - [True, True, True, True], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - arr_masked = aa.Array2D(values=array_2d, mask=mask) - arr_zoomed = arr_masked.zoomed_around_mask(buffer=0) - - assert (arr_zoomed.native == np.array([[0.0, 6.0, 7.0], [9.0, 10.0, 11.0]])).all() - - mask = aa.Mask2D( - mask=np.array( - [ - [True, False, True, True], - [True, False, False, True], - [True, False, False, True], - [True, True, True, True], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - arr_masked = aa.Array2D(values=array_2d, mask=mask) - arr_zoomed = arr_masked.zoomed_around_mask(buffer=0) - assert (arr_zoomed.native == np.array([[2.0, 0.0], [6.0, 7.0], [10.0, 11.0]])).all() - - -def test__zoomed_around_mask__origin_updated(): - array_2d = np.ones(shape=(4, 4)) - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True, True], - [True, False, False, True], - [True, False, False, True], - [True, True, True, True], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - arr_masked = aa.Array2D(values=array_2d, mask=mask) - - arr_zoomed = arr_masked.zoomed_around_mask(buffer=0) - - assert arr_zoomed.mask.origin == (0.0, 0.0) - - array_2d = np.ones(shape=(6, 6)) - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True, True, True, True], - [True, True, True, True, True, True], - [True, True, True, False, False, True], - [True, True, True, False, False, True], - [True, True, True, True, True, True], - [True, True, True, True, True, True], - ] - ), - pixel_scales=(1.0, 1.0), - ) - - arr_masked = aa.Array2D(values=array_2d, mask=mask) - - arr_zoomed = arr_masked.zoomed_around_mask(buffer=0) - - assert arr_zoomed.mask.origin == (0.0, 1.0) - - -def test__extent_of_zoomed_array(): - array_2d = [ - [1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0], - ] - - mask = aa.Mask2D( - mask=np.array( - [ - [True, True, True, False], - [True, False, False, True], - [True, False, False, True], - [True, True, True, True], - ] - ), - pixel_scales=(1.0, 2.0), - ) - - arr_masked = aa.Array2D(values=array_2d, mask=mask) - - extent = arr_masked.extent_of_zoomed_array(buffer=1) - - assert extent == pytest.approx(np.array([-4.0, 6.0, -2.0, 3.0]), 1.0e-4) - - def test__binned_across_rows(): array = aa.Array2D.no_mask(values=np.ones((4, 3)), pixel_scales=1.0) diff --git a/test_autoarray/structures/grids/test_grid_2d_util.py b/test_autoarray/structures/grids/test_grid_2d_util.py index edf1d2497..79c127310 100644 --- a/test_autoarray/structures/grids/test_grid_2d_util.py +++ b/test_autoarray/structures/grids/test_grid_2d_util.py @@ -147,6 +147,112 @@ def test__grid_2d_slim_via_shape_native_from(): ).all() +def test__grid_2d_slim_via_shape_native_not_mask_from(): + grid_2d = aa.util.grid_2d.grid_2d_slim_via_shape_native_not_mask_from( + shape_native=(2, 3), + pixel_scales=(1.0, 1.0), + ) + + assert ( + grid_2d + == np.array( + [ + [0.5, -1.0], + [0.5, 0.0], + [0.5, 1.0], + [-0.5, -1.0], + [-0.5, 0.0], + [-0.5, 1.0], + ] + ) + ).all() + + grid_2d = aa.util.grid_2d.grid_2d_slim_via_shape_native_not_mask_from( + shape_native=(3, 2), + pixel_scales=(1.0, 1.0), + ) + + assert ( + grid_2d + == np.array( + [ + [1.0, -0.5], + [1.0, 0.5], + [0.0, -0.5], + [0.0, 0.5], + [-1.0, -0.5], + [-1.0, 0.5], + ] + ) + ).all() + + grid_2d = aa.util.grid_2d.grid_2d_slim_via_shape_native_not_mask_from( + shape_native=(3, 2), pixel_scales=(1.0, 1.0), origin=(3.0, -2.0) + ) + + assert ( + grid_2d + == np.array( + [ + [4.0, -2.5], + [4.0, -1.5], + [3.0, -2.5], + [3.0, -1.5], + [2.0, -2.5], + [2.0, -1.5], + ] + ) + ).all() + + +def test__grid_2d_via_shape_native_from(): + grid_2d = aa.util.grid_2d.grid_2d_via_shape_native_from( + shape_native=(2, 3), + pixel_scales=(1.0, 1.0), + ) + + assert ( + grid_2d + == np.array( + [ + [[0.5, -1.0], [0.5, 0.0], [0.5, 1.0]], + [[-0.5, -1.0], [-0.5, 0.0], [-0.5, 1.0]], + ] + ) + ).all() + + grid_2d = aa.util.grid_2d.grid_2d_via_shape_native_from( + shape_native=(3, 2), + pixel_scales=(1.0, 1.0), + ) + + assert ( + grid_2d + == np.array( + [ + [[1.0, -0.5], [1.0, 0.5]], + [[0.0, -0.5], [0.0, 0.5]], + [[-1.0, -0.5], [-1.0, 0.5]], + ] + ) + ).all() + + grid_2d = aa.util.grid_2d.grid_2d_via_shape_native_from( + shape_native=(3, 2), pixel_scales=(1.0, 1.0), origin=(3.0, -2.0) + ) + + assert ( + grid_2d + == np.array( + [ + [[4.0, -2.5], [4.0, -1.5]], + [[3.0, -2.5], [3.0, -1.5]], + [[2.0, -2.5], [2.0, -1.5]], + ] + ) + ).all() + + def test__grid_2d_via_shape_native_from(): grid_2d = aa.util.grid_2d.grid_2d_via_shape_native_from( shape_native=(2, 3), @@ -236,7 +342,7 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(1.0, 1.0), ) - assert (grid_radii == np.array([[0.0, 0.0], [0.0, 1.0]])).all() + assert grid_radii == pytest.approx(np.array([[0.0, 0.0], [0.0, 1.0]]), abs=1.0e-4) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([-1.0, 3.0, -1.0, 1.0]), @@ -244,9 +350,9 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(1.0, 1.0), ) - assert ( - grid_radii == np.array([[0.0, 0.0], [0.0, 1.0], [0.0, 2.0], [0.0, 3.0]]) - ).all() + assert grid_radii == pytest.approx( + np.array([[0.0, 0.0], [0.0, 1.0], [0.0, 2.0], [0.0, 3.0]]), abs=1.0e-4 + ) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([-1.0, 3.0, -1.0, 1.0]), @@ -254,7 +360,9 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(1.0, 1.0), ) - assert (grid_radii == np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0]])).all() + assert grid_radii == pytest.approx( + np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0]]), abs=1.0e-4 + ) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([-2.0, 1.0, -1.0, 1.0]), @@ -262,9 +370,9 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(1.0, 1.0), ) - assert ( - grid_radii == np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0], [0.0, 4.0]]) - ).all() + assert grid_radii == pytest.approx( + np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0], [0.0, 4.0]]), abs=1.0e-4 + ) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([-1.0, 1.0, -1.0, 1.0]), @@ -272,10 +380,10 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(0.1, 0.5), ) - assert ( - grid_radii - == np.array([[0.0, 1.0], [0.0, 1.5], [0.0, 2.0], [0.0, 2.5], [0.0, 3.0]]) - ).all() + assert grid_radii == pytest.approx( + np.array([[0.0, 1.0], [0.0, 1.5], [0.0, 2.0], [0.0, 2.5], [0.0, 3.0]]), + abs=1.0e-4, + ) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([5.0, 8.0, 99.9, 100.1]), @@ -283,9 +391,8 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(10.0, 0.25), ) - assert ( - grid_radii - == np.array( + assert grid_radii == pytest.approx( + np.array( [ [100.0, 7.0], [100.0, 7.25], @@ -297,8 +404,9 @@ def test__grid_scaled_2d_slim_radial_projected_from(): [100.0, 8.75], [100.0, 9.0], ] - ) - ).all() + ), + abs=1.0e-4, + ) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([-1.0, 1.0, -1.0, 3.0]), @@ -306,9 +414,9 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(1.0, 1.0), ) - assert ( - grid_radii == np.array([[0.0, 0.0], [0.0, 1.0], [0.0, 2.0], [0.0, 3.0]]) - ).all() + assert grid_radii == pytest.approx( + np.array([[0.0, 0.0], [0.0, 1.0], [0.0, 2.0], [0.0, 3.0]]), abs=1.0e-4 + ) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([-1.0, 1.0, -2.0, 1.0]), @@ -316,9 +424,9 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(1.0, 1.0), ) - assert ( - grid_radii == np.array([[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]) - ).all() + assert grid_radii == pytest.approx( + np.array([[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]), abs=1.0e-4 + ) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([-1.0, 1.0, -1.0, 1.0]), @@ -326,10 +434,10 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(0.5, 0.1), ) - assert ( - grid_radii - == np.array([[1.0, 0.0], [1.0, 0.5], [1.0, 1.0], [1.0, 1.5], [1.0, 2.0]]) - ).all() + assert grid_radii == pytest.approx( + np.array([[1.0, 0.0], [1.0, 0.5], [1.0, 1.0], [1.0, 1.5], [1.0, 2.0]]), + abs=1.0e-4, + ) grid_radii = aa.util.grid_2d.grid_scaled_2d_slim_radial_projected_from( extent=np.array([99.9, 100.1, -1.0, 3.0]), @@ -337,7 +445,9 @@ def test__grid_scaled_2d_slim_radial_projected_from(): pixel_scales=(1.5, 10.0), ) - assert (grid_radii == np.array([[-1.0, 100.0], [-1.0, 101.5], [-1.0, 103.0]])).all() + assert grid_radii == pytest.approx( + np.array([[-1.0, 100.0], [-1.0, 101.5], [-1.0, 103.0]]), abs=1.0e-4 + ) def test__grid_2d_slim_from(): @@ -457,107 +567,3 @@ def test__grid_2d_native_from(): ] ) ).all() - - grid_slim = np.array( - [ - [1.0, 1.0], - [1.0, 1.0], - [1.0, 1.0], - [1.0, 1.0], - [2.0, 2.0], - [2.0, 2.0], - [2.0, 2.0], - [2.0, 2.0], - [3.0, 3.0], - [3.0, 3.0], - [3.0, 3.0], - [4.0, 4.0], - ] - ) - - -def test__grid_2d_slim_upscaled_from(): - grid_slim = np.array([[1.0, 1.0]]) - - grid_upscaled_2d = aa.util.grid_2d.grid_2d_slim_upscaled_from( - grid_slim=grid_slim, upscale_factor=1, pixel_scales=(2.0, 2.0) - ) - - assert (grid_upscaled_2d == np.array([[1.0, 1.0]])).all() - - grid_upscaled_2d = aa.util.grid_2d.grid_2d_slim_upscaled_from( - grid_slim=grid_slim, upscale_factor=2, pixel_scales=(2.0, 2.0) - ) - - assert ( - grid_upscaled_2d == np.array([[1.5, 0.5], [1.5, 1.5], [0.5, 0.5], [0.5, 1.5]]) - ).all() - - grid_slim = np.array([[1.0, 1.0], [1.0, 3.0]]) - - grid_upscaled_2d = aa.util.grid_2d.grid_2d_slim_upscaled_from( - grid_slim=grid_slim, upscale_factor=2, pixel_scales=(2.0, 2.0) - ) - - assert ( - grid_upscaled_2d - == np.array( - [ - [1.5, 0.5], - [1.5, 1.5], - [0.5, 0.5], - [0.5, 1.5], - [1.5, 2.5], - [1.5, 3.5], - [0.5, 2.5], - [0.5, 3.5], - ] - ) - ).all() - - grid_slim = np.array([[1.0, 1.0], [3.0, 1.0]]) - - grid_upscaled_2d = aa.util.grid_2d.grid_2d_slim_upscaled_from( - grid_slim=grid_slim, upscale_factor=2, pixel_scales=(2.0, 2.0) - ) - - assert ( - grid_upscaled_2d - == np.array( - [ - [1.5, 0.5], - [1.5, 1.5], - [0.5, 0.5], - [0.5, 1.5], - [3.5, 0.5], - [3.5, 1.5], - [2.5, 0.5], - [2.5, 1.5], - ] - ) - ).all() - - grid_slim = np.array([[1.0, 1.0]]) - - grid_upscaled_2d = aa.util.grid_2d.grid_2d_slim_upscaled_from( - grid_slim=grid_slim, upscale_factor=2, pixel_scales=(3.0, 2.0) - ) - - assert ( - grid_upscaled_2d - == np.array([[1.75, 0.5], [1.75, 1.5], [0.25, 0.5], [0.25, 1.5]]) - ).all() - - grid_upscaled_2d = aa.util.grid_2d.grid_2d_slim_upscaled_from( - grid_slim=grid_slim, upscale_factor=3, pixel_scales=(2.0, 2.0) - ) - - assert grid_upscaled_2d[0] == pytest.approx(np.array([1.666, 0.333]), 1.0e-2) - assert grid_upscaled_2d[1] == pytest.approx(np.array([1.666, 1.0]), 1.0e-2) - assert grid_upscaled_2d[2] == pytest.approx(np.array([1.666, 1.666]), 1.0e-2) - assert grid_upscaled_2d[3] == pytest.approx(np.array([1.0, 0.333]), 1.0e-2) - assert grid_upscaled_2d[4] == pytest.approx(np.array([1.0, 1.0]), 1.0e-2) - assert grid_upscaled_2d[5] == pytest.approx(np.array([1.0, 1.666]), 1.0e-2) - assert grid_upscaled_2d[6] == pytest.approx(np.array([0.333, 0.333]), 1.0e-2) - assert grid_upscaled_2d[7] == pytest.approx(np.array([0.333, 1.0]), 1.0e-2) - assert grid_upscaled_2d[8] == pytest.approx(np.array([0.333, 1.666]), 1.0e-2) diff --git a/test_autoarray/structures/grids/test_irregular_2d.py b/test_autoarray/structures/grids/test_irregular_2d.py index 245848224..f98ee6a90 100644 --- a/test_autoarray/structures/grids/test_irregular_2d.py +++ b/test_autoarray/structures/grids/test_irregular_2d.py @@ -95,18 +95,22 @@ def test__furthest_distances_to_other_coordinates(): def test__grid_of_closest_from(): grid = aa.Grid2DIrregular(values=[(0.0, 0.0), (0.0, 1.0)]) - grid_of_closest = grid.grid_of_closest_from(grid_pair=np.array([[0.0, 0.1]])) + grid_of_closest = grid.grid_of_closest_from( + grid_pair=aa.Grid2DIrregular(np.array([[0.0, 0.1]])) + ) assert (grid_of_closest == np.array([[0.0, 0.0]])).all() grid_of_closest = grid.grid_of_closest_from( - grid_pair=np.array([[0.0, 0.1], [0.0, 0.2], [0.0, 0.3]]) + grid_pair=aa.Grid2DIrregular(np.array([[0.0, 0.1], [0.0, 0.2], [0.0, 0.3]])) ) assert (grid_of_closest == np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])).all() grid_of_closest = grid.grid_of_closest_from( - grid_pair=np.array([[0.0, 0.1], [0.0, 0.2], [0.0, 0.9], [0.0, -0.1]]) + grid_pair=aa.Grid2DIrregular( + np.array([[0.0, 0.1], [0.0, 0.2], [0.0, 0.9], [0.0, -0.1]]) + ) ) assert ( diff --git a/test_autoarray/structures/grids/test_uniform_1d.py b/test_autoarray/structures/grids/test_uniform_1d.py index b99da3608..e641c4015 100644 --- a/test_autoarray/structures/grids/test_uniform_1d.py +++ b/test_autoarray/structures/grids/test_uniform_1d.py @@ -123,7 +123,7 @@ def test__grid_2d_radial_projected_from(): assert type(grid_2d) == aa.Grid2DIrregular assert grid_2d.slim == pytest.approx( - np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0], [0.0, 4.0]]), 1.0e-4 + np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0], [0.0, 4.0]]), abs=1.0e-4 ) grid_2d = grid_1d.grid_2d_radial_projected_from(angle=90.0) diff --git a/test_autoarray/structures/grids/test_uniform_2d.py b/test_autoarray/structures/grids/test_uniform_2d.py index 78813329e..9352dfbb1 100644 --- a/test_autoarray/structures/grids/test_uniform_2d.py +++ b/test_autoarray/structures/grids/test_uniform_2d.py @@ -441,7 +441,7 @@ def test__from_mask(): mask = aa.Mask2D(mask=mask, pixel_scales=(2.0, 2.0)) grid_via_util = aa.util.grid_2d.grid_2d_slim_via_mask_from( - mask_2d=np.array(mask), pixel_scales=(2.0, 2.0) + mask_2d=mask, pixel_scales=(2.0, 2.0) ) grid_2d = aa.Grid2D.from_mask(mask=mask) @@ -451,8 +451,8 @@ def test__from_mask(): assert grid_2d.pixel_scales == (2.0, 2.0) grid_2d_native = aa.util.grid_2d.grid_2d_native_from( - grid_2d_slim=np.array(grid_2d), - mask_2d=np.array(mask), + grid_2d_slim=grid_2d.array, + mask_2d=mask, ) assert (grid_2d_native == grid_2d.native).all() @@ -559,7 +559,7 @@ def test__grid_2d_radial_projected_shape_slim_from(): pixel_scales=grid_2d.pixel_scales, ) - assert (grid_radii == grid_radii_util).all() + assert grid_radii == pytest.approx(grid_radii_util, 1.0e-4) assert grid_radial_shape_slim == grid_radii_util.shape[0] grid_2d = aa.Grid2D.uniform(shape_native=(3, 4), pixel_scales=(3.0, 2.0)) diff --git a/test_autoarray/structures/mesh/test_rectangular.py b/test_autoarray/structures/mesh/test_rectangular.py index a63489733..431cf4867 100644 --- a/test_autoarray/structures/mesh/test_rectangular.py +++ b/test_autoarray/structures/mesh/test_rectangular.py @@ -12,8 +12,8 @@ def test__neighbors__compare_to_mesh_util(): # I8 I 9I10I11I # I12I13I14I15I - mesh = aa.Mesh2DRectangular.overlay_grid( - shape_native=(7, 5), grid=np.zeros((2, 2)), buffer=1e-8 + mesh = aa.Mesh2DRectangularUniform.overlay_grid( + shape_native=(7, 5), grid=aa.Grid2DIrregular(np.zeros((2, 2))), buffer=1e-8 ) (neighbors_util, neighbors_sizes_util) = aa.util.mesh.rectangular_neighbors_from( @@ -25,7 +25,7 @@ def test__neighbors__compare_to_mesh_util(): def test__edge_pixel_list(): - grid = np.array( + grid = aa.Grid2DIrregular( [ [-1.0, -1.0], [-1.0, 0.0], @@ -39,7 +39,7 @@ def test__edge_pixel_list(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid, buffer=1e-8 ) @@ -47,7 +47,7 @@ def test__edge_pixel_list(): def test__shape_native_and_pixel_scales(): - grid = np.array( + grid = aa.Grid2DIrregular( [ [-1.0, -1.0], [-1.0, 0.0], @@ -61,14 +61,14 @@ def test__shape_native_and_pixel_scales(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid, buffer=1e-8 ) assert mesh.shape_native == (3, 3) assert mesh.pixel_scales == pytest.approx((2.0 / 3.0, 2.0 / 3.0), 1e-2) - grid = np.array( + grid = aa.Grid2DIrregular( [ [1.0, -1.0], [1.0, 0.0], @@ -82,16 +82,16 @@ def test__shape_native_and_pixel_scales(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(5, 4), grid=grid, buffer=1e-8 ) assert mesh.shape_native == (5, 4) assert mesh.pixel_scales == pytest.approx((2.0 / 5.0, 2.0 / 4.0), 1e-2) - grid = np.array([[2.0, 1.0], [4.0, 3.0], [6.0, 5.0], [8.0, 7.0]]) + grid = aa.Grid2DIrregular([[2.0, 1.0], [4.0, 3.0], [6.0, 5.0], [8.0, 7.0]]) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid, buffer=1e-8 ) @@ -100,7 +100,7 @@ def test__shape_native_and_pixel_scales(): def test__pixel_centres__3x3_grid__pixel_centres(): - grid = np.array( + grid = aa.Grid2DIrregular( [ [1.0, -1.0], [1.0, 0.0], @@ -114,7 +114,7 @@ def test__pixel_centres__3x3_grid__pixel_centres(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(3, 3), grid=grid, buffer=1e-8 ) @@ -134,7 +134,7 @@ def test__pixel_centres__3x3_grid__pixel_centres(): ) ) - grid = np.array( + grid = aa.Grid2DIrregular( [ [1.0, -1.0], [1.0, 0.0], @@ -148,7 +148,7 @@ def test__pixel_centres__3x3_grid__pixel_centres(): ] ) - mesh = aa.Mesh2DRectangular.overlay_grid( + mesh = aa.Mesh2DRectangularUniform.overlay_grid( shape_native=(4, 3), grid=grid, buffer=1e-8 ) @@ -179,7 +179,7 @@ def test__interpolated_array_from(): pixel_scales=1.0, ) - grid_rectangular = aa.Mesh2DRectangular( + grid_rectangular = aa.Mesh2DRectangularUniform( values=grid, shape_native=grid.shape_native, pixel_scales=grid.pixel_scales ) diff --git a/test_autoarray/structures/mesh/test_voronoi.py b/test_autoarray/structures/mesh/test_voronoi.py index 12ef7bc30..6e3fa5887 100644 --- a/test_autoarray/structures/mesh/test_voronoi.py +++ b/test_autoarray/structures/mesh/test_voronoi.py @@ -29,7 +29,7 @@ def test__neighbors__compare_to_mesh_util(): np.asarray([grid[:, 1], grid[:, 0]]).T, qhull_options="Qbb Qc Qx Qm" ) - (neighbors_util, neighbors_sizes_util) = aa.util.mesh.voronoi_neighbors_from( + (neighbors_util, neighbors_sizes_util) = aa.util.mesh_numba.voronoi_neighbors_from( pixels=9, ridge_points=np.array(voronoi.ridge_points) ) diff --git a/test_autoarray/structures/plot/test_structure_plotters.py b/test_autoarray/structures/plot/test_structure_plotters.py index d455c86f4..53b796798 100644 --- a/test_autoarray/structures/plot/test_structure_plotters.py +++ b/test_autoarray/structures/plot/test_structure_plotters.py @@ -3,7 +3,6 @@ from os import path import pytest import numpy as np -import jax.numpy as jnp import shutil directory = path.dirname(path.realpath(__file__)) @@ -58,7 +57,6 @@ def test__array( array_plotter = aplt.Array2DPlotter( array=array_2d_7x7, - include_2d=aplt.Include2D(origin=True, mask=True, border=True), mat_plot_2d=aplt.MatPlot2D( output=aplt.Output(path=plot_path, filename="array2", format="png") ), @@ -139,7 +137,6 @@ def test__grid( mat_plot_2d=aplt.MatPlot2D( output=aplt.Output(path=plot_path, filename="grid2", format="png") ), - include_2d=aplt.Include2D(origin=True, mask=True, border=True), ) grid_2d_plotter.figure_2d(color_array=color_array) @@ -168,3 +165,20 @@ def test__grid( grid_2d_plotter.figure_2d(color_array=color_array) assert path.join(plot_path, "grid3.png") in plot_patch.paths + + +def test__array_rgb( + array_2d_rgb_7x7, + plot_path, + plot_patch, +): + array_plotter = aplt.Array2DPlotter( + array=array_2d_rgb_7x7, + mat_plot_2d=aplt.MatPlot2D( + output=aplt.Output(path=plot_path, filename="array_rgb", format="png") + ), + ) + + array_plotter.figure_2d() + + assert path.join(plot_path, "array_rgb.png") in plot_patch.paths diff --git a/test_autoarray/structures/test_visibilities.py b/test_autoarray/structures/test_visibilities.py index f029ef3ed..82c6bddf0 100644 --- a/test_autoarray/structures/test_visibilities.py +++ b/test_autoarray/structures/test_visibilities.py @@ -16,7 +16,6 @@ def test__manual__makes_visibilities_without_other_inputs(): assert type(visibilities) == vis.Visibilities assert (visibilities.slim == np.array([1.0 + 2.0j, 3.0 + 4.0j])).all() assert (visibilities.in_array == np.array([[1.0, 2.0], [3.0, 4.0]])).all() - assert (visibilities.ordered_1d == np.array([1.0, 3.0, 2.0, 4.0])).all() assert (visibilities.amplitudes == np.array([np.sqrt(5), 5.0])).all() assert visibilities.phases == pytest.approx( np.array([1.10714872, 0.92729522]), 1.0e-4 @@ -29,7 +28,6 @@ def test__manual__makes_visibilities_without_other_inputs(): assert ( visibilities.in_array == np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ).all() - assert (visibilities.ordered_1d == np.array([1.0, 3.0, 5.0, 2.0, 4.0, 6.0])).all() def test__manual__makes_visibilities_with_converted_input_as_list(): @@ -121,7 +119,3 @@ def test__visibilities_noise_has_attributes(): assert (noise_map.slim == np.array([1.0 + 2.0j, 3.0 + 4.0j])).all() assert (noise_map.amplitudes == np.array([np.sqrt(5), 5.0])).all() assert noise_map.phases == pytest.approx(np.array([1.10714872, 0.92729522]), 1.0e-4) - assert (noise_map.ordered_1d == np.array([1.0, 3.0, 2.0, 4.0])).all() - assert ( - noise_map.weight_list_ordered_1d == np.array([1.0, 1.0 / 9.0, 0.25, 0.0625]) - ).all() diff --git a/test_autoarray/structures/triangles/conftest.py b/test_autoarray/structures/triangles/conftest.py index a8d8580a3..9b943c224 100644 --- a/test_autoarray/structures/triangles/conftest.py +++ b/test_autoarray/structures/triangles/conftest.py @@ -1,5 +1,6 @@ from autoarray.numpy_wrapper import np from autoarray.structures.triangles.array import ArrayTriangles +from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles from matplotlib import pyplot as plt @@ -54,3 +55,19 @@ def triangles(): ] ), ) + + +@pytest.fixture +def one_triangle(): + return CoordinateArrayTriangles( + coordinates=np.array([[0, 0]]), + side_length=1.0, + ) + + +@pytest.fixture +def two_triangles(): + return CoordinateArrayTriangles( + coordinates=np.array([[0, 0], [1, 0]]), + side_length=1.0, + ) diff --git a/test_autoarray/structures/triangles/coordinate/__init__.py b/test_autoarray/structures/triangles/coordinate/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_autoarray/structures/triangles/coordinate/conftest.py b/test_autoarray/structures/triangles/coordinate/conftest.py deleted file mode 100644 index 302b565f7..000000000 --- a/test_autoarray/structures/triangles/coordinate/conftest.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -import numpy as np - -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles - - -@pytest.fixture -def one_triangle(): - return CoordinateArrayTriangles( - coordinates=np.array([[0, 0]]), - side_length=1.0, - ) - - -@pytest.fixture -def two_triangles(): - return CoordinateArrayTriangles( - coordinates=np.array([[0, 0], [1, 0]]), - side_length=1.0, - ) diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py deleted file mode 100644 index 545d16da0..000000000 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py +++ /dev/null @@ -1,290 +0,0 @@ -import pytest - -import numpy as np - -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles -from autoarray.structures.triangles.shape import Point - - -def test_two(two_triangles): - assert np.all(two_triangles.centres == np.array([[0, 0], [0.5, 0]])) - assert np.all( - two_triangles.triangles - == [ - [ - [0.0, HEIGHT_FACTOR / 2], - [0.5, -HEIGHT_FACTOR / 2], - [-0.5, -HEIGHT_FACTOR / 2], - ], - [ - [0.5, -HEIGHT_FACTOR / 2], - [0.0, HEIGHT_FACTOR / 2], - [1.0, HEIGHT_FACTOR / 2], - ], - ] - ) - - -def test_trivial_triangles(one_triangle): - assert one_triangle.flip_array == np.array([1]) - assert np.all(one_triangle.centres == np.array([[0, 0]])) - assert np.all( - one_triangle.triangles - == [ - [ - [0.0, HEIGHT_FACTOR / 2], - [0.5, -HEIGHT_FACTOR / 2], - [-0.5, -HEIGHT_FACTOR / 2], - ], - ] - ) - - -def test_above(): - triangles = CoordinateArrayTriangles( - coordinates=np.array([[0, 1]]), - side_length=1.0, - ) - assert np.all( - triangles.up_sample().triangles - == [ - [ - [0.0, 0.43301270189221935], - [-0.25, 0.8660254037844386], - [0.25, 0.8660254037844386], - ], - [ - [0.25, 0.8660254037844388], - [0.0, 1.299038105676658], - [0.5, 1.299038105676658], - ], - [ - [-0.25, 0.8660254037844388], - [-0.5, 1.299038105676658], - [0.0, 1.299038105676658], - ], - [ - [0.0, 1.299038105676658], - [0.25, 0.8660254037844388], - [-0.25, 0.8660254037844388], - ], - ] - ) - - -@pytest.fixture -def upside_down(): - return CoordinateArrayTriangles( - coordinates=np.array([[1, 0]]), - side_length=1.0, - ) - - -def test_upside_down(upside_down): - assert np.all(upside_down.centres == np.array([[0.5, 0]])) - assert np.all( - upside_down.triangles - == [ - [ - [0.5, -HEIGHT_FACTOR / 2], - [0.0, HEIGHT_FACTOR / 2], - [1.0, HEIGHT_FACTOR / 2], - ], - ] - ) - - -def test_up_sample(one_triangle): - up_sampled = one_triangle.up_sample() - assert up_sampled.side_length == 0.5 - assert np.all( - up_sampled.triangles - == [ - [[0.0, -0.4330127018922193], [-0.25, 0.0], [0.25, 0.0]], - [[0.25, 0.0], [0.5, -0.4330127018922193], [0.0, -0.4330127018922193]], - [[-0.25, 0.0], [0.0, -0.4330127018922193], [-0.5, -0.4330127018922193]], - [[0.0, 0.4330127018922193], [0.25, 0.0], [-0.25, 0.0]], - ] - ) - - -def test_up_sample_upside_down(upside_down): - up_sampled = upside_down.up_sample() - assert up_sampled.side_length == 0.5 - assert np.all( - up_sampled.triangles - == [ - [[0.5, -0.4330127018922193], [0.25, 0.0], [0.75, 0.0]], - [[0.75, 0.0], [0.5, 0.4330127018922193], [1.0, 0.4330127018922193]], - [[0.25, 0.0], [0.0, 0.4330127018922193], [0.5, 0.4330127018922193]], - [[0.5, 0.4330127018922193], [0.75, 0.0], [0.25, 0.0]], - ] - ) - - -def _test_up_sample_twice(one_triangle, plot): - plot(one_triangle) - one = one_triangle.up_sample() - two = one.up_sample() - three = two.up_sample() - plot(three, color="blue") - plot(two, color="green") - plot(one, color="red") - - -def test_neighborhood(one_triangle): - assert np.all( - one_triangle.neighborhood().triangles - == [ - [ - [-0.5, -0.4330127018922193], - [-1.0, 0.4330127018922193], - [0.0, 0.4330127018922193], - ], - [ - [0.0, -1.299038105676658], - [-0.5, -0.4330127018922193], - [0.5, -0.4330127018922193], - ], - [ - [0.0, 0.4330127018922193], - [0.5, -0.4330127018922193], - [-0.5, -0.4330127018922193], - ], - [ - [0.5, -0.4330127018922193], - [0.0, 0.4330127018922193], - [1.0, 0.4330127018922193], - ], - ] - ) - - -def test_upside_down_neighborhood(upside_down): - assert np.all( - upside_down.neighborhood().triangles - == [ - [ - [0.0, 0.4330127018922193], - [0.5, -0.4330127018922193], - [-0.5, -0.4330127018922193], - ], - [ - [0.5, -0.4330127018922193], - [0.0, 0.4330127018922193], - [1.0, 0.4330127018922193], - ], - [ - [0.5, 1.299038105676658], - [1.0, 0.4330127018922193], - [0.0, 0.4330127018922193], - ], - [ - [1.0, 0.4330127018922193], - [1.5, -0.4330127018922193], - [0.5, -0.4330127018922193], - ], - ] - ) - - -def _test_complicated(plot, one_triangle): - triangles = one_triangle.neighborhood().neighborhood() - up_sampled = triangles.up_sample() - - -def test_vertices(one_triangle): - assert np.all( - one_triangle.vertices - == [ - [-0.5, -0.4330127018922193], - [0.0, 0.4330127018922193], - [0.5, -0.4330127018922193], - ] - ) - - -def test_up_sampled_vertices(one_triangle): - assert np.all( - one_triangle.up_sample().vertices - == [ - [-0.5, -0.4330127018922193], - [-0.25, 0.0], - [0.0, -0.4330127018922193], - [0.0, 0.4330127018922193], - [0.25, 0.0], - [0.5, -0.4330127018922193], - ] - ) - - -def test_with_vertices(one_triangle): - triangle = one_triangle.with_vertices(np.array([[0, 0], [1, 0], [0.5, 1]])) - assert np.all(triangle.triangles == [[[1.0, 0.0], [0.5, 1.0], [0.0, 0.0]]]) - - -def _test_multiple_with_vertices(one_triangle, plot): - up_sampled = one_triangle.up_sample() - plot(up_sampled.with_vertices(2 * up_sampled.vertices).triangles.tolist()) - - -def test_for_indexes(two_triangles): - assert np.all( - two_triangles.for_indexes(np.array([0])).triangles - == [ - [ - [0.0, 0.4330127018922193], - [0.5, -0.4330127018922193], - [-0.5, -0.4330127018922193], - ] - ] - ) - - -def test_means(one_triangle): - assert np.all(one_triangle.means == [[0.0, -0.14433756729740643]]) - - -@pytest.mark.parametrize( - "x, y", - [ - (0.0, 0.0), - (-0.5, -HEIGHT_FACTOR / 2), - (0.5, -HEIGHT_FACTOR / 2), - (0.0, HEIGHT_FACTOR / 2), - ], -) -def test_containment(one_triangle, x, y): - assert one_triangle.containing_indices(Point(x, y)) == [0] - - -def test_triangles_touch(): - triangles = CoordinateArrayTriangles( - np.array([[0, 0], [2, 0]]), - ) - - assert max(triangles.triangles[0][:, 0]) == min(triangles.triangles[1][:, 0]) - - triangles = CoordinateArrayTriangles( - np.array([[0, 0], [0, 1]]), - ) - assert max(triangles.triangles[0][:, 1]) == min(triangles.triangles[1][:, 1]) - - -def test_from_grid_regression(): - triangles = CoordinateArrayTriangles.for_limits_and_scale( - x_min=-4.75, - x_max=4.75, - y_min=-4.75, - y_max=4.75, - scale=0.5, - ) - - x = triangles.vertices[:, 0] - assert min(x) <= -4.75 - assert max(x) >= 4.75 - - y = triangles.vertices[:, 1] - assert min(y) <= -4.75 - assert max(y) >= 4.75 diff --git a/test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py b/test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py deleted file mode 100644 index 1f37a1c90..000000000 --- a/test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py +++ /dev/null @@ -1,127 +0,0 @@ -from autoarray.numpy_wrapper import jit -import pytest - -from autoarray.structures.triangles.abstract import HEIGHT_FACTOR -from autoarray.structures.triangles.shape import Point - -try: - from jax import numpy as np - import jax - - jax.config.update("jax_log_compiles", True) - from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import ( - CoordinateArrayTriangles, - ) -except ImportError: - import numpy as np - from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles - - -@pytest.fixture -def one_triangle(): - return CoordinateArrayTriangles( - coordinates=np.array([[0, 0]]), - side_length=1.0, - ) - - -@jit -def full_routine(triangles): - neighborhood = triangles.neighborhood() - up_sampled = neighborhood.up_sample() - with_vertices = up_sampled.with_vertices(up_sampled.vertices) - indexes = with_vertices.containing_indices(Point(0.1, 0.1)) - return up_sampled.for_indexes(indexes) - - -# def test_full_routine(one_triangle, compare_with_nans): -# result = full_routine(one_triangle) -# -# assert compare_with_nans( -# result.triangles, -# np.array( -# [ -# [ -# [0.0, 0.4330126941204071], -# [0.25, 0.0], -# [-0.25, 0.0], -# ] -# ] -# ), -# ) - - -def test_neighborhood(one_triangle): - assert np.allclose( - np.array(jit(one_triangle.neighborhood)().triangles), - np.array( - [ - [ - [-0.5, -0.4330126941204071], - [-1.0, 0.4330126941204071], - [0.0, 0.4330126941204071], - ], - [ - [0.0, -1.299038052558899], - [-0.5, -0.4330126941204071], - [0.5, -0.4330126941204071], - ], - [ - [0.0, 0.4330126941204071], - [0.5, -0.4330126941204071], - [-0.5, -0.4330126941204071], - ], - [ - [0.5, -0.4330126941204071], - [0.0, 0.4330126941204071], - [1.0, 0.4330126941204071], - ], - ] - ), - ) - - -def test_up_sample(one_triangle): - up_sampled = jit(one_triangle.up_sample)() - assert np.allclose( - np.array(up_sampled.triangles), - np.array( - [ - [ - [[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]], - [ - [0.25, 0.0], - [0.5, -0.4330126941204071], - [0.0, -0.4330126941204071], - ], - [ - [-0.25, 0.0], - [0.0, -0.4330126941204071], - [-0.5, -0.4330126941204071], - ], - [[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]], - ] - ] - ), - ) - - -def test_means(one_triangle): - assert len(one_triangle.means) == 1 - - up_sampled = one_triangle.up_sample() - neighborhood = up_sampled.neighborhood() - assert np.count_nonzero(~np.isnan(neighborhood.means).any(axis=1)) == 10 - - -ONE_TRIANGLE_AREA = HEIGHT_FACTOR * 0.5 - - -def test_area(one_triangle): - assert one_triangle.area == ONE_TRIANGLE_AREA - assert one_triangle.up_sample().area == ONE_TRIANGLE_AREA - - neighborhood = one_triangle.neighborhood() - assert neighborhood.area == 4 * ONE_TRIANGLE_AREA - assert neighborhood.up_sample().area == 4 * ONE_TRIANGLE_AREA - assert neighborhood.neighborhood().area == 10 * ONE_TRIANGLE_AREA diff --git a/test_autoarray/structures/triangles/test_array_representation.py b/test_autoarray/structures/triangles/test_array_representation.py deleted file mode 100644 index 832c0793f..000000000 --- a/test_autoarray/structures/triangles/test_array_representation.py +++ /dev/null @@ -1,215 +0,0 @@ -import numpy as np -import pytest - -from autoarray.structures.triangles.array import ArrayTriangles -from autoarray.structures.triangles.shape import Point - - -@pytest.mark.parametrize( - "point, indices", - [ - ( - Point(0.1, 0.1), - np.array([0]), - ), - ( - Point(0.6, 0.6), - np.array([1]), - ), - ( - Point(0.5, 0.5), - np.array([0, 1]), - ), - ], -) -def test_contains_vertices( - triangles, - point, - indices, -): - containing_indices = triangles.containing_indices(point) - - assert (containing_indices == indices).all() - - -@pytest.mark.parametrize( - "indexes, vertices, indices", - [ - ( - np.array([0]), - np.array( - [ - [0.0, 0.0], - [0.0, 1.0], - [1.0, 0.0], - ] - ), - np.array( - [ - [0, 2, 1], - ] - ), - ), - ( - np.array([1]), - np.array( - [ - [0.0, 1.0], - [1.0, 0.0], - [1.0, 1.0], - ] - ), - np.array( - [ - [1, 0, 2], - ] - ), - ), - ( - np.array([0, 1]), - np.array( - [ - [0.0, 0.0], - [0.0, 1.0], - [1.0, 0.0], - [1.0, 1.0], - ] - ), - np.array( - [ - [0, 2, 1], - [2, 1, 3], - ] - ), - ), - ], -) -def test_for_indexes( - triangles, - indexes, - vertices, - indices, -): - containing = triangles.for_indexes(indexes) - - assert (containing.indices == indices).all() - assert (containing.vertices == vertices).all() - - -def test_up_sample(triangles): - up_sampled = triangles.up_sample() - - assert ( - up_sampled.vertices - == np.array( - [ - [0.0, 0.0], - [0.0, 0.5], - [0.0, 1.0], - [0.5, 0.0], - [0.5, 0.5], - [0.5, 1.0], - [1.0, 0.0], - [1.0, 0.5], - [1.0, 1.0], - ] - ) - ).all() - - assert ( - up_sampled.indices - == np.array( - [ - [6, 4, 3], - [2, 5, 4], - [2, 1, 4], - [8, 7, 5], - [3, 4, 1], - [4, 5, 7], - [0, 3, 1], - [6, 4, 7], - ] - ) - ).all() - - -@pytest.mark.parametrize( - "offset", - [-1, 0, 1], -) -def test_simple_neighborhood(offset): - triangles = ArrayTriangles( - indices=np.array( - [ - [0, 1, 2], - ] - ), - vertices=np.array( - [ - [0.0, 0.0], - [1.0, 0.0], - [0.0, 1.0], - ] - ) - + offset, - ) - assert ( - triangles.neighborhood().triangles - == ( - np.array( - [ - [[-1.0, 1.0], [0.0, 0.0], [0.0, 1.0]], - [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]], - [[0.0, 0.0], [1.0, -1.0], [1.0, 0.0]], - [[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], - ] - ) - + offset - ) - ).all() - - -def test_neighborhood(triangles): - neighborhood = triangles.neighborhood() - - assert ( - neighborhood.vertices - == np.array( - [ - [-1.0, 1.0], - [0.0, 0.0], - [0.0, 1.0], - [0.0, 2.0], - [1.0, -1.0], - [1.0, 0.0], - [1.0, 1.0], - [2.0, 0.0], - ] - ) - ).all() - - assert ( - neighborhood.indices - == np.array( - [ - [0, 1, 2], - [1, 2, 5], - [1, 4, 5], - [2, 3, 6], - [2, 5, 6], - [5, 6, 7], - ] - ) - ).all() - - -def test_means(triangles): - means = triangles.means - assert means == pytest.approx( - np.array( - [ - [0.33333333, 0.33333333], - [0.66666667, 0.66666667], - ] - ) - ) diff --git a/test_autoarray/structures/triangles/test_coordinate.py b/test_autoarray/structures/triangles/test_coordinate.py new file mode 100644 index 000000000..2f37bf506 --- /dev/null +++ b/test_autoarray/structures/triangles/test_coordinate.py @@ -0,0 +1,420 @@ +from jax import numpy as np +import jax +import numpy as np + +jax.config.update("jax_log_compiles", True) +import pytest + +from autoarray.structures.triangles.abstract import HEIGHT_FACTOR +from autoarray.structures.triangles.shape import Point + +from autoarray.structures.triangles.coordinate_array import ( + CoordinateArrayTriangles, +) + + +def test__two(two_triangles): + + assert np.all(two_triangles.centres == np.array([[0, 0], [0.5, 0]])) + assert two_triangles.triangles == pytest.approx( + np.array( + [ + [ + [0.0, HEIGHT_FACTOR / 2], + [0.5, -HEIGHT_FACTOR / 2], + [-0.5, -HEIGHT_FACTOR / 2], + ], + [ + [0.5, -HEIGHT_FACTOR / 2], + [0.0, HEIGHT_FACTOR / 2], + [1.0, HEIGHT_FACTOR / 2], + ], + ] + ), + 1.0e-4, + ) + + +def test__trivial_triangles(one_triangle): + assert one_triangle.flip_array == np.array([1]) + assert np.all(one_triangle.centres == np.array([[0, 0]])) + assert one_triangle.triangles == pytest.approx( + np.array( + [ + [ + [0.0, HEIGHT_FACTOR / 2], + [0.5, -HEIGHT_FACTOR / 2], + [-0.5, -HEIGHT_FACTOR / 2], + ], + ] + ), + 1.0e-4, + ) + + +def test__above(): + triangles = CoordinateArrayTriangles( + coordinates=np.array([[0, 1]]), + side_length=1.0, + ) + assert triangles.up_sample().triangles == pytest.approx( + np.array( + [ + [ + [0.0, 0.43301270189221935], + [-0.25, 0.8660254037844386], + [0.25, 0.8660254037844386], + ], + [ + [0.25, 0.8660254037844388], + [0.0, 1.299038105676658], + [0.5, 1.299038105676658], + ], + [ + [-0.25, 0.8660254037844388], + [-0.5, 1.299038105676658], + [0.0, 1.299038105676658], + ], + [ + [0.0, 1.299038105676658], + [0.25, 0.8660254037844388], + [-0.25, 0.8660254037844388], + ], + ] + ), + 1.0e-4, + ) + + +@pytest.fixture +def upside_down(): + return CoordinateArrayTriangles( + coordinates=np.array([[1, 0]]), + side_length=1.0, + ) + + +def test_upside_down(upside_down): + assert np.all(upside_down.centres == np.array([[0.5, 0]])) + assert upside_down.triangles == pytest.approx( + np.array( + [ + [ + [0.5, -HEIGHT_FACTOR / 2], + [0.0, HEIGHT_FACTOR / 2], + [1.0, HEIGHT_FACTOR / 2], + ], + ] + ), + 1.0e-4, + ) + + +def test_up_sample(one_triangle): + up_sampled = one_triangle.up_sample() + assert up_sampled.side_length == 0.5 + assert up_sampled.triangles == pytest.approx( + np.array( + [ + [[0.0, -0.4330127018922193], [-0.25, 0.0], [0.25, 0.0]], + [[0.25, 0.0], [0.5, -0.4330127018922193], [0.0, -0.4330127018922193]], + [[-0.25, 0.0], [0.0, -0.4330127018922193], [-0.5, -0.4330127018922193]], + [[0.0, 0.4330127018922193], [0.25, 0.0], [-0.25, 0.0]], + ] + ), + 1.0e-4, + ) + + +def test_up_sample_upside_down(upside_down): + up_sampled = upside_down.up_sample() + assert up_sampled.side_length == 0.5 + assert up_sampled.triangles == pytest.approx( + np.array( + [ + [[0.5, -0.4330127018922193], [0.25, 0.0], [0.75, 0.0]], + [[0.75, 0.0], [0.5, 0.4330127018922193], [1.0, 0.4330127018922193]], + [[0.25, 0.0], [0.0, 0.4330127018922193], [0.5, 0.4330127018922193]], + [[0.5, 0.4330127018922193], [0.75, 0.0], [0.25, 0.0]], + ] + ), + 1.0e-4, + ) + + +def _test_up_sample_twice(one_triangle, plot): + plot(one_triangle) + one = one_triangle.up_sample() + two = one.up_sample() + three = two.up_sample() + plot(three, color="blue") + plot(two, color="green") + plot(one, color="red") + + +def test_neighborhood(one_triangle): + assert one_triangle.neighborhood().triangles == pytest.approx( + np.array( + [ + [ + [-0.5, -0.4330127018922193], + [-1.0, 0.4330127018922193], + [0.0, 0.4330127018922193], + ], + [ + [0.0, -1.299038105676658], + [-0.5, -0.4330127018922193], + [0.5, -0.4330127018922193], + ], + [ + [0.0, 0.4330127018922193], + [0.5, -0.4330127018922193], + [-0.5, -0.4330127018922193], + ], + [ + [0.5, -0.4330127018922193], + [0.0, 0.4330127018922193], + [1.0, 0.4330127018922193], + ], + ] + ), + 1.0e-4, + ) + + +def test_upside_down_neighborhood(upside_down): + assert upside_down.neighborhood().triangles == pytest.approx( + np.array( + [ + [ + [0.0, 0.4330127018922193], + [0.5, -0.4330127018922193], + [-0.5, -0.4330127018922193], + ], + [ + [0.5, -0.4330127018922193], + [0.0, 0.4330127018922193], + [1.0, 0.4330127018922193], + ], + [ + [0.5, 1.299038105676658], + [1.0, 0.4330127018922193], + [0.0, 0.4330127018922193], + ], + [ + [1.0, 0.4330127018922193], + [1.5, -0.4330127018922193], + [0.5, -0.4330127018922193], + ], + ] + ), + 1.0e-4, + ) + + +def _test_complicated(plot, one_triangle): + triangles = one_triangle.neighborhood().neighborhood() + up_sampled = triangles.up_sample() + + +def test_vertices(one_triangle): + assert one_triangle.vertices == pytest.approx( + np.array( + [ + [-0.5, -0.4330127018922193], + [0.0, 0.4330127018922193], + [0.5, -0.4330127018922193], + ] + ), + 1.0e-4, + ) + + +def test_up_sampled_vertices(one_triangle): + assert one_triangle.up_sample().vertices[0:6, :] == pytest.approx( + np.array( + [ + [-0.5, -0.4330127018922193], + [-0.25, 0.0], + [0.0, -0.4330127018922193], + [0.0, 0.4330127018922193], + [0.25, 0.0], + [0.5, -0.4330127018922193], + ] + ), + 1.0e-4, + ) + + +def test_with_vertices(one_triangle): + triangle = one_triangle.with_vertices(np.array([[0, 0], [1, 0], [0.5, 1]])) + assert triangle.triangles == pytest.approx( + np.array([[[1.0, 0.0], [0.5, 1.0], [0.0, 0.0]]]), 1.0e-4 + ) + + +def _test_multiple_with_vertices(one_triangle, plot): + up_sampled = one_triangle.up_sample() + plot(up_sampled.with_vertices(2 * up_sampled.vertices).triangles.tolist()) + + +def test_for_indexes(two_triangles): + assert two_triangles.for_indexes(np.array([0])).triangles == pytest.approx( + np.array( + [ + [ + [0.0, 0.4330127018922193], + [0.5, -0.4330127018922193], + [-0.5, -0.4330127018922193], + ] + ] + ), + 1.0e-4, + ) + + +def test_means(one_triangle): + assert one_triangle.means == pytest.approx( + np.array([[0.0, -0.14433756729740643]]), 1.0e-4 + ) + + +def test_triangles_touch(): + triangles = CoordinateArrayTriangles( + np.array([[0, 0], [2, 0]]), + ) + + assert max(triangles.triangles[0][:, 0]) == min(triangles.triangles[1][:, 0]) + + triangles = CoordinateArrayTriangles( + np.array([[0, 0], [0, 1]]), + ) + assert max(triangles.triangles[0][:, 1]) == min(triangles.triangles[1][:, 1]) + + +def test_from_grid_regression(): + triangles = CoordinateArrayTriangles.for_limits_and_scale( + x_min=-4.75, + x_max=4.75, + y_min=-4.75, + y_max=4.75, + scale=0.5, + ) + + x = triangles.vertices[:, 0] + assert min(x) <= -4.75 + assert max(x) >= 4.75 + + y = triangles.vertices[:, 1] + assert min(y) <= -4.75 + assert max(y) >= 4.75 + + +@pytest.fixture +def one_triangle(): + return CoordinateArrayTriangles( + coordinates=np.array([[0, 0]]), + side_length=1.0, + ) + + +@jax.jit +def full_routine(triangles): + neighborhood = triangles.neighborhood() + up_sampled = neighborhood.up_sample() + with_vertices = up_sampled.with_vertices(up_sampled.vertices) + indexes = with_vertices.containing_indices(Point(0.1, 0.1)) + return up_sampled.for_indexes(indexes) + + +# def test_full_routine(one_triangle, compare_with_nans): +# result = full_routine(one_triangle) +# +# assert compare_with_nans( +# result.triangles, +# np.array( +# [ +# [ +# [0.0, 0.4330126941204071], +# [0.25, 0.0], +# [-0.25, 0.0], +# ] +# ] +# ), +# ) + + +def test_neighborhood(one_triangle): + assert np.allclose( + np.array(jax.jit(one_triangle.neighborhood)().triangles), + np.array( + [ + [ + [-0.5, -0.4330126941204071], + [-1.0, 0.4330126941204071], + [0.0, 0.4330126941204071], + ], + [ + [0.0, -1.299038052558899], + [-0.5, -0.4330126941204071], + [0.5, -0.4330126941204071], + ], + [ + [0.0, 0.4330126941204071], + [0.5, -0.4330126941204071], + [-0.5, -0.4330126941204071], + ], + [ + [0.5, -0.4330126941204071], + [0.0, 0.4330126941204071], + [1.0, 0.4330126941204071], + ], + ] + ), + ) + + +def test_up_sample(one_triangle): + up_sampled = jax.jit(one_triangle.up_sample)() + assert np.allclose( + np.array(up_sampled.triangles), + np.array( + [ + [ + [[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]], + [ + [0.25, 0.0], + [0.5, -0.4330126941204071], + [0.0, -0.4330126941204071], + ], + [ + [-0.25, 0.0], + [0.0, -0.4330126941204071], + [-0.5, -0.4330126941204071], + ], + [[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]], + ] + ] + ), + ) + + +def test_means(one_triangle): + assert len(one_triangle.means) == 1 + + up_sampled = one_triangle.up_sample() + neighborhood = up_sampled.neighborhood() + assert np.count_nonzero(~np.isnan(neighborhood.means).any(axis=1)) == 10 + + +ONE_TRIANGLE_AREA = HEIGHT_FACTOR * 0.5 + + +def test_area(one_triangle): + assert one_triangle.area == ONE_TRIANGLE_AREA + assert one_triangle.up_sample().area == ONE_TRIANGLE_AREA + + neighborhood = one_triangle.neighborhood() + assert neighborhood.area == 4 * ONE_TRIANGLE_AREA + assert neighborhood.up_sample().area == 4 * ONE_TRIANGLE_AREA + assert neighborhood.neighborhood().area == 10 * ONE_TRIANGLE_AREA diff --git a/test_autoarray/structures/triangles/test_extended_source.py b/test_autoarray/structures/triangles/test_extended_source.py index 4491bd834..4ea2482af 100644 --- a/test_autoarray/structures/triangles/test_extended_source.py +++ b/test_autoarray/structures/triangles/test_extended_source.py @@ -49,7 +49,7 @@ def test_small_point(triangles, point, indices): radius=0.001, ) ) - assert containing_triangles.tolist() == indices + assert [i for i in containing_triangles.tolist() if i != -1] == indices @pytest.mark.parametrize( @@ -72,4 +72,4 @@ def test_large_circle( radius=radius, ) ) - assert containing_triangles.tolist() == indices + assert [i for i in containing_triangles.tolist() if i != -1] == indices diff --git a/test_autoarray/structures/triangles/test_jax.py b/test_autoarray/structures/triangles/test_jax.py index def239849..63e1b1293 100644 --- a/test_autoarray/structures/triangles/test_jax.py +++ b/test_autoarray/structures/triangles/test_jax.py @@ -1,19 +1,13 @@ -from autoarray.structures.triangles.shape import Point - -try: - from jax import numpy as np - import jax +from jax import numpy as np +import jax - jax.config.update("jax_log_compiles", True) - from autoarray.structures.triangles.array.jax_array import ArrayTriangles -except ImportError: - import numpy as np - from autoarray.structures.triangles.array import ArrayTriangles +jax.config.update("jax_log_compiles", True) import pytest -pytest.importorskip("jax") +from autoarray.structures.triangles.shape import Point +from autoarray.structures.triangles.array import ArrayTriangles @pytest.fixture diff --git a/test_autoarray/structures/triangles/test_nan_triangles.py b/test_autoarray/structures/triangles/test_nan_triangles.py index 725cf5257..6dd420ad5 100644 --- a/test_autoarray/structures/triangles/test_nan_triangles.py +++ b/test_autoarray/structures/triangles/test_nan_triangles.py @@ -1,14 +1,7 @@ +from jax import numpy as np import pytest -try: - from jax import numpy as np - from autoarray.structures.triangles.array.jax_array import ArrayTriangles -except ImportError: - import numpy as np - from autoarray.structures.triangles.array import ArrayTriangles - - -pytest.importorskip("jax") +from autoarray.structures.triangles.array import ArrayTriangles @pytest.fixture diff --git a/test_autoarray/structures/triangles/coordinate/test_vertex_coordinates.py b/test_autoarray/structures/triangles/test_vertex_coordinates.py similarity index 100% rename from test_autoarray/structures/triangles/coordinate/test_vertex_coordinates.py rename to test_autoarray/structures/triangles/test_vertex_coordinates.py diff --git a/test_autoarray/test_decorators.py b/test_autoarray/test_decorators.py index 85ea8b3fe..e2652d27c 100644 --- a/test_autoarray/test_decorators.py +++ b/test_autoarray/test_decorators.py @@ -2,18 +2,9 @@ class MockClass: - def __init__(self, value, run_time_dict=None): + def __init__(self, value): self._value = value - self.run_time_dict = run_time_dict @property - @aa.profile_func def value(self): return self._value - - -def test__profile_decorator_times_decorated_function(): - cls = MockClass(value=1.0, run_time_dict={}) - cls.value - - assert "value_0" in cls.run_time_dict