diff --git a/autoarray/config/general.yaml b/autoarray/config/general.yaml index 7b9112e81..d74fc3d67 100644 --- a/autoarray/config/general.yaml +++ b/autoarray/config/general.yaml @@ -2,6 +2,8 @@ jax: use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy. fits: flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9. +psf: + use_fft_default: true # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution. inversion: check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same. use_positive_only_solver: true # If True, inversion's use a positive-only linear algebra solver by default, which is slower but prevents unphysical negative values in the reconstructed solutuion. diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index aa08cc3ca..9642eba2a 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -4,6 +4,7 @@ from typing import Optional, Union from autoconf import cached_property +from autoconf import instance from autoarray.dataset.abstract.dataset import AbstractDataset from autoarray.dataset.grids import GridsDataset @@ -29,7 +30,7 @@ def __init__( noise_covariance_matrix: Optional[np.ndarray] = None, over_sample_size_lp: Union[int, Array2D] = 4, over_sample_size_pixelization: Union[int, Array2D] = 4, - pad_for_psf: bool = False, + disable_fft_pad: bool = True, use_normalized_psf: Optional[bool] = True, check_noise_map: bool = True, ): @@ -76,10 +77,10 @@ def __init__( over_sample_size_pixelization How over sampling is performed for the grid which is associated with a pixelization, which is therefore passed into the calculations performed in the `inversion` module. - pad_for_psf - The PSF convolution may extend beyond the edges of the image mask, which can lead to edge effects in the - convolved image. If `True`, the image and noise-map are padded to ensure the PSF convolution does not - extend beyond the edge of the image. + disable_fft_pad + The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming, which places the fewest zeros + around the image. If this is set to `True`, this optimal padding is not performed and the image is used + as-is. use_normalized_psf If `True`, the PSF kernel values are rescaled such that they sum to 1.0. This can be important for ensuring the PSF kernel does not change the overall normalization of the image when it is convolved with it. @@ -87,52 +88,48 @@ def __init__( If True, the noise-map is checked to ensure all values are above zero. """ - self.unmasked = None + self.disable_fft_pad = disable_fft_pad - self.pad_for_psf = pad_for_psf + if psf is not None: - if pad_for_psf and psf is not None: - try: - data.mask.derive_mask.blurring_from( - kernel_shape_native=psf.shape_native - ) - except exc.MaskException: - over_sample_size_lp = ( - over_sample_util.over_sample_size_convert_to_array_2d_from( - over_sample_size=over_sample_size_lp, mask=data.mask - ) - ) - over_sample_size_lp = ( - over_sample_size_lp.padded_before_convolution_from( - kernel_shape=psf.shape_native, mask_pad_value=1 - ) - ) + full_shape, fft_shape, mask_shape = psf.fft_shape_from(mask=data.mask) - over_sample_size_pixelization = ( - over_sample_util.over_sample_size_convert_to_array_2d_from( - over_sample_size=over_sample_size_pixelization, mask=data.mask - ) - ) - over_sample_size_pixelization = ( - over_sample_size_pixelization.padded_before_convolution_from( - kernel_shape=psf.shape_native, mask_pad_value=1 - ) + if psf is not None and not disable_fft_pad and data.mask.shape != fft_shape: + + # If using real-space convolution instead of FFT, enforce odd-odd shapes + if not psf.use_fft: + fft_shape = tuple(s + 1 if s % 2 == 0 else s for s in fft_shape) + + logger.info( + f"Imaging data has been trimmed or padded for FFT convolution.\n" + f" - Original shape : {data.mask.shape}\n" + f" - FFT shape : {fft_shape}\n" + f"Padding ensures accurate PSF convolution in Fourier space. " + f"Set `disable_fft_pad=True` in Imaging object to turn off automatic padding." + ) + + over_sample_size_lp = ( + over_sample_util.over_sample_size_convert_to_array_2d_from( + over_sample_size=over_sample_size_lp, mask=data.mask ) + ) + over_sample_size_lp = over_sample_size_lp.resized_from( + new_shape=fft_shape, mask_pad_value=1 + ) - data = data.padded_before_convolution_from( - kernel_shape=psf.shape_native, mask_pad_value=1 + over_sample_size_pixelization = ( + over_sample_util.over_sample_size_convert_to_array_2d_from( + over_sample_size=over_sample_size_pixelization, mask=data.mask ) - if noise_map is not None: - noise_map = noise_map.padded_before_convolution_from( - kernel_shape=psf.shape_native, mask_pad_value=1 - ) - logger.info( - f"The image and noise map of the `Imaging` objected have been padded to the dimensions" - f"{data.shape}. This is because the blurring region around the mask (which defines where" - f"PSF flux may be convolved into the masked region) extended beyond the edge of the image." - f"" - f"This can be prevented by using a smaller mask, smaller PSF kernel size or manually padding" - f"the image and noise-map yourself." + ) + over_sample_size_pixelization = over_sample_size_pixelization.resized_from( + new_shape=fft_shape, mask_pad_value=1 + ) + + data = data.resized_from(new_shape=fft_shape, mask_pad_value=1) + if noise_map is not None: + noise_map = noise_map.resized_from( + new_shape=fft_shape, mask_pad_value=1 ) super().__init__( @@ -179,6 +176,9 @@ def __init__( normalize=use_normalized_psf, image_mask=image_mask, blurring_mask=blurring_mask, + mask_shape=mask_shape, + full_shape=full_shape, + fft_shape=fft_shape, ) self.psf = psf @@ -337,31 +337,34 @@ def from_fits( over_sample_size_pixelization=over_sample_size_pixelization, ) - def apply_mask(self, mask: Mask2D) -> "Imaging": + def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging": """ Apply a mask to the imaging dataset, whereby the mask is applied to the image data, noise-map and other quantities one-by-one. - The original unmasked imaging data is stored as the `self.unmasked` attribute. This is used to ensure that if - the `apply_mask` function is called multiple times, every mask is always applied to the original unmasked - imaging dataset. + The `apply_mask` function cannot be called multiple times, if it is a mask may remove data, therefore + an exception is raised. If you wish to apply a new mask, reload the dataset from .fits files. Parameters ---------- mask The 2D mask that is applied to the image. """ - if self.data.mask.is_all_false: - unmasked_dataset = self - else: - unmasked_dataset = self.unmasked + invalid = np.logical_and(self.data.mask, np.logical_not(mask)) + + if np.any(invalid): + raise exc.DatasetException( + "The new mask overlaps with pixels that are already unmasked in the dataset. " + "You cannot apply a new mask on top of an existing one. " + "If you wish to apply a different mask, please reload the dataset from .fits files." + ) - data = Array2D(values=unmasked_dataset.data.native, mask=mask) + data = Array2D(values=self.data.native, mask=mask) - noise_map = Array2D(values=unmasked_dataset.noise_map.native, mask=mask) + noise_map = Array2D(values=self.noise_map.native, mask=mask) - if unmasked_dataset.noise_covariance_matrix is not None: - noise_covariance_matrix = unmasked_dataset.noise_covariance_matrix + if self.noise_covariance_matrix is not None: + noise_covariance_matrix = self.noise_covariance_matrix noise_covariance_matrix = np.delete( noise_covariance_matrix, mask.derive_indexes.masked_slim, 0 @@ -385,11 +388,9 @@ def apply_mask(self, mask: Mask2D) -> "Imaging": noise_covariance_matrix=noise_covariance_matrix, over_sample_size_lp=over_sample_size_lp, over_sample_size_pixelization=over_sample_size_pixelization, - pad_for_psf=True, + disable_fft_pad=disable_fft_pad, ) - dataset.unmasked = unmasked_dataset - logger.info( f"IMAGING - Data masked, contains a total of {mask.pixels_in_mask} image-pixels" ) @@ -400,6 +401,7 @@ def apply_noise_scaling( self, mask: Mask2D, noise_value: float = 1e8, + disable_fft_pad: bool = False, signal_to_noise_value: Optional[float] = None, should_zero_data: bool = True, ) -> "Imaging": @@ -455,18 +457,6 @@ def apply_noise_scaling( else: data = self.data.native.array - data_unmasked = Array2D.no_mask( - values=data, - shape_native=self.data.shape_native, - pixel_scales=self.data.pixel_scales, - ) - - noise_map_unmasked = Array2D.no_mask( - values=noise_map, - shape_native=self.noise_map.shape_native, - pixel_scales=self.noise_map.pixel_scales, - ) - data = Array2D(values=data, mask=self.data.mask) noise_map = Array2D(values=noise_map, mask=self.data.mask) @@ -478,15 +468,10 @@ def apply_noise_scaling( noise_covariance_matrix=self.noise_covariance_matrix, over_sample_size_lp=self.over_sample_size_lp, over_sample_size_pixelization=self.over_sample_size_pixelization, - pad_for_psf=False, + disable_fft_pad=disable_fft_pad, check_noise_map=False, ) - if self.unmasked is not None: - dataset.unmasked = self.unmasked - dataset.unmasked.data = data_unmasked - dataset.unmasked.noise_map = noise_map_unmasked - logger.info( f"IMAGING - Data noise scaling applied, a total of {mask.pixels_in_mask} pixels were scaled to large noise values." ) @@ -497,6 +482,7 @@ def apply_over_sampling( self, over_sample_size_lp: Union[int, Array2D] = None, over_sample_size_pixelization: Union[int, Array2D] = None, + disable_fft_pad: bool = False, ) -> "AbstractDataset": """ Apply new over sampling objects to the grid and grid pixelization of the dataset. @@ -526,7 +512,7 @@ def apply_over_sampling( over_sample_size_lp=over_sample_size_lp or self.over_sample_size_lp, over_sample_size_pixelization=over_sample_size_pixelization or self.over_sample_size_pixelization, - pad_for_psf=False, + disable_fft_pad=disable_fft_pad, check_noise_map=False, ) diff --git a/autoarray/dataset/imaging/simulator.py b/autoarray/dataset/imaging/simulator.py index 576dc6017..081498701 100644 --- a/autoarray/dataset/imaging/simulator.py +++ b/autoarray/dataset/imaging/simulator.py @@ -126,7 +126,7 @@ def via_image_from( pixel_scales=image.pixel_scales, ) - image = self.psf.convolved_array_from(array=image) + image = self.psf.convolved_image_from(image=image, blurring_image=None) image = image + background_sky_map @@ -169,12 +169,16 @@ def via_image_from( image = Array2D(values=image, mask=mask) dataset = Imaging( - data=image, psf=self.psf, noise_map=noise_map, check_noise_map=False + data=image, + psf=self.psf, + noise_map=noise_map, + check_noise_map=False, + disable_fft_pad=True, ) if over_sample_size is not None: dataset = dataset.apply_over_sampling( - over_sample_size_lp=over_sample_size.native + over_sample_size_lp=over_sample_size.native, disable_fft_pad=True ) return dataset diff --git a/autoarray/geometry/geometry_1d.py b/autoarray/geometry/geometry_1d.py index a6380ca75..bfaad808b 100644 --- a/autoarray/geometry/geometry_1d.py +++ b/autoarray/geometry/geometry_1d.py @@ -1,11 +1,6 @@ from __future__ import annotations import logging -import numpy as np -from typing import TYPE_CHECKING, List, Tuple, Union - -if TYPE_CHECKING: - from autoarray.structures.grids.uniform_1d import Grid1D - from autoarray.mask.mask_2d import Mask2D +from typing import Tuple from autoarray import type as ty diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 9167af6f9..1d94c464e 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -92,7 +92,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: return [ ( - self.psf.convolve_mapping_matrix( + self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_obj.mapping_matrix, mask=self.mask ) if linear_obj.operated_mapping_matrix_override is None @@ -134,7 +134,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: if linear_func.operated_mapping_matrix_override is not None: operated_mapping_matrix = linear_func.operated_mapping_matrix_override else: - operated_mapping_matrix = self.psf.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_func.mapping_matrix, mask=self.mask, ) @@ -215,7 +215,7 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict: mapper_operated_mapping_matrix_dict = {} for mapper in self.cls_list_from(cls=AbstractMapper): - operated_mapping_matrix = self.psf.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=mapper.mapping_matrix, mask=self.mask, ) diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 698750a22..169d45f9d 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -73,7 +73,7 @@ def _data_vector_mapper(self) -> np.ndarray: mapper = mapper_list[i] param_range = mapper_param_range_list[i] - operated_mapping_matrix = self.psf.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=mapper.mapping_matrix, mask=self.mask ) @@ -132,7 +132,7 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: mapper_i = mapper_list[i] mapper_param_range_i = mapper_param_range_list[i] - operated_mapping_matrix = self.psf.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=mapper_i.mapping_matrix, mask=self.mask ) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index 58a3ccc63..ed87179e5 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -518,8 +518,13 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: reconstruction=np.array(reconstruction), ) - mapped_reconstructed_image = self.psf.convolve_image_no_blurring( - image=mapped_reconstructed_image, mask=self.mask + mapped_reconstructed_image = Array2D( + values=mapped_reconstructed_image, mask=self.mask + ) + + mapped_reconstructed_image = self.psf.convolved_image_from( + image=mapped_reconstructed_image, + blurring_image=None, ).array mapped_reconstructed_image = Array2D( diff --git a/autoarray/mask/derive/mask_2d.py b/autoarray/mask/derive/mask_2d.py index 19e005362..aa28068e1 100644 --- a/autoarray/mask/derive/mask_2d.py +++ b/autoarray/mask/derive/mask_2d.py @@ -1,6 +1,5 @@ from __future__ import annotations import logging -import copy import numpy as np from typing import TYPE_CHECKING, Tuple @@ -10,7 +9,6 @@ from autoarray import exc from autoarray.mask.derive.indexes_2d import DeriveIndexes2D -from autoarray.structures.arrays import array_2d_util from autoarray.mask import mask_2d_util logging.basicConfig() diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 0f4fe30f9..bf6fc64c8 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -653,7 +653,9 @@ def unmasked_blurred_array_from(self, padded_array, psf, image_shape) -> Array2D The 1D unmasked image which is blurred. """ - blurred_image = psf.convolved_array_from(array=padded_array) + blurred_image = psf.convolved_image_from( + image=padded_array, blurring_image=None + ) return self.trimmed_array_from( padded_array=blurred_image, image_shape=image_shape diff --git a/autoarray/operators/mock/mock_psf.py b/autoarray/operators/mock/mock_psf.py index e89d2b732..44fdf847f 100644 --- a/autoarray/operators/mock/mock_psf.py +++ b/autoarray/operators/mock/mock_psf.py @@ -2,5 +2,5 @@ class MockPSF: def __init__(self, operated_mapping_matrix=None): self.operated_mapping_matrix = operated_mapping_matrix - def convolve_mapping_matrix(self, mapping_matrix, mask): + def convolved_mapping_matrix_from(self, mapping_matrix, mask): return self.operated_mapping_matrix diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 2c9698398..0fb5450a2 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -1,9 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from autoarray import Mask2D + import jax import jax.numpy as jnp import numpy as np from pathlib import Path +import scipy from typing import List, Optional, Tuple, Union +import warnings +from autoconf import conf from autoconf.fitsable import header_obj_from from autoarray.structures.arrays.uniform_2d import AbstractArray2D @@ -11,9 +21,7 @@ from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.header import Header -from autoarray import exc from autoarray import type as ty -from autoarray.structures.arrays import array_2d_util class Kernel2D(AbstractArray2D): @@ -26,26 +34,83 @@ def __init__( store_native: bool = False, image_mask=None, blurring_mask=None, + mask_shape=None, + full_shape=None, + fft_shape=None, + use_fft: Optional[bool] = None, *args, **kwargs, ): """ - An array of values, which are paired to a uniform 2D mask of pixels. Each entry - on the array corresponds to a value at the centre of a pixel in an unmasked pixel. See the ``Array2D`` class - for a full description of how Arrays work. + A 2D convolution kernel stored as an array of values paired to a uniform 2D mask. + + The ``Kernel2D`` is a subclass of ``Array2D`` with additional methods for performing + point spread function (PSF) convolution of images or mapping matrices. Each entry of + the kernel corresponds to a PSF value at the centre of a pixel in the unmasked grid. - The ``Kernel2D`` class is an ``Array2D`` but with additioonal methods that allow it to be convolved with data. + Two convolution modes are supported: + + - **Real-space convolution**: performed directly via sliding-window summation or + ``jax.scipy.signal.convolve``. This is exact but can be slow for large kernels. + - **FFT convolution**: performed by transforming both the kernel and the input image + into Fourier space, multiplying, and transforming back. This is typically faster + for kernels larger than ~5×5, but requires careful zero-padding. + + When using FFT convolution, the input image and mask are automatically padded such + that the FFT avoids circular wrap-around artefacts. This padding is computed from the + kernel size via :meth:`fft_shape_from`. The padded shape is stored in ``fft_shape``. + If FFT convolution is attempted without precomputing and applying this padding, + an exception is raised to avoid silent shape mismatches. Parameters ---------- values - The values of the array. + The raw 2D kernel values. Can be normalised to sum to unity if ``normalize=True``. mask - The 2D mask associated with the array, defining the pixels each array value is paired with and - originates from. + The 2D mask associated with the kernel, defining the pixels each kernel value is + paired with. + header + Optional metadata (e.g. FITS header) associated with the kernel. normalize - If True, the Kernel2D's array values are normalized such that they sum to 1.0. + If True, the kernel values are rescaled such that they sum to 1.0. + store_native + If True, the kernel is stored in its full native 2D form as an attribute + ``stored_native`` for re-use (e.g. when convolving repeatedly). + image_mask + Optional mask defining the unmasked image pixels when performing convolution. + If not provided, defaults to the supplied ``mask``. + blurring_mask + Optional mask defining the "blurring region": pixels outside the image mask + into which PSF flux can spread. Used to construct blurring images and + blurring mapping matrices. + mask_shape + The shape of the (unpadded) mask region. Used when cropping back results after + FFT convolution. + full_shape + The unpadded image + kernel shape (``image_shape + kernel_shape - 1``). + fft_shape + The padded shape used in FFT convolution, typically computed via + ``scipy.fft.next_fast_len`` for efficiency. Must be precomputed before calling + FFT convolution methods. + use_fft + If True, convolution is performed in Fourier space with zero-padding. + If False, convolution is performed in real space. + If None, a default choice is made: real space for small kernels, + FFT for large kernels. + *args, **kwargs + Passed to the ``Array2D`` constructor. + + Notes + ----- + - FFT padding can be disabled globally with ``disable_fft_pad=True`` when + constructing ``Imaging`` objects, in which case convolution will either + use real space or proceed without padding. + - Blurring masks ensure that PSF flux spilling outside the main image mask + is included correctly. Omitting them may lead to underestimated PSF wings. + - For unit tests with tiny kernels, FFT and real-space convolution may differ + slightly due to edge and truncation effects. """ + super().__init__( values=values, mask=mask, @@ -77,6 +142,28 @@ def __init__( slim_to_native_blurring[:, 1], ) + self.fft_shape = fft_shape + + self.mask_shape = None + self.full_shape = None + self.fft_psf = None + + if self.fft_shape is not None: + + self.mask_shape = mask_shape + self.full_shape = full_shape + self.fft_psf = jnp.fft.rfft2(self.native.array, s=self.fft_shape) + self.fft_psf_mapping = jnp.expand_dims(self.fft_psf, 2) + + self._use_fft = use_fft + + @property + def use_fft(self): + if self._use_fft is None: + return conf.instance["general"]["psf"]["use_fft_default"] + + return self._use_fft + @classmethod def no_mask( cls, @@ -88,6 +175,10 @@ def no_mask( normalize: bool = False, image_mask=None, blurring_mask=None, + mask_shape=None, + full_shape=None, + fft_shape=None, + use_fft: Optional[bool] = None, ): """ Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically @@ -122,6 +213,10 @@ def no_mask( normalize=normalize, image_mask=image_mask, blurring_mask=blurring_mask, + mask_shape=mask_shape, + full_shape=full_shape, + fft_shape=fft_shape, + use_fft=use_fft, ) @classmethod @@ -391,27 +486,407 @@ def from_fits( header=Header(header_sci_obj=header_sci_obj, header_hdu_obj=header_hdu_obj), ) + def fft_shape_from( + self, mask: np.ndarray + ) -> Union[Tuple[int, int], Tuple[int, int], Tuple[int, int]]: + """ + Compute the padded shapes required for FFT-based convolution with this kernel. + + FFT convolution requires the input image and kernel to be zero-padded so that + the convolution is equivalent to linear convolution (not circular) and to avoid + wrap-around artefacts. This method inspects the mask and the kernel shape to + determine three key shapes: + + - ``mask_shape``: the rectangular bounding-box region of the mask that encloses + all unmasked (False) pixels, padded by half the kernel size in each direction. + This is the minimal region that must be retained for convolution. + - ``full_shape``: the "linear convolution shape", equal to + ``mask_shape + kernel_shape - 1``. This is the minimal padded size required + for an exact linear convolution. + - ``fft_shape``: the FFT-efficient padded shape, obtained by rounding each + dimension of ``full_shape`` up to the next fast length for real FFTs + (via ``scipy.fft.next_fast_len``). Using this ensures efficient FFT execution. + + Parameters + ---------- + mask + A 2D mask where False indicates unmasked pixels (valid data) and True + indicates masked pixels. The bounding-box of the False region is used + to compute the convolution region. + + Returns + ------- + full_shape + The unpadded linear convolution shape (mask region + kernel − 1). + fft_shape + The FFT-friendly padded shape for efficient convolution. + mask_shape + The rectangular mask region size including kernel padding. + """ + + ys, xs = np.where(~mask) + y_min, y_max = ys.min(), ys.max() + x_min, x_max = xs.min(), xs.max() + + (pad_y, pad_x) = self.shape_native + + mask_shape = ( + (y_max + pad_y // 2) - (y_min - pad_y // 2), + (x_max + pad_x // 2) - (x_min - pad_x // 2), + ) + + full_shape = tuple(s1 + s2 - 1 for s1, s2 in zip(mask_shape, self.shape_native)) + fft_shape = tuple(scipy.fft.next_fast_len(s, real=True) for s in full_shape) + + return full_shape, fft_shape, mask_shape + + @property + def normalized(self) -> "Kernel2D": + """ + Normalize the Kernel2D such that its data_vector values sum to unity. + """ + return Kernel2D(values=self, mask=self.mask, normalize=True) + + def mapping_matrix_native_from( + self, + mapping_matrix: jnp.ndarray, + mask: "Mask2D", + blurring_mapping_matrix: Optional[jnp.ndarray] = None, + blurring_mask: Optional["Mask2D"] = None, + ) -> jnp.ndarray: + """ + Expand a slim mapping matrix (image-plane) and optional blurring mapping matrix + into a full native 3D cube (ny, nx, n_src). + + This is primarily used for real-space convolution, where the pixel-to-source + mapping must be represented on the full image grid. + + Parameters + ---------- + mapping_matrix : ndarray (N_pix, N_src) + Slim mapping matrix for unmasked image pixels, mapping each image pixel + to source-plane pixels. + mask : Mask2D + Mask defining which image pixels are unmasked. Used to expand the slim + mapping matrix into a native grid. + blurring_mapping_matrix : ndarray (N_blur, N_src), optional + Mapping matrix for blurring pixels outside the main mask (e.g. light + spilling in from outside). If provided, it is also scattered into the + native cube. + blurring_mask : Mask2D, optional + Mask defining the blurring region pixels. Must be provided if + `blurring_mapping_matrix` is given and `slim_to_native_blurring_tuple` + is not already cached. + + Returns + ------- + ndarray (ny, nx, N_src) + Native 3D mapping matrix cube with dimensions (image_y, image_x, sources). + Contains contributions from both the main mapping matrix and, if provided, + the blurring mapping matrix. + """ + slim_to_native_tuple = self.slim_to_native_tuple + if slim_to_native_tuple is None: + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(mask.array), size=mapping_matrix.shape[0] + ) + + n_src = mapping_matrix.shape[1] + + # Allocate full native grid (ny, nx, n_src) + mapping_matrix_native = jnp.zeros( + mask.shape + (n_src,), dtype=mapping_matrix.dtype + ) + + # Scatter main mapping matrix into native cube + mapping_matrix_native = mapping_matrix_native.at[slim_to_native_tuple].set( + mapping_matrix + ) + + # Optionally scatter blurring mapping matrix + if blurring_mapping_matrix is not None: + slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple + + if slim_to_native_blurring_tuple is None: + if blurring_mask is None: + raise ValueError( + "blurring_mask must be provided if blurring_mapping_matrix is given " + "and slim_to_native_blurring_tuple is None." + ) + slim_to_native_blurring_tuple = jnp.nonzero( + jnp.logical_not(blurring_mask.array), + size=blurring_mapping_matrix.shape[0], + ) + + mapping_matrix_native = mapping_matrix_native.at[ + slim_to_native_blurring_tuple + ].set(blurring_mapping_matrix) + + return mapping_matrix_native + + def convolved_image_from(self, image, blurring_image, jax_method="direct"): + """ + Convolve an input masked image with this PSF. + + This method chooses between an FFT-based convolution (default if + ``self.use_fft=True``) or a direct real-space convolution, depending on + how the Kernel2D was configured. + + In the FFT branch: + - The input image (and optional blurring image) are resized / padded to + match the FFT-friendly padded shape (``fft_shape``) associated with this kernel. + - The PSF and image are transformed to Fourier space via ``jax.numpy.fft.rfft2``. + - Convolution is performed as elementwise multiplication. + - The result is inverse-transformed and cropped back to the masked region. + + Padding ensures that the FFT implements *linear* convolution, not circular, + and avoids wrap-around artefacts. The required padding is determined by + ``fft_shape_from(mask)``. If no precomputed shapes exist, they are computed + on the fly. For reproducible behaviour, precompute and set + ``fft_shape``, ``full_shape``, and ``mask_shape`` on the kernel. + + If ``use_fft=False``, convolution falls back to + :meth:`Kernel2D.convolved_image_via_real_space_from`. + + Parameters + ---------- + image + Masked 2D image array to convolve. + blurring_image + Masked image containing flux from outside the mask core that blurs + into the masked region after convolution. If ``None``, only the direct + image is convolved, which may be numerically incorrect if the mask + excludes PSF wings. + jax_method : {"direct", "fft"} + Backend passed to ``jax.scipy.signal.convolve`` when in real-space mode. + Ignored for FFT convolutions. + + Returns + ------- + Array2D + The convolved image in slim (1D masked) format. + """ + + if not self.use_fft: + return self.convolved_image_via_real_space_from( + image=image, blurring_image=blurring_image, jax_method=jax_method + ) + + if self.fft_shape is None: + + full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=image.mask) + fft_psf = jnp.fft.rfft2(self.stored_native.array, s=fft_shape) + + image = image.resized_from(new_shape=fft_shape, mask_pad_value=1) + if blurring_image is not None: + blurring_image = blurring_image.resized_from( + new_shape=fft_shape, mask_pad_value=1 + ) + + else: + + fft_shape = self.fft_shape + full_shape = self.full_shape + mask_shape = self.mask_shape + fft_psf = self.fft_psf + + slim_to_native_tuple = self.slim_to_native_tuple + slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple + + if slim_to_native_tuple is None: + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(image.mask.array), size=image.shape[0] + ) + + # start with native image padded with zeros + image_both_native = jnp.zeros(image.mask.shape, dtype=image.dtype) + image_both_native = image_both_native.at[slim_to_native_tuple].set( + jnp.asarray(image.array) + ) + + # add blurring contribution if provided + if blurring_image is not None: + if slim_to_native_blurring_tuple is None: + slim_to_native_blurring_tuple = jnp.nonzero( + jnp.logical_not(blurring_image.mask.array), + size=blurring_image.shape[0], + ) + image_both_native = image_both_native.at[slim_to_native_blurring_tuple].set( + jnp.asarray(blurring_image.array) + ) + else: + warnings.warn( + "No blurring_image provided. Only the direct image will be convolved. " + "This may change the correctness of the PSF convolution." + ) + + # FFT the combined image + fft_image_native = jnp.fft.rfft2(image_both_native, s=fft_shape, axes=(0, 1)) + + # Multiply by PSF in Fourier space and invert + blurred_image_full = jnp.fft.irfft2( + fft_psf * fft_image_native, s=fft_shape, axes=(0, 1) + ) + + # Crop back to mask_shape + start_indices = tuple( + (full_size - out_size) // 2 + for full_size, out_size in zip(full_shape, mask_shape) + ) + out_shape_full = mask_shape + blurred_image_native = jax.lax.dynamic_slice( + blurred_image_full, start_indices, out_shape_full + ) + + return Array2D( + values=blurred_image_native[slim_to_native_tuple], mask=image.mask + ) + + def convolved_mapping_matrix_from( + self, + mapping_matrix, + mask, + blurring_mapping_matrix=None, + blurring_mask: Optional[Mask2D] = None, + jax_method="direct", + ): + """ + Convolve a source-plane mapping matrix with this PSF. + + A mapping matrix maps image-plane unmasked pixels to source-plane pixels. + This method performs the equivalent operation of PSF convolution on the + mapping matrix, so that model visibilities / images can be computed via + matrix multiplication instead of explicit convolution. + + If ``use_fft=True``, convolution is performed in Fourier space: + - The mapping matrix is scattered into a 3D native cube + (ny, nx, n_src). + - An FFT of this cube is multiplied by the precomputed FFT of the PSF. + - The inverse FFT is taken and cropped to the mask region. + - The slim (masked 1D) representation is returned. + + If ``use_fft=False``, convolution falls back to + :meth:`Kernel2D.convolved_mapping_matrix_via_real_space_from`. + + Notes + ----- + - FFT convolution requires that ``self.fft_shape`` and related padding + attributes are precomputed. If not, a ``ValueError`` is raised with the + expected vs actual shapes. This ensures the mapping matrix is padded + consistently with the PSF. + - The optional ``blurring_mapping_matrix`` plays the same role as + ``blurring_image`` in :meth:`convolved_image_from`, accounting for PSF flux + that falls into the masked region from outside. + + Parameters + ---------- + mapping_matrix : ndarray of shape (N_pix, N_src) + Slim mapping matrix from unmasked pixels to source pixels. + mask : Mask2D + Associated mask defining the image grid. + blurring_mapping_matrix : ndarray of shape (N_blur, N_src), optional + Mapping matrix for the blurring region, outside the mask core. + jax_method : str + Backend passed to real-space convolution if ``use_fft=False``. + + Returns + ------- + ndarray of shape (N_pix, N_src) + Convolved mapping matrix in slim form. + """ + if not self.use_fft: + return self.convolved_mapping_matrix_via_real_space_from( + mapping_matrix=mapping_matrix, + mask=mask, + blurring_mapping_matrix=blurring_mapping_matrix, + blurring_mask=blurring_mask, + jax_method=jax_method, + ) + + if self.fft_shape is None: + + full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=mask) + + raise ValueError( + f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n" + f"Expected mapping matrix padded to match FFT shape of PSF.\n" + f"PSF fft_shape: {fft_shape}, mask shape: {mask.shape}, " + f"mapping_matrix shape: {getattr(mapping_matrix, 'shape', 'unknown')}." + ) + + else: + + fft_shape = self.fft_shape + full_shape = self.full_shape + mask_shape = self.mask_shape + fft_psf_mapping = self.fft_psf_mapping + + slim_to_native_tuple = self.slim_to_native_tuple + + if slim_to_native_tuple is None: + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(mask.array), size=mapping_matrix.shape[0] + ) + + mapping_matrix_native = self.mapping_matrix_native_from( + mapping_matrix=mapping_matrix, + mask=mask, + blurring_mapping_matrix=blurring_mapping_matrix, + blurring_mask=blurring_mask, + ) + + # FFT convolution + fft_mapping_matrix_native = jnp.fft.rfft2( + mapping_matrix_native, s=fft_shape, axes=(0, 1) + ) + blurred_mapping_matrix_full = jnp.fft.irfft2( + fft_psf_mapping * fft_mapping_matrix_native, + s=fft_shape, + axes=(0, 1), + ) + + # crop back + start_indices = tuple( + (full_size - out_size) // 2 + for full_size, out_size in zip(full_shape, mask_shape) + ) + (0,) + out_shape_full = mask_shape + (blurred_mapping_matrix_full.shape[2],) + blurred_mapping_matrix_native = jax.lax.dynamic_slice( + blurred_mapping_matrix_full, start_indices, out_shape_full + ) + + # return slim form + return blurred_mapping_matrix_native[slim_to_native_tuple] + def rescaled_with_odd_dimensions_from( self, rescale_factor: float, normalize: bool = False ) -> "Kernel2D": """ - If the PSF kernel has one or two even-sized dimensions, return a PSF object where the kernel has odd-sized - dimensions (odd-sized dimensions are required for 2D convolution). + Return a version of this kernel rescaled so both dimensions are odd-sized. - The PSF can be scaled to larger / smaller sizes than the input size, if the rescale factor uses values that - deviate furher from 1.0. + Odd-sized kernels are often required for real space convolution operations + (e.g. centered PSFs in imaging pipelines). If the kernel has one or two + even-sized dimensions, they are rescaled (via interpolation) and padded + so that both dimensions are odd. - Kernels are rescald using the scikit-image routine rescale, which performs rescaling via an interpolation - routine. This may lead to loss of accuracy in the PSF kernel and it is advised that users, where possible, - create their PSF on an odd-sized array using their data reduction pipelines that remove this approximation. + The kernel can also be scaled larger or smaller by changing + ``rescale_factor``. Rescaling uses ``skimage.transform.rescale`` / + ``resize``, which interpolate pixel values and may introduce small + inaccuracies compared to native instrument PSFs. Where possible, users + should generate odd-sized PSFs directly from data reduction. Parameters ---------- rescale_factor - The factor by which the kernel is rescaled. If this has a value of 1.0, the kernel is rescaled to the - closest odd-sized dimensions (e.g. 20 -> 19). Higher / lower values scale to higher / lower dimensions. + Factor by which the kernel is rescaled. If 1.0, only adjusts size to + nearest odd dimensions. Values > 1 enlarge, < 1 shrink the kernel. normalize - Whether the PSF should be normalized after being rescaled. + If True, the returned kernel is normalized to sum to 1.0. + + Returns + ------- + Kernel2D + Rescaled kernel with odd-sized dimensions. """ from skimage.transform import resize, rescale @@ -477,237 +952,134 @@ def rescaled_with_odd_dimensions_from( values=kernel_rescaled, pixel_scales=pixel_scales, normalize=normalize ) - @property - def normalized(self) -> "Kernel2D": - """ - Normalize the Kernel2D such that its data_vector values sum to unity. - """ - return Kernel2D(values=self, mask=self.mask, normalize=True) - - def convolved_array_from(self, array: Array2D) -> Array2D: - """ - Convolve an array with this Kernel2D - - Parameters - ---------- - image - An array representing the image the Kernel2D is convolved with. - - Returns - ------- - convolved_image - An array representing the image after convolution. - - Raises - ------ - KernelException if either Kernel2D psf dimension is odd + def convolved_image_via_real_space_from( + self, + image: np.ndarray, + blurring_image: Optional[np.ndarray] = None, + jax_method: str = "direct", + ): """ - import scipy.signal - - if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") - - array_2d = array.native - - convolved_array_2d = scipy.signal.convolve2d( - array_2d.array, self.native.array, mode="same" - ) - - convolved_array_1d = array_2d_util.array_2d_slim_from( - mask_2d=array_2d.mask, - array_2d_native=convolved_array_2d, - ) + Convolve an input masked image with this PSF in real space. - return Array2D(values=convolved_array_1d, mask=array_2d.mask) + This is the direct method (non-FFT) where convolution is explicitly + performed using ``jax.scipy.signal.convolve`` with the kernel in native + space. - def convolved_array_with_mask_from(self, array: Array2D, mask) -> Array2D: - """ - Convolve an array with this Kernel2D + Unlike FFT convolution, this does not require padding shapes, but it is + typically much slower for large kernels (> ~5x5). Parameters ---------- image - An array representing the image the Kernel2D is convolved with. + Masked image array to convolve. + blurring_image + Blurring contribution from outside the mask core. If None, only the + direct image is convolved (which may be numerically incorrect). + jax_method + Method flag for JAX convolution backend (default "direct"). Returns ------- - convolved_image - An array representing the image after convolution. - - Raises - ------ - KernelException if either Kernel2D psf dimension is odd - """ - import scipy.signal - - if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") - - convolved_array_2d = scipy.signal.convolve2d( - array.array, self.native.array, mode="same" - ) - - convolved_array_1d = array_2d_util.array_2d_slim_from( - mask_2d=mask, - array_2d_native=convolved_array_2d, - ) - - return Array2D(values=convolved_array_1d, mask=mask) - - def convolve_image(self, image, blurring_image, jax_method="direct"): - """ - For a given 1D array and blurring array, convolve the two using this psf. - - Parameters - ---------- - image - 1D array of the values which are to be blurred with the psf's PSF. - blurring_image - 1D array of the blurring values which blur into the array after PSF convolution. - jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. + Array2D + Convolved image in slim format. """ slim_to_native_tuple = self.slim_to_native_tuple slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple if slim_to_native_tuple is None: - slim_to_native_tuple = jnp.nonzero( jnp.logical_not(image.mask.array), size=image.shape[0] ) - if slim_to_native_blurring_tuple is None: - - slim_to_native_blurring_tuple = jnp.nonzero( - jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] - ) - - # make sure dtype matches what you want - expanded_array_native = jnp.zeros( - image.mask.shape, dtype=jnp.asarray(image.array).dtype - ) + # start with native array padded with zeros + image_native = jnp.zeros(image.mask.shape, dtype=jnp.asarray(image.array).dtype) - # set using a tuple of index arrays - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( + # set image pixels + image_native = image_native.at[slim_to_native_tuple].set( jnp.asarray(image.array) ) - expanded_array_native = expanded_array_native.at[ - slim_to_native_blurring_tuple - ].set(jnp.asarray(blurring_image.array)) - - kernel = self.stored_native.array - - convolve_native = jax.scipy.signal.convolve( - expanded_array_native, kernel, mode="same", method=jax_method - ) - - convolved_array_1d = convolve_native[slim_to_native_tuple] - - return Array2D(values=convolved_array_1d, mask=image.mask) - - def convolve_image_no_blurring(self, image, mask, jax_method="direct"): - """ - For a given 1D array and blurring array, convolve the two using this psf. - - Parameters - ---------- - image - 1D array of the values which are to be blurred with the psf's PSF. - blurring_image - 1D array of the blurring values which blur into the array after PSF convolution. - jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. - """ - - slim_to_native_tuple = self.slim_to_native_tuple - - if slim_to_native_tuple is None: - - slim_to_native_tuple = jnp.nonzero( - jnp.logical_not(mask.array), size=image.shape[0] - ) - # make sure dtype matches what you want - expanded_array_native = jnp.zeros(mask.shape) - - # set using a tuple of index arrays - if isinstance(image, np.ndarray) or isinstance(image, jnp.ndarray): - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( - image + # add blurring contribution if provided + if blurring_image is not None: + if slim_to_native_blurring_tuple is None: + slim_to_native_blurring_tuple = jnp.nonzero( + jnp.logical_not(blurring_image.mask.array), + size=blurring_image.shape[0], + ) + image_native = image_native.at[slim_to_native_blurring_tuple].set( + jnp.asarray(blurring_image.array) ) else: - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( - jnp.asarray(image.array) + warnings.warn( + "No blurring_image provided. Only the direct image will be convolved. " + "This may change the correctness of the PSF convolution." ) + # perform real-space convolution kernel = self.stored_native.array - convolve_native = jax.scipy.signal.convolve( - expanded_array_native, kernel, mode="same", method=jax_method + image_native, kernel, mode="same", method=jax_method ) convolved_array_1d = convolve_native[slim_to_native_tuple] - return Array2D(values=convolved_array_1d, mask=mask) + return Array2D(values=convolved_array_1d, mask=image.mask) - def convolve_image_no_blurring_for_mapping(self, image, mask, jax_method="direct"): + def convolved_mapping_matrix_via_real_space_from( + self, + mapping_matrix: np.ndarray, + mask, + blurring_mapping_matrix: Optional[np.ndarray] = None, + blurring_mask: Optional[Mask2D] = None, + jax_method: str = "direct", + ): """ - For a given 1D array and blurring array, convolve the two using this psf. + Convolve a source-plane mapping matrix with this PSF in real space. + + Equivalent to :meth:`convolved_mapping_matrix_from`, but using explicit + real-space convolution rather than FFTs. This avoids FFT padding issues + but is slower for large kernels. + + The mapping matrix is expanded into a native cube (ny, nx, n_src), + convolved with the kernel (broadcast along the source axis), + and reduced back to slim form. Parameters ---------- - image - 1D array of the values which are to be blurred with the psf's PSF. - blurring_image - 1D array of the blurring values which blur into the array after PSF convolution. + mapping_matrix + Slim mapping matrix from unmasked pixels to source pixels. + mask + Mask defining the pixelization grid. + blurring_mapping_matrix : ndarray (N_blur, N_src), optional + Mapping matrix for blurring region pixels outside the mask core. jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. + Backend passed to JAX convolution. + + Returns + ------- + ndarray (N_pix, N_src) + Convolved mapping matrix in slim form. """ slim_to_native_tuple = self.slim_to_native_tuple if slim_to_native_tuple is None: - slim_to_native_tuple = jnp.nonzero( - jnp.logical_not(mask.array), size=image.shape[0] + jnp.logical_not(mask.array), size=mapping_matrix.shape[0] ) - # make sure dtype matches what you want - expanded_array_native = jnp.zeros(mask.shape) - - # set using a tuple of index arrays - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( - image + mapping_matrix_native = self.mapping_matrix_native_from( + mapping_matrix=mapping_matrix, + mask=mask, + blurring_mapping_matrix=blurring_mapping_matrix, + blurring_mask=blurring_mask, ) - + # 6) Real-space convolution, broadcast kernel over source axis kernel = self.stored_native.array - - convolve_native = jax.scipy.signal.convolve( - expanded_array_native, kernel, mode="same", method=jax_method + blurred_mapping_matrix_native = jax.scipy.signal.convolve( + mapping_matrix_native, kernel[..., None], mode="same", method=jax_method ) - convolved_array_1d = convolve_native[slim_to_native_tuple] - - return Array2D(values=convolved_array_1d, mask=mask) - - def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"): - """For a given 1D array and blurring array, convolve the two using this psf. - - Parameters - ---------- - image - 1D array of the values which are to be blurred with the psf's PSF. - """ - return jax.vmap( - self.convolve_image_no_blurring_for_mapping, in_axes=(1, None, None) - )(mapping_matrix, mask, jax_method).T + # return slim form + return blurred_mapping_matrix_native[slim_to_native_tuple] diff --git a/test_autoarray/config/general.yaml b/test_autoarray/config/general.yaml index 6f331d141..66d3354fd 100644 --- a/test_autoarray/config/general.yaml +++ b/test_autoarray/config/general.yaml @@ -2,6 +2,8 @@ analysis: n_cores: 1 fits: flip_for_ds9: false +psf: + use_fft_default: false # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution. Real space used for unit tests. grid: remove_projected_centre: false adapt: diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index ead9e51f5..2e07a4cc3 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -33,22 +33,24 @@ def make_test_data_path(): return test_data_path -def test__psf_and_mask_hit_edge__automatically_pads_image_and_noise_map(): +def test__psf_and_mask_hit_edge__automatically_pads_image_and_noise_map_for_fft(): image = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) noise_map = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) psf = aa.Kernel2D.ones(shape_native=(3, 3), pixel_scales=1.0) - dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=False) + dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, disable_fft_pad=True) assert dataset.data.shape_native == (3, 3) assert dataset.noise_map.shape_native == (3, 3) - dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=True) + dataset = aa.Imaging( + data=image, noise_map=noise_map, psf=psf, disable_fft_pad=False + ) - assert dataset.data.shape_native == (5, 5) - assert dataset.noise_map.shape_native == (5, 5) + assert dataset.data.shape_native == (6, 6) + assert dataset.noise_map.shape_native == (6, 6) assert dataset.data.mask[0, 0] == True - assert dataset.data.mask[1, 1] == False + assert dataset.data.mask[2, 2] == False def test__noise_covariance_input__noise_map_uses_diag(): @@ -126,7 +128,7 @@ def test__output_to_fits(imaging_7x7, test_data_path): def test__apply_mask(imaging_7x7, mask_2d_7x7, psf_3x3): - masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7) + masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7, disable_fft_pad=True) assert (masked_imaging_7x7.data.slim == np.ones(9)).all() @@ -263,5 +265,5 @@ def test__psf_not_odd_x_odd_kernel__raises_error(): psf = aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) dataset = aa.Imaging( - data=image, noise_map=noise_map, psf=psf, pad_for_psf=False + data=image, noise_map=noise_map, psf=psf, disable_fft_pad=True ) diff --git a/test_autoarray/fit/test_fit_dataset.py b/test_autoarray/fit/test_fit_dataset.py index 1f8ec6380..4363c7707 100644 --- a/test_autoarray/fit/test_fit_dataset.py +++ b/test_autoarray/fit/test_fit_dataset.py @@ -64,7 +64,10 @@ def test__figure_of_merit__with_noise_covariance_matrix_in_dataset( assert fit.chi_squared != pytest.approx(chi_squared, 1.0e-4) -def test__grid_offset_via_data_model(masked_imaging_7x7, model_image_7x7): +def test__grid_offset_via_data_model(imaging_7x7, mask_2d_7x7, model_image_7x7): + + masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7, disable_fft_pad=True) + fit = aa.m.MockFitImaging( dataset=masked_imaging_7x7, use_mask_in_fit=False, diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index cafb8722b..546ffe763 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -203,7 +203,7 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): mapping_matrix = mapper.mapping_matrix - blurred_mapping_matrix = psf.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolved_mapping_matrix_from( mapping_matrix=mapping_matrix, mask=mask ) @@ -290,7 +290,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): w_tilde=w_tilde, mapping_matrix=mapping_matrix ) - blurred_mapping_matrix = psf.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolved_mapping_matrix_from( mapping_matrix=mapping_matrix, mask=mask ) @@ -370,7 +370,7 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): pix_pixels=pixelization.pixels, ) - blurred_mapping_matrix = psf.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolved_mapping_matrix_from( mapping_matrix=mapping_matrix, mask=mask, ) diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 6fa4e7295..1b1420633 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -290,66 +290,7 @@ def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) -def test__convolved_array_from(): - - array_2d = aa.Array2D.no_mask( - [ - [0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0], - ], - pixel_scales=1.0, - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 3.0, 3.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) - - assert ( - blurred_array_2d.native - == np.array( - [ - [1.0, 1.0, 0.0, 0.0], - [2.0, 1.0, 1.0, 1.0], - [3.0, 3.0, 2.0, 2.0], - [0.0, 0.0, 1.0, 3.0], - ] - ) - ).all() - - array_2d = aa.Array2D.no_mask( - values=[ - [1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ], - pixel_scales=1.0, - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 3.0, 3.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) - - assert ( - blurred_array_2d.native - == np.array( - [ - [2.0, 1.0, 0.0, 0.0], - [3.0, 3.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 2.0, 2.0], - ] - ) - ).all() - - -def test__convolve_image(): +def test__convolved_image_from(): mask = aa.Mask2D.circular( shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 @@ -368,7 +309,7 @@ def test__convolve_image(): values=blurred_image_via_scipy.native, mask=mask ) - # Now reproduce this data using the convolve_image function + # Now reproduce this data using the convolved_image_from function image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) @@ -381,7 +322,7 @@ def test__convolve_image(): blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) - blurred_masked_im_1 = kernel.convolve_image( + blurred_masked_im_1 = kernel.convolved_image_from( image=masked_image, blurring_image=blurring_image ) @@ -390,7 +331,7 @@ def test__convolve_image(): ) -def test__convolve_image_no_blurring(): +def test__convolve_imaged_from__no_blurring(): # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. mask = aa.Mask2D.circular( @@ -420,8 +361,8 @@ def test__convolve_image_no_blurring(): masked_image = aa.Array2D(values=image.native, mask=mask) - blurred_masked_im_1 = kernel.convolve_image_no_blurring( - image=masked_image, mask=mask + blurred_masked_im_1 = kernel.convolved_image_from( + image=masked_image, blurring_image=None ) assert blurred_masked_image_via_scipy == pytest.approx( @@ -429,7 +370,7 @@ def test__convolve_image_no_blurring(): ) -def test__convolve_mapping_matrix(): +def test__convolved_mapping_matrix_from(): mask = aa.Mask2D( mask=np.array( [ @@ -473,7 +414,7 @@ def test__convolve_mapping_matrix(): ] ) - blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask) + blurred_mapping = kernel.convolved_mapping_matrix_from(mapping, mask) assert ( blurred_mapping @@ -531,7 +472,7 @@ def test__convolve_mapping_matrix(): ] ) - blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask) + blurred_mapping = kernel.convolved_mapping_matrix_from(mapping, mask) assert blurred_mapping == pytest.approx( np.array( @@ -556,3 +497,36 @@ def test__convolve_mapping_matrix(): ), abs=1e-4, ) + + +def test__convolve_imaged_from__via_fft__sizes_not_precomputed__compare_numerical_value(): + + # ------------------------------- + # Case 1: direct image convolution + # ------------------------------- + mask = aa.Mask2D.circular( + shape_native=(20, 20), pixel_scales=(1.0, 1.0), radius=5.0 + ) + + image = aa.Array2D.no_mask(values=np.arange(400).reshape(20, 20), pixel_scales=1.0) + masked_image = aa.Array2D(values=image.native, mask=mask) + + kernel_fft = aa.Kernel2D.no_mask( + values=np.arange(49).reshape(7, 7), + pixel_scales=1.0, + use_fft=True, + normalize=True, + ) + + blurring_mask = mask.derive_mask.blurring_from( + kernel_shape_native=kernel_fft.shape_native + ) + blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) + + blurred_fft = kernel_fft.convolved_image_from( + image=masked_image, blurring_image=blurring_image + ) + + assert blurred_fft.native.array[13, 13] == pytest.approx( + 207.49999999999, rel=1e-6, abs=1e-6 + )