From a7d887eb30ec7514148696bf9a352fff2d54be3b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 6 Oct 2025 20:37:59 +0100 Subject: [PATCH 01/19] rename function --- autoarray/structures/arrays/kernel_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index b2992278a..cfce414db 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -553,7 +553,7 @@ def convolved_array_with_mask_from(self, array: Array2D, mask) -> Array2D: return Array2D(values=convolved_array_1d, mask=mask) - def convolve_image(self, image, blurring_image, jax_method="direct"): + def convolve_image_via_real_space(self, image, blurring_image, jax_method="direct"): """ For a given 1D array and blurring array, convolve the two using this psf. From 6495882937cddfc7d9d2a26a3cba199b46162ad9 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 8 Oct 2025 16:04:47 +0100 Subject: [PATCH 02/19] failed --- autoarray/dataset/imaging/dataset.py | 68 ++++++++++++++-------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index aa08cc3ca..8dfb8a644 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -92,48 +92,46 @@ def __init__( self.pad_for_psf = pad_for_psf if pad_for_psf and psf is not None: - try: - data.mask.derive_mask.blurring_from( - kernel_shape_native=psf.shape_native - ) - except exc.MaskException: - over_sample_size_lp = ( - over_sample_util.over_sample_size_convert_to_array_2d_from( - over_sample_size=over_sample_size_lp, mask=data.mask - ) + + pad_shape = (300, 300) + + over_sample_size_lp = ( + over_sample_util.over_sample_size_convert_to_array_2d_from( + over_sample_size=over_sample_size_lp, mask=data.mask ) - over_sample_size_lp = ( - over_sample_size_lp.padded_before_convolution_from( - kernel_shape=psf.shape_native, mask_pad_value=1 - ) + ) + over_sample_size_lp = ( + over_sample_size_lp.resized_from( + new_shape=pad_shape, mask_pad_value=1 ) + ) - over_sample_size_pixelization = ( - over_sample_util.over_sample_size_convert_to_array_2d_from( - over_sample_size=over_sample_size_pixelization, mask=data.mask - ) + over_sample_size_pixelization = ( + over_sample_util.over_sample_size_convert_to_array_2d_from( + over_sample_size=over_sample_size_pixelization, mask=data.mask ) - over_sample_size_pixelization = ( - over_sample_size_pixelization.padded_before_convolution_from( - kernel_shape=psf.shape_native, mask_pad_value=1 - ) + ) + over_sample_size_pixelization = ( + over_sample_size_pixelization.resized_from( + new_shape=pad_shape, mask_pad_value=1 ) + ) - data = data.padded_before_convolution_from( - kernel_shape=psf.shape_native, mask_pad_value=1 - ) - if noise_map is not None: - noise_map = noise_map.padded_before_convolution_from( - kernel_shape=psf.shape_native, mask_pad_value=1 - ) - logger.info( - f"The image and noise map of the `Imaging` objected have been padded to the dimensions" - f"{data.shape}. This is because the blurring region around the mask (which defines where" - f"PSF flux may be convolved into the masked region) extended beyond the edge of the image." - f"" - f"This can be prevented by using a smaller mask, smaller PSF kernel size or manually padding" - f"the image and noise-map yourself." + data = data.resized_from( + new_shape=pad_shape, mask_pad_value=1 + ) + if noise_map is not None: + noise_map = noise_map.resized_from( + new_shape=pad_shape, mask_pad_value=1 ) + logger.info( + f"The image and noise map of the `Imaging` objected have been padded to the dimensions" + f"{data.shape_native}. This is because the blurring region around the mask (which defines where" + f"PSF flux may be convolved into the masked region) extended beyond the edge of the image." + f"" + f"This can be prevented by using a smaller mask, smaller PSF kernel size or manually padding" + f"the image and noise-map yourself." + ) super().__init__( data=data, From e7fb78ce97cb5c6e533d57e77cebcbbf2fdfcffa Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 10:05:03 +0100 Subject: [PATCH 03/19] update FFT padding scheme to be optimal --- autoarray/dataset/imaging/dataset.py | 24 ++- .../inversion/inversion/inversion_util.py | 2 +- autoarray/structures/arrays/kernel_2d.py | 197 ++++++++++++++++-- 3 files changed, 203 insertions(+), 20 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 8dfb8a644..db2396295 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -1,6 +1,7 @@ import logging import numpy as np from pathlib import Path +import scipy from typing import Optional, Union from autoconf import cached_property @@ -91,9 +92,17 @@ def __init__( self.pad_for_psf = pad_for_psf - if pad_for_psf and psf is not None: + if pad_for_psf: + full_shape, fft_shape, mask_shape = psf.fft_shape_from(mask=data.mask) + + print(data.mask.shape, full_shape, fft_shape, mask_shape, psf.shape_native) - pad_shape = (300, 300) + else: + full_shape = psf.full_shape + fft_shape = psf.fft_shape + mask_shape = psf.mask_shape + + if pad_for_psf and psf is not None: over_sample_size_lp = ( over_sample_util.over_sample_size_convert_to_array_2d_from( @@ -102,7 +111,7 @@ def __init__( ) over_sample_size_lp = ( over_sample_size_lp.resized_from( - new_shape=pad_shape, mask_pad_value=1 + new_shape=fft_shape, mask_pad_value=1 ) ) @@ -113,16 +122,16 @@ def __init__( ) over_sample_size_pixelization = ( over_sample_size_pixelization.resized_from( - new_shape=pad_shape, mask_pad_value=1 + new_shape=fft_shape, mask_pad_value=1 ) ) data = data.resized_from( - new_shape=pad_shape, mask_pad_value=1 + new_shape=fft_shape, mask_pad_value=1 ) if noise_map is not None: noise_map = noise_map.resized_from( - new_shape=pad_shape, mask_pad_value=1 + new_shape=fft_shape, mask_pad_value=1 ) logger.info( f"The image and noise map of the `Imaging` objected have been padded to the dimensions" @@ -177,6 +186,9 @@ def __init__( normalize=use_normalized_psf, image_mask=image_mask, blurring_mask=blurring_mask, + mask_shape=mask_shape, + full_shape=full_shape, + fft_shape=fft_shape ) self.psf = psf diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 95e216c9e..15ec4cf75 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -93,7 +93,7 @@ def curvature_matrix_via_mapping_matrix_from( noise_map Flattened 1D array of the noise-map used by the inversion during the fit. """ - array = mapping_matrix / noise_map[:, None] + array = mapping_matrix / jnp.expand_dims(noise_map.array, 1) curvature_matrix = jnp.dot(array.T, array) if add_to_curvature_diag and len(no_regularization_index_list) > 0: diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index c3e123942..e46e16158 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import numpy as np from pathlib import Path +import scipy from typing import List, Optional, Tuple, Union from autoconf.fitsable import header_obj_from @@ -26,6 +27,9 @@ def __init__( store_native: bool = False, image_mask=None, blurring_mask=None, + mask_shape=None, + full_shape=None, + fft_shape=None, *args, **kwargs, ): @@ -77,6 +81,19 @@ def __init__( slim_to_native_blurring[:, 1], ) + self.fft_shape = fft_shape + + self.mask_shape = None + self.full_shape = None + self.fft_psf = None + + if self.fft_shape is not None: + + self.mask_shape = mask_shape + self.full_shape = full_shape + self.fft_psf = jnp.fft.rfft2(self.native.array, s=self.fft_shape) + self.fft_psf_mapping = jnp.expand_dims(self.fft_psf, 2) + @classmethod def no_mask( cls, @@ -88,6 +105,9 @@ def no_mask( normalize: bool = False, image_mask=None, blurring_mask=None, + mask_shape=None, + full_shape=None, + fft_shape=None ): """ Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically @@ -122,6 +142,9 @@ def no_mask( normalize=normalize, image_mask=image_mask, blurring_mask=blurring_mask, + mask_shape=mask_shape, + full_shape=full_shape, + fft_shape=fft_shape ) @classmethod @@ -391,6 +414,21 @@ def from_fits( header=Header(header_sci_obj=header_sci_obj, header_hdu_obj=header_hdu_obj), ) + def fft_shape_from(self, mask): + + ys, xs = np.where(~mask) + y_min, y_max = ys.min(), ys.max() + x_min, x_max = xs.min(), xs.max() + + (pad_y, pad_x) = self.shape_native + + mask_shape = ((y_max + pad_y // 2) - (y_min - pad_y // 2), (x_max + pad_x // 2) - (x_min - pad_x // 2)) + + full_shape = tuple(s1 + s2 - 1 for s1, s2 in zip(mask_shape, self.shape_native)) + fft_shape = tuple(scipy.fft.next_fast_len(s, real=True) for s in full_shape) + + return full_shape, fft_shape, mask_shape + def rescaled_with_odd_dimensions_from( self, rescale_factor: float, normalize: bool = False ) -> "Kernel2D": @@ -554,7 +592,7 @@ def convolved_array_with_mask_from(self, array: Array2D, mask) -> Array2D: return Array2D(values=convolved_array_1d, mask=mask) - def convolve_image_via_real_space(self, image, blurring_image, jax_method="direct"): + def convolve_image(self, image, blurring_image, jax_method="direct"): """ For a given 1D array and blurring array, convolve the two using this psf. @@ -587,27 +625,105 @@ def convolve_image_via_real_space(self, image, blurring_image, jax_method="direc ) # make sure dtype matches what you want - expanded_array_native = jnp.zeros( - image.mask.shape, dtype=jnp.asarray(image.array).dtype + image_both_native = jnp.zeros( + image.mask.shape, dtype=image.dtype ) # set using a tuple of index arrays - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( + image_both_native = image_both_native.at[slim_to_native_tuple].set( jnp.asarray(image.array) ) - expanded_array_native = expanded_array_native.at[ + image_both_native = image_both_native.at[ slim_to_native_blurring_tuple ].set(jnp.asarray(blurring_image.array)) - kernel = self.stored_native.array + # FFT the combined image + fft_image_native = jnp.fft.rfft2(image_both_native, s=self.fft_shape, axes=(0, 1)) - convolve_native = jax.scipy.signal.convolve( - expanded_array_native, kernel, mode="same", method=jax_method + # Multiply by PSF in Fourier space and invert + blurred_image_full = jnp.fft.irfft2(self.fft_psf * fft_image_native, s=self.fft_shape, axes=(0, 1)) + + # Crop back to mask_shape + start_indices = tuple((full_size - out_size) // 2 for full_size, out_size in zip(self.full_shape, self.mask_shape)) + out_shape_full = self.mask_shape + blurred_image_native = jax.lax.dynamic_slice(blurred_image_full, start_indices, out_shape_full) + + return Array2D(values=blurred_image_native[slim_to_native_tuple], mask=image.mask) + + def convolve_mapping_matrix( + self, + mapping_matrix, + mask, + blurring_mapping_matrix=None, + jax_method="direct", + ): + """ + Convolve a source-pixel mapping matrix with this PSF in Fourier space. + Also supports a blurring mapping matrix, which is added in the same way as blurring_image. + + Parameters + ---------- + mapping_matrix : (N_masked_pixels, N_src) + Mapping matrix of unmasked pixels to source pixels. + mask : Mask + Mask object with slim-to-native mapping. + blurring_mapping_matrix : (N_blurring_pixels, N_src) or None + Mapping matrix for the blurring grid (outside the mask core). + If provided, this is scattered into native space and added to the main mapping matrix. + jax_method : str + Currently unused, placeholder for different convolution backends. + + Returns + ------- + (N_masked_pixels, N_src) + Blurred mapping matrix in slim form (only unmasked pixels). + """ + + slim_to_native_tuple = self.slim_to_native_tuple + if slim_to_native_tuple is None: + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(mask.array), size=mask.shape[0] + ) + + n_src = mapping_matrix.shape[1] + + # allocate full native + source dimension + mapping_matrix_native = jnp.zeros( + mask.shape + (n_src,), dtype=mapping_matrix.dtype ) - convolved_array_1d = convolve_native[slim_to_native_tuple] + # scatter main mapping matrix + mapping_matrix_native = mapping_matrix_native.at[slim_to_native_tuple].set( + mapping_matrix + ) - return Array2D(values=convolved_array_1d, mask=image.mask) + # optionally scatter blurring mapping matrix + if blurring_mapping_matrix is not None: + slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple + mapping_matrix_native = mapping_matrix_native.at[ + slim_to_native_blurring_tuple + ].set(blurring_mapping_matrix) + + # FFT convolution + fft_mapping_matrix_native = jnp.fft.rfft2( + mapping_matrix_native, s=self.fft_shape, axes=(0, 1) + ) + blurred_mapping_matrix_full = jnp.fft.irfft2( + self.fft_psf_mapping * fft_mapping_matrix_native, s=self.fft_shape, axes=(0, 1) + ) + + # crop back + start_indices = tuple( + (full_size - out_size) // 2 + for full_size, out_size in zip(self.full_shape, self.mask_shape) + ) + (0,) + out_shape_full = self.mask_shape + (blurred_mapping_matrix_full.shape[2],) + blurred_mapping_matrix_native = jax.lax.dynamic_slice( + blurred_mapping_matrix_full, start_indices, out_shape_full + ) + + # return slim form + return blurred_mapping_matrix_native[slim_to_native_tuple] def convolve_image_no_blurring(self, image, mask, jax_method="direct"): """ @@ -657,7 +773,7 @@ def convolve_image_no_blurring(self, image, mask, jax_method="direct"): return Array2D(values=convolved_array_1d, mask=mask) - def convolve_image_no_blurring_for_mapping(self, image, mask, jax_method="direct"): + def convolve_image_no_blurring_for_mapping_via_real_space(self, image, mask, jax_method="direct"): """ For a given 1D array and blurring array, convolve the two using this psf. @@ -700,7 +816,62 @@ def convolve_image_no_blurring_for_mapping(self, image, mask, jax_method="direct return Array2D(values=convolved_array_1d, mask=mask) - def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"): + def convolve_image_via_real_space(self, image, blurring_image, jax_method="direct"): + """ + For a given 1D array and blurring array, convolve the two using this psf. + + Parameters + ---------- + image + 1D array of the values which are to be blurred with the psf's PSF. + blurring_image + 1D array of the blurring values which blur into the array after PSF convolution. + jax_method + If JAX is enabled this keyword will indicate what method is used for the PSF + convolution. Can be either `direct` to calculate it in real space or `fft` + to calculated it via a fast Fourier transform. `fft` is typically faster for + kernels that are more than about 5x5. Default is `fft`. + """ + + slim_to_native_tuple = self.slim_to_native_tuple + slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple + + if slim_to_native_tuple is None: + + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(image.mask.array), size=image.shape[0] + ) + + if slim_to_native_blurring_tuple is None: + + slim_to_native_blurring_tuple = jnp.nonzero( + jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] + ) + + # make sure dtype matches what you want + expanded_array_native = jnp.zeros( + image.mask.shape, dtype=jnp.asarray(image.array).dtype + ) + + # set using a tuple of index arrays + expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( + jnp.asarray(image.array) + ) + expanded_array_native = expanded_array_native.at[ + slim_to_native_blurring_tuple + ].set(jnp.asarray(blurring_image.array)) + + kernel = self.stored_native.array + + convolve_native = jax.scipy.signal.convolve( + expanded_array_native, kernel, mode="same", method=jax_method + ) + + convolved_array_1d = convolve_native[slim_to_native_tuple] + + return Array2D(values=convolved_array_1d, mask=image.mask) + + def convolve_mapping_matrix_via_real_space(self, mapping_matrix, mask, jax_method="direct"): """For a given 1D array and blurring array, convolve the two using this psf. Parameters @@ -709,5 +880,5 @@ def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"): 1D array of the values which are to be blurred with the psf's PSF. """ return jax.vmap( - self.convolve_image_no_blurring_for_mapping, in_axes=(1, None, None) + self.convolve_image_no_blurring_for_mapping_via_real_space, in_axes=(1, None, None) )(mapping_matrix, mask, jax_method).T From 7e9d26a54a4fb991380f331d254023c707d72c99 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 17:15:00 +0100 Subject: [PATCH 04/19] padding built into FFT --- autoarray/dataset/imaging/dataset.py | 44 ++++++++++++---------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index db2396295..36962f5af 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -30,7 +30,7 @@ def __init__( noise_covariance_matrix: Optional[np.ndarray] = None, over_sample_size_lp: Union[int, Array2D] = 4, over_sample_size_pixelization: Union[int, Array2D] = 4, - pad_for_psf: bool = False, + disable_fft_pad : bool = True, use_normalized_psf: Optional[bool] = True, check_noise_map: bool = True, ): @@ -77,10 +77,10 @@ def __init__( over_sample_size_pixelization How over sampling is performed for the grid which is associated with a pixelization, which is therefore passed into the calculations performed in the `inversion` module. - pad_for_psf - The PSF convolution may extend beyond the edges of the image mask, which can lead to edge effects in the - convolved image. If `True`, the image and noise-map are padded to ensure the PSF convolution does not - extend beyond the edge of the image. + disable_fft_pad + The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming, which places the fewest zeros + around the image. If this is set to `True`, this optimal padding is not performed and the image is used + as-is. use_normalized_psf If `True`, the PSF kernel values are rescaled such that they sum to 1.0. This can be important for ensuring the PSF kernel does not change the overall normalization of the image when it is convolved with it. @@ -90,19 +90,21 @@ def __init__( self.unmasked = None - self.pad_for_psf = pad_for_psf + self.disable_fft_pad = disable_fft_pad - if pad_for_psf: - full_shape, fft_shape, mask_shape = psf.fft_shape_from(mask=data.mask) + if psf is not None: - print(data.mask.shape, full_shape, fft_shape, mask_shape, psf.shape_native) + full_shape, fft_shape, mask_shape = psf.fft_shape_from(mask=data.mask) - else: - full_shape = psf.full_shape - fft_shape = psf.fft_shape - mask_shape = psf.mask_shape + if psf is not None and not disable_fft_pad and data.mask.shape != fft_shape: - if pad_for_psf and psf is not None: + logger.info( + f"Imaging data has been trimmed or padded for FFT convolution.\n" + f" - Original shape : {data.mask.shape}\n" + f" - FFT shape : {fft_shape}\n" + f"Padding ensures accurate PSF convolution in Fourier space. " + f"Set `disable_fft_pad=True` in Imaging object to turn off automatic padding." + ) over_sample_size_lp = ( over_sample_util.over_sample_size_convert_to_array_2d_from( @@ -133,14 +135,6 @@ def __init__( noise_map = noise_map.resized_from( new_shape=fft_shape, mask_pad_value=1 ) - logger.info( - f"The image and noise map of the `Imaging` objected have been padded to the dimensions" - f"{data.shape_native}. This is because the blurring region around the mask (which defines where" - f"PSF flux may be convolved into the masked region) extended beyond the edge of the image." - f"" - f"This can be prevented by using a smaller mask, smaller PSF kernel size or manually padding" - f"the image and noise-map yourself." - ) super().__init__( data=data, @@ -395,7 +389,7 @@ def apply_mask(self, mask: Mask2D) -> "Imaging": noise_covariance_matrix=noise_covariance_matrix, over_sample_size_lp=over_sample_size_lp, over_sample_size_pixelization=over_sample_size_pixelization, - pad_for_psf=True, + disable_fft_pad=False, ) dataset.unmasked = unmasked_dataset @@ -488,7 +482,7 @@ def apply_noise_scaling( noise_covariance_matrix=self.noise_covariance_matrix, over_sample_size_lp=self.over_sample_size_lp, over_sample_size_pixelization=self.over_sample_size_pixelization, - pad_for_psf=False, + disable_fft_pad=False, check_noise_map=False, ) @@ -536,7 +530,7 @@ def apply_over_sampling( over_sample_size_lp=over_sample_size_lp or self.over_sample_size_lp, over_sample_size_pixelization=over_sample_size_pixelization or self.over_sample_size_pixelization, - pad_for_psf=False, + disable_fft_pad=False, check_noise_map=False, ) From c95af763776a5818d0ddbc6d0165b55fa58efe16 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 17:23:48 +0100 Subject: [PATCH 05/19] support both types of convolution with override --- .../inversion/inversion/imaging/w_tilde.py | 8 +- autoarray/structures/arrays/kernel_2d.py | 212 ++++++------------ 2 files changed, 79 insertions(+), 141 deletions(-) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index 58a3ccc63..dbce9233b 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -518,8 +518,12 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: reconstruction=np.array(reconstruction), ) - mapped_reconstructed_image = self.psf.convolve_image_no_blurring( - image=mapped_reconstructed_image, mask=self.mask + mapped_reconstructed_image = Array2D( + values=mapped_reconstructed_image, mask=self.mask + ) + + mapped_reconstructed_image = self.psf.convolve_image( + image=mapped_reconstructed_image.native, ).array mapped_reconstructed_image = Array2D( diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index e46e16158..4ee1d262b 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -4,6 +4,7 @@ from pathlib import Path import scipy from typing import List, Optional, Tuple, Union +import warnings from autoconf.fitsable import header_obj_from @@ -30,6 +31,7 @@ def __init__( mask_shape=None, full_shape=None, fft_shape=None, + use_real_space : bool = False *args, **kwargs, ): @@ -94,6 +96,8 @@ def __init__( self.fft_psf = jnp.fft.rfft2(self.native.array, s=self.fft_shape) self.fft_psf_mapping = jnp.expand_dims(self.fft_psf, 2) + self.use_real_space = use_real_space + @classmethod def no_mask( cls, @@ -609,53 +613,66 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): kernels that are more than about 5x5. Default is `fft`. """ + if self.use_real_space: + return self.convolve_image_via_real_space( + image=image, blurring_image=blurring_image, jax_method=jax_method + ) + slim_to_native_tuple = self.slim_to_native_tuple slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple if slim_to_native_tuple is None: - slim_to_native_tuple = jnp.nonzero( jnp.logical_not(image.mask.array), size=image.shape[0] ) - if slim_to_native_blurring_tuple is None: - - slim_to_native_blurring_tuple = jnp.nonzero( - jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] - ) - - # make sure dtype matches what you want - image_both_native = jnp.zeros( - image.mask.shape, dtype=image.dtype - ) - - # set using a tuple of index arrays + # start with native image padded with zeros + image_both_native = jnp.zeros(image.mask.shape, dtype=image.dtype) image_both_native = image_both_native.at[slim_to_native_tuple].set( jnp.asarray(image.array) ) - image_both_native = image_both_native.at[ - slim_to_native_blurring_tuple - ].set(jnp.asarray(blurring_image.array)) + + # add blurring contribution if provided + if blurring_image is not None: + if slim_to_native_blurring_tuple is None: + slim_to_native_blurring_tuple = jnp.nonzero( + jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] + ) + image_both_native = image_both_native.at[slim_to_native_blurring_tuple].set( + jnp.asarray(blurring_image.array) + ) + else: + warnings.warn( + "No blurring_image provided. Only the direct image will be convolved. " + "This may change the correctness of the PSF convolution." + ) # FFT the combined image fft_image_native = jnp.fft.rfft2(image_both_native, s=self.fft_shape, axes=(0, 1)) # Multiply by PSF in Fourier space and invert - blurred_image_full = jnp.fft.irfft2(self.fft_psf * fft_image_native, s=self.fft_shape, axes=(0, 1)) + blurred_image_full = jnp.fft.irfft2( + self.fft_psf * fft_image_native, s=self.fft_shape, axes=(0, 1) + ) # Crop back to mask_shape - start_indices = tuple((full_size - out_size) // 2 for full_size, out_size in zip(self.full_shape, self.mask_shape)) + start_indices = tuple( + (full_size - out_size) // 2 + for full_size, out_size in zip(self.full_shape, self.mask_shape) + ) out_shape_full = self.mask_shape - blurred_image_native = jax.lax.dynamic_slice(blurred_image_full, start_indices, out_shape_full) + blurred_image_native = jax.lax.dynamic_slice( + blurred_image_full, start_indices, out_shape_full + ) return Array2D(values=blurred_image_native[slim_to_native_tuple], mask=image.mask) def convolve_mapping_matrix( - self, - mapping_matrix, - mask, - blurring_mapping_matrix=None, - jax_method="direct", + self, + mapping_matrix, + mask, + blurring_mapping_matrix=None, + jax_method="direct", ): """ Convolve a source-pixel mapping matrix with this PSF in Fourier space. @@ -678,6 +695,8 @@ def convolve_mapping_matrix( (N_masked_pixels, N_src) Blurred mapping matrix in slim form (only unmasked pixels). """ + if self.use_real_space: + return self.convolve_image_no_blurring_for_mapping_via_real_space(image=mapping_matrix, mask=mask, jax_method=jax_method) slim_to_native_tuple = self.slim_to_native_tuple if slim_to_native_tuple is None: @@ -725,144 +744,59 @@ def convolve_mapping_matrix( # return slim form return blurred_mapping_matrix_native[slim_to_native_tuple] - def convolve_image_no_blurring(self, image, mask, jax_method="direct"): + def convolve_image_via_real_space(self, image, blurring_image=None, jax_method="direct"): """ - For a given 1D array and blurring array, convolve the two using this psf. + Convolve an input image with this PSF in real space. Parameters ---------- - image - 1D array of the values which are to be blurred with the psf's PSF. - blurring_image - 1D array of the blurring values which blur into the array after PSF convolution. - jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. - """ - - slim_to_native_tuple = self.slim_to_native_tuple - - if slim_to_native_tuple is None: - - slim_to_native_tuple = jnp.nonzero( - jnp.logical_not(mask.array), size=image.shape[0] - ) - - # make sure dtype matches what you want - expanded_array_native = jnp.zeros(mask.shape) - - # set using a tuple of index arrays - if isinstance(image, np.ndarray) or isinstance(image, jnp.ndarray): - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( - image - ) - else: - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( - jnp.asarray(image.array) - ) - - kernel = self.stored_native.array - - convolve_native = jax.scipy.signal.convolve( - expanded_array_native, kernel, mode="same", method=jax_method - ) - - convolved_array_1d = convolve_native[slim_to_native_tuple] - - return Array2D(values=convolved_array_1d, mask=mask) - - def convolve_image_no_blurring_for_mapping_via_real_space(self, image, mask, jax_method="direct"): - """ - For a given 1D array and blurring array, convolve the two using this psf. - - Parameters - ---------- - image - 1D array of the values which are to be blurred with the psf's PSF. - blurring_image - 1D array of the blurring values which blur into the array after PSF convolution. - jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. - """ - - slim_to_native_tuple = self.slim_to_native_tuple - - if slim_to_native_tuple is None: - - slim_to_native_tuple = jnp.nonzero( - jnp.logical_not(mask.array), size=image.shape[0] - ) - - # make sure dtype matches what you want - expanded_array_native = jnp.zeros(mask.shape) - - # set using a tuple of index arrays - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( - image - ) - - kernel = self.stored_native.array - - convolve_native = jax.scipy.signal.convolve( - expanded_array_native, kernel, mode="same", method=jax_method - ) - - convolved_array_1d = convolve_native[slim_to_native_tuple] - - return Array2D(values=convolved_array_1d, mask=mask) - - def convolve_image_via_real_space(self, image, blurring_image, jax_method="direct"): - """ - For a given 1D array and blurring array, convolve the two using this psf. - - Parameters - ---------- - image - 1D array of the values which are to be blurred with the psf's PSF. - blurring_image - 1D array of the blurring values which blur into the array after PSF convolution. - jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. + image : Array2D + 1D array of values to be blurred with the PSF. + blurring_image : Array2D or None, optional + 1D array of blurring values which convolve into the image. + If None, only the direct image is convolved. A warning is raised + because omitting the blurring image may change the correctness of + the convolution result. + jax_method : {"direct", "fft"} + Method passed to `jax.scipy.signal.convolve`. Default is "direct". """ slim_to_native_tuple = self.slim_to_native_tuple slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple if slim_to_native_tuple is None: - slim_to_native_tuple = jnp.nonzero( jnp.logical_not(image.mask.array), size=image.shape[0] ) - if slim_to_native_blurring_tuple is None: - - slim_to_native_blurring_tuple = jnp.nonzero( - jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] - ) - - # make sure dtype matches what you want + # start with native array padded with zeros expanded_array_native = jnp.zeros( image.mask.shape, dtype=jnp.asarray(image.array).dtype ) - # set using a tuple of index arrays + # set image pixels expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( jnp.asarray(image.array) ) - expanded_array_native = expanded_array_native.at[ - slim_to_native_blurring_tuple - ].set(jnp.asarray(blurring_image.array)) - kernel = self.stored_native.array + # add blurring contribution if provided + if blurring_image is not None: + if slim_to_native_blurring_tuple is None: + slim_to_native_blurring_tuple = jnp.nonzero( + jnp.logical_not(blurring_image.mask.array), + size=blurring_image.shape[0] + ) + expanded_array_native = expanded_array_native.at[slim_to_native_blurring_tuple].set( + jnp.asarray(blurring_image.array) + ) + else: + warnings.warn( + "No blurring_image provided. Only the direct image will be convolved. " + "This may change the correctness of the PSF convolution." + ) + # perform real-space convolution + kernel = self.stored_native.array convolve_native = jax.scipy.signal.convolve( expanded_array_native, kernel, mode="same", method=jax_method ) From 94df6eb89d77a7114e6018a3d11adc8b2bbd8579 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 17:51:09 +0100 Subject: [PATCH 06/19] convolve_mapping_matrix_via_real_space --- autoarray/dataset/imaging/dataset.py | 20 +-- autoarray/structures/arrays/kernel_2d.py | 128 ++++++++++++++---- test_autoarray/config/general.yaml | 2 + .../structures/arrays/test_kernel_2d.py | 4 +- 4 files changed, 111 insertions(+), 43 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 36962f5af..ece559b67 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -30,7 +30,7 @@ def __init__( noise_covariance_matrix: Optional[np.ndarray] = None, over_sample_size_lp: Union[int, Array2D] = 4, over_sample_size_pixelization: Union[int, Array2D] = 4, - disable_fft_pad : bool = True, + disable_fft_pad: bool = True, use_normalized_psf: Optional[bool] = True, check_noise_map: bool = True, ): @@ -111,10 +111,8 @@ def __init__( over_sample_size=over_sample_size_lp, mask=data.mask ) ) - over_sample_size_lp = ( - over_sample_size_lp.resized_from( - new_shape=fft_shape, mask_pad_value=1 - ) + over_sample_size_lp = over_sample_size_lp.resized_from( + new_shape=fft_shape, mask_pad_value=1 ) over_sample_size_pixelization = ( @@ -122,15 +120,11 @@ def __init__( over_sample_size=over_sample_size_pixelization, mask=data.mask ) ) - over_sample_size_pixelization = ( - over_sample_size_pixelization.resized_from( - new_shape=fft_shape, mask_pad_value=1 - ) - ) - - data = data.resized_from( + over_sample_size_pixelization = over_sample_size_pixelization.resized_from( new_shape=fft_shape, mask_pad_value=1 ) + + data = data.resized_from(new_shape=fft_shape, mask_pad_value=1) if noise_map is not None: noise_map = noise_map.resized_from( new_shape=fft_shape, mask_pad_value=1 @@ -182,7 +176,7 @@ def __init__( blurring_mask=blurring_mask, mask_shape=mask_shape, full_shape=full_shape, - fft_shape=fft_shape + fft_shape=fft_shape, ) self.psf = psf diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 4ee1d262b..6f6dbc7d3 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -6,6 +6,7 @@ from typing import List, Optional, Tuple, Union import warnings +from autoconf import conf from autoconf.fitsable import header_obj_from from autoarray.structures.arrays.uniform_2d import AbstractArray2D @@ -31,7 +32,7 @@ def __init__( mask_shape=None, full_shape=None, fft_shape=None, - use_real_space : bool = False + use_fft: Optional[bool] = None, *args, **kwargs, ): @@ -96,7 +97,14 @@ def __init__( self.fft_psf = jnp.fft.rfft2(self.native.array, s=self.fft_shape) self.fft_psf_mapping = jnp.expand_dims(self.fft_psf, 2) - self.use_real_space = use_real_space + self._use_fft = use_fft + + @property + def use_fft(self): + if self._use_fft is None: + return conf.instance["general"]["psf"]["use_fft_default"] + + return self._use_fft @classmethod def no_mask( @@ -111,7 +119,7 @@ def no_mask( blurring_mask=None, mask_shape=None, full_shape=None, - fft_shape=None + fft_shape=None, ): """ Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically @@ -148,7 +156,7 @@ def no_mask( blurring_mask=blurring_mask, mask_shape=mask_shape, full_shape=full_shape, - fft_shape=fft_shape + fft_shape=fft_shape, ) @classmethod @@ -426,7 +434,10 @@ def fft_shape_from(self, mask): (pad_y, pad_x) = self.shape_native - mask_shape = ((y_max + pad_y // 2) - (y_min - pad_y // 2), (x_max + pad_x // 2) - (x_min - pad_x // 2)) + mask_shape = ( + (y_max + pad_y // 2) - (y_min - pad_y // 2), + (x_max + pad_x // 2) - (x_min - pad_x // 2), + ) full_shape = tuple(s1 + s2 - 1 for s1, s2 in zip(mask_shape, self.shape_native)) fft_shape = tuple(scipy.fft.next_fast_len(s, real=True) for s in full_shape) @@ -613,7 +624,7 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): kernels that are more than about 5x5. Default is `fft`. """ - if self.use_real_space: + if not self.use_fft: return self.convolve_image_via_real_space( image=image, blurring_image=blurring_image, jax_method=jax_method ) @@ -636,7 +647,8 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): if blurring_image is not None: if slim_to_native_blurring_tuple is None: slim_to_native_blurring_tuple = jnp.nonzero( - jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] + jnp.logical_not(blurring_image.mask.array), + size=blurring_image.shape[0], ) image_both_native = image_both_native.at[slim_to_native_blurring_tuple].set( jnp.asarray(blurring_image.array) @@ -648,7 +660,9 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): ) # FFT the combined image - fft_image_native = jnp.fft.rfft2(image_both_native, s=self.fft_shape, axes=(0, 1)) + fft_image_native = jnp.fft.rfft2( + image_both_native, s=self.fft_shape, axes=(0, 1) + ) # Multiply by PSF in Fourier space and invert blurred_image_full = jnp.fft.irfft2( @@ -665,7 +679,9 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): blurred_image_full, start_indices, out_shape_full ) - return Array2D(values=blurred_image_native[slim_to_native_tuple], mask=image.mask) + return Array2D( + values=blurred_image_native[slim_to_native_tuple], mask=image.mask + ) def convolve_mapping_matrix( self, @@ -695,8 +711,13 @@ def convolve_mapping_matrix( (N_masked_pixels, N_src) Blurred mapping matrix in slim form (only unmasked pixels). """ - if self.use_real_space: - return self.convolve_image_no_blurring_for_mapping_via_real_space(image=mapping_matrix, mask=mask, jax_method=jax_method) + if not self.use_fft: + return self.convolve_mapping_matrix_via_real_space( + mapping_matrix=mapping_matrix, + mask=mask, + blurring_mapping_matrix=blurring_mapping_matrix, + jax_method=jax_method, + ) slim_to_native_tuple = self.slim_to_native_tuple if slim_to_native_tuple is None: @@ -728,7 +749,9 @@ def convolve_mapping_matrix( mapping_matrix_native, s=self.fft_shape, axes=(0, 1) ) blurred_mapping_matrix_full = jnp.fft.irfft2( - self.fft_psf_mapping * fft_mapping_matrix_native, s=self.fft_shape, axes=(0, 1) + self.fft_psf_mapping * fft_mapping_matrix_native, + s=self.fft_shape, + axes=(0, 1), ) # crop back @@ -744,7 +767,9 @@ def convolve_mapping_matrix( # return slim form return blurred_mapping_matrix_native[slim_to_native_tuple] - def convolve_image_via_real_space(self, image, blurring_image=None, jax_method="direct"): + def convolve_image_via_real_space( + self, image, blurring_image=None, jax_method="direct" + ): """ Convolve an input image with this PSF in real space. @@ -784,11 +809,11 @@ def convolve_image_via_real_space(self, image, blurring_image=None, jax_method=" if slim_to_native_blurring_tuple is None: slim_to_native_blurring_tuple = jnp.nonzero( jnp.logical_not(blurring_image.mask.array), - size=blurring_image.shape[0] + size=blurring_image.shape[0], ) - expanded_array_native = expanded_array_native.at[slim_to_native_blurring_tuple].set( - jnp.asarray(blurring_image.array) - ) + expanded_array_native = expanded_array_native.at[ + slim_to_native_blurring_tuple + ].set(jnp.asarray(blurring_image.array)) else: warnings.warn( "No blurring_image provided. Only the direct image will be convolved. " @@ -805,14 +830,63 @@ def convolve_image_via_real_space(self, image, blurring_image=None, jax_method=" return Array2D(values=convolved_array_1d, mask=image.mask) - def convolve_mapping_matrix_via_real_space(self, mapping_matrix, mask, jax_method="direct"): - """For a given 1D array and blurring array, convolve the two using this psf. + def convolve_mapping_matrix_via_real_space( + self, mapping_matrix, mask, blurring_mapping_matrix=None, jax_method="direct" + ): + # 1) Indices of unmasked (image) pixels — no `size=` to avoid wrong lengths + ys, xs = self.slim_to_native_tuple or jnp.nonzero(jnp.logical_not(mask.array)) + n_pix, n_src = mapping_matrix.shape + + # Sanity check + if ys.shape[0] != n_pix: + raise ValueError( + f"Mapping rows ({n_pix}) != unmasked pixels ({ys.shape[0]}). " + "Make sure you’re using the image (not blurring) index tuple." + ) - Parameters - ---------- - image - 1D array of the values which are to be blurred with the psf's PSF. - """ - return jax.vmap( - self.convolve_image_no_blurring_for_mapping_via_real_space, in_axes=(1, None, None) - )(mapping_matrix, mask, jax_method).T + # 2) Allocate native cube (ny, nx, n_src) + mapping_matrix_native = jnp.zeros( + mask.shape + (n_src,), dtype=mapping_matrix.dtype + ) + + # 3) Build index grids with identical shape (n_pix, n_src) + ys_exp = jnp.broadcast_to(ys[:, None], (n_pix, n_src)) + xs_exp = jnp.broadcast_to(xs[:, None], (n_pix, n_src)) + src_exp = jnp.broadcast_to(jnp.arange(n_src)[None, :], (n_pix, n_src)) + + # 4) Scatter all at once (values also shape (n_pix, n_src)) + mapping_matrix_native = mapping_matrix_native.at[(ys_exp, xs_exp, src_exp)].set( + mapping_matrix + ) + + # 5) Optional blurring mapping matrix + if blurring_mapping_matrix is not None: + ys_b, xs_b = self.slim_to_native_blurring_tuple or jnp.nonzero( + jnp.logical_not( + mask.array + ) # use the correct blurring grid mask here if different + ) + n_blur, n_src_b = blurring_mapping_matrix.shape + if n_src_b != n_src: + raise ValueError( + "blurring_mapping_matrix columns must match mapping_matrix columns (n_src)." + ) + + ys_b_exp = jnp.broadcast_to(ys_b[:, None], (n_blur, n_src)) + xs_b_exp = jnp.broadcast_to(xs_b[:, None], (n_blur, n_src)) + src_b_exp = jnp.broadcast_to(jnp.arange(n_src)[None, :], (n_blur, n_src)) + + mapping_matrix_native = mapping_matrix_native.at[ + (ys_b_exp, xs_b_exp, src_b_exp) + ].set(blurring_mapping_matrix) + + # 6) Real-space convolution, broadcast kernel over source axis + kernel = self.stored_native.array + convolved_native = jax.scipy.signal.convolve( + mapping_matrix_native, kernel[..., None], mode="same", method=jax_method + ) + + # 7) Pull back to slim (n_pix, n_src) + blurred_mapping_matrix = convolved_native[ys, xs, :] + + return blurred_mapping_matrix diff --git a/test_autoarray/config/general.yaml b/test_autoarray/config/general.yaml index 6f331d141..2af37eda3 100644 --- a/test_autoarray/config/general.yaml +++ b/test_autoarray/config/general.yaml @@ -2,6 +2,8 @@ analysis: n_cores: 1 fits: flip_for_ds9: false +psf: + use_fft_default: false # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution. grid: remove_projected_centre: false adapt: diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 6fa4e7295..213e2f8fa 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -420,9 +420,7 @@ def test__convolve_image_no_blurring(): masked_image = aa.Array2D(values=image.native, mask=mask) - blurred_masked_im_1 = kernel.convolve_image_no_blurring( - image=masked_image, mask=mask - ) + blurred_masked_im_1 = kernel.convolve_image(image=masked_image, blurring_image=None) assert blurred_masked_image_via_scipy == pytest.approx( blurred_masked_im_1.array, 1e-4 From 4eb18c8281c4dc8ea99b1758593c1148b6348e1e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 18:36:24 +0100 Subject: [PATCH 07/19] test on convolve_image --- autoarray/structures/arrays/kernel_2d.py | 59 ++++++++++++++++--- .../structures/arrays/test_kernel_2d.py | 24 ++++++++ 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 6f6dbc7d3..b5f9eafc2 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -19,6 +19,7 @@ from autoarray.structures.arrays import array_2d_util + class Kernel2D(AbstractArray2D): def __init__( self, @@ -120,6 +121,7 @@ def no_mask( mask_shape=None, full_shape=None, fft_shape=None, + use_fft: Optional[bool] = None, ): """ Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically @@ -157,6 +159,7 @@ def no_mask( mask_shape=mask_shape, full_shape=full_shape, fft_shape=fft_shape, + use_fft=use_fft, ) @classmethod @@ -629,6 +632,24 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): image=image, blurring_image=blurring_image, jax_method=jax_method ) + if self.fft_shape is None: + + full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=image.mask) + fft_psf = jnp.fft.rfft2(self.stored_native.array, s=fft_shape) + + image = image.resized_from(new_shape=fft_shape, mask_pad_value=1) + if blurring_image is not None: + blurring_image = blurring_image.resized_from( + new_shape=fft_shape, mask_pad_value=1 + ) + + else: + + fft_shape = self.fft_shape + full_shape = self.full_shape + mask_shape = self.mask_shape + fft_psf = self.fft_psf + slim_to_native_tuple = self.slim_to_native_tuple slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple @@ -661,20 +682,20 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): # FFT the combined image fft_image_native = jnp.fft.rfft2( - image_both_native, s=self.fft_shape, axes=(0, 1) + image_both_native, s=fft_shape, axes=(0, 1) ) # Multiply by PSF in Fourier space and invert blurred_image_full = jnp.fft.irfft2( - self.fft_psf * fft_image_native, s=self.fft_shape, axes=(0, 1) + fft_psf * fft_image_native, s=fft_shape, axes=(0, 1) ) # Crop back to mask_shape start_indices = tuple( (full_size - out_size) // 2 - for full_size, out_size in zip(self.full_shape, self.mask_shape) + for full_size, out_size in zip(full_shape, mask_shape) ) - out_shape_full = self.mask_shape + out_shape_full = mask_shape blurred_image_native = jax.lax.dynamic_slice( blurred_image_full, start_indices, out_shape_full ) @@ -719,7 +740,27 @@ def convolve_mapping_matrix( jax_method=jax_method, ) + if self.fft_shape is None: + + full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=mask) + + raise ValueError( + f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n" + f"Expected mapping matrix padded to match FFT shape of PSF.\n" + f"PSF fft_shape: {fft_shape}, mask shape: {mask.shape}, " + f"mapping_matrix shape: {getattr(mapping_matrix, 'shape', 'unknown')}." + ) + + else: + + fft_shape = self.fft_shape + full_shape = self.full_shape + mask_shape = self.mask_shape + fft_psf = self.fft_psf + fft_psf_mapping = self.fft_psf_mapping + slim_to_native_tuple = self.slim_to_native_tuple + if slim_to_native_tuple is None: slim_to_native_tuple = jnp.nonzero( jnp.logical_not(mask.array), size=mask.shape[0] @@ -746,20 +787,20 @@ def convolve_mapping_matrix( # FFT convolution fft_mapping_matrix_native = jnp.fft.rfft2( - mapping_matrix_native, s=self.fft_shape, axes=(0, 1) + mapping_matrix_native, s=fft_shape, axes=(0, 1) ) blurred_mapping_matrix_full = jnp.fft.irfft2( - self.fft_psf_mapping * fft_mapping_matrix_native, - s=self.fft_shape, + fft_psf_mapping * fft_mapping_matrix_native, + s=fft_shape, axes=(0, 1), ) # crop back start_indices = tuple( (full_size - out_size) // 2 - for full_size, out_size in zip(self.full_shape, self.mask_shape) + for full_size, out_size in zip(full_shape, mask_shape) ) + (0,) - out_shape_full = self.mask_shape + (blurred_mapping_matrix_full.shape[2],) + out_shape_full = mask_shape + (blurred_mapping_matrix_full.shape[2],) blurred_mapping_matrix_native = jax.lax.dynamic_slice( blurred_mapping_matrix_full, start_indices, out_shape_full ) diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 213e2f8fa..6718cfcbf 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -554,3 +554,27 @@ def test__convolve_mapping_matrix(): ), abs=1e-4, ) + +def test__convolve_image__via_fft__sizes_not_precomputed__compare_numerical_value(): + + # ------------------------------- + # Case 1: direct image convolution + # ------------------------------- + mask = aa.Mask2D.circular( + shape_native=(20, 20), pixel_scales=(1.0, 1.0), radius=5.0 + ) + + image = aa.Array2D.no_mask(values=np.arange(400).reshape(20, 20), pixel_scales=1.0) + masked_image = aa.Array2D(values=image.native, mask=mask) + + kernel_fft = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0, use_fft=True, normalize=True) + + blurring_mask = mask.derive_mask.blurring_from(kernel_shape_native=kernel_fft.shape_native) + blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) + + blurred_fft = kernel_fft.convolve_image(image=masked_image, blurring_image=blurring_image) + + assert blurred_fft.native.array[13, 13] == pytest.approx(207.49999999999, rel=1e-6, abs=1e-6) + + + From 8acb11d11e41e1eca3dbf04ca25f9030a47786c8 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 18:44:39 +0100 Subject: [PATCH 08/19] docstrings --- autoarray/structures/arrays/kernel_2d.py | 435 ++++++++++++++++------- 1 file changed, 308 insertions(+), 127 deletions(-) diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index b5f9eafc2..7a1b0c8a9 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -38,22 +38,75 @@ def __init__( **kwargs, ): """ - An array of values, which are paired to a uniform 2D mask of pixels. Each entry - on the array corresponds to a value at the centre of a pixel in an unmasked pixel. See the ``Array2D`` class - for a full description of how Arrays work. + A 2D convolution kernel stored as an array of values paired to a uniform 2D mask. - The ``Kernel2D`` class is an ``Array2D`` but with additioonal methods that allow it to be convolved with data. + The ``Kernel2D`` is a subclass of ``Array2D`` with additional methods for performing + point spread function (PSF) convolution of images or mapping matrices. Each entry of + the kernel corresponds to a PSF value at the centre of a pixel in the unmasked grid. + + Two convolution modes are supported: + + - **Real-space convolution**: performed directly via sliding-window summation or + ``jax.scipy.signal.convolve``. This is exact but can be slow for large kernels. + - **FFT convolution**: performed by transforming both the kernel and the input image + into Fourier space, multiplying, and transforming back. This is typically faster + for kernels larger than ~5×5, but requires careful zero-padding. + + When using FFT convolution, the input image and mask are automatically padded such + that the FFT avoids circular wrap-around artefacts. This padding is computed from the + kernel size via :meth:`fft_shape_from`. The padded shape is stored in ``fft_shape``. + If FFT convolution is attempted without precomputing and applying this padding, + an exception is raised to avoid silent shape mismatches. Parameters ---------- values - The values of the array. + The raw 2D kernel values. Can be normalised to sum to unity if ``normalize=True``. mask - The 2D mask associated with the array, defining the pixels each array value is paired with and - originates from. + The 2D mask associated with the kernel, defining the pixels each kernel value is + paired with. + header + Optional metadata (e.g. FITS header) associated with the kernel. normalize - If True, the Kernel2D's array values are normalized such that they sum to 1.0. + If True, the kernel values are rescaled such that they sum to 1.0. + store_native + If True, the kernel is stored in its full native 2D form as an attribute + ``stored_native`` for re-use (e.g. when convolving repeatedly). + image_mask + Optional mask defining the unmasked image pixels when performing convolution. + If not provided, defaults to the supplied ``mask``. + blurring_mask + Optional mask defining the "blurring region": pixels outside the image mask + into which PSF flux can spread. Used to construct blurring images and + blurring mapping matrices. + mask_shape + The shape of the (unpadded) mask region. Used when cropping back results after + FFT convolution. + full_shape + The unpadded image + kernel shape (``image_shape + kernel_shape - 1``). + fft_shape + The padded shape used in FFT convolution, typically computed via + ``scipy.fft.next_fast_len`` for efficiency. Must be precomputed before calling + FFT convolution methods. + use_fft + If True, convolution is performed in Fourier space with zero-padding. + If False, convolution is performed in real space. + If None, a default choice is made: real space for small kernels, + FFT for large kernels. + *args, **kwargs + Passed to the ``Array2D`` constructor. + + Notes + ----- + - FFT padding can be disabled globally with ``disable_fft_pad=True`` when + constructing ``Imaging`` objects, in which case convolution will either + use real space or proceed without padding. + - Blurring masks ensure that PSF flux spilling outside the main image mask + is included correctly. Omitting them may lead to underestimated PSF wings. + - For unit tests with tiny kernels, FFT and real-space convolution may differ + slightly due to edge and truncation effects. """ + super().__init__( values=values, mask=mask, @@ -429,7 +482,41 @@ def from_fits( header=Header(header_sci_obj=header_sci_obj, header_hdu_obj=header_hdu_obj), ) - def fft_shape_from(self, mask): + def fft_shape_from(self, mask : np.ndarray) -> Union[Tuple[int, int], Tuple[int, int], Tuple[int, int]]: + """ + Compute the padded shapes required for FFT-based convolution with this kernel. + + FFT convolution requires the input image and kernel to be zero-padded so that + the convolution is equivalent to linear convolution (not circular) and to avoid + wrap-around artefacts. This method inspects the mask and the kernel shape to + determine three key shapes: + + - ``mask_shape``: the rectangular bounding-box region of the mask that encloses + all unmasked (False) pixels, padded by half the kernel size in each direction. + This is the minimal region that must be retained for convolution. + - ``full_shape``: the "linear convolution shape", equal to + ``mask_shape + kernel_shape - 1``. This is the minimal padded size required + for an exact linear convolution. + - ``fft_shape``: the FFT-efficient padded shape, obtained by rounding each + dimension of ``full_shape`` up to the next fast length for real FFTs + (via ``scipy.fft.next_fast_len``). Using this ensures efficient FFT execution. + + Parameters + ---------- + mask + A 2D mask where False indicates unmasked pixels (valid data) and True + indicates masked pixels. The bounding-box of the False region is used + to compute the convolution region. + + Returns + ------- + full_shape + The unpadded linear convolution shape (mask region + kernel − 1). + fft_shape + The FFT-friendly padded shape for efficient convolution. + mask_shape + The rectangular mask region size including kernel padding. + """ ys, xs = np.where(~mask) y_min, y_max = ys.min(), ys.max() @@ -447,92 +534,6 @@ def fft_shape_from(self, mask): return full_shape, fft_shape, mask_shape - def rescaled_with_odd_dimensions_from( - self, rescale_factor: float, normalize: bool = False - ) -> "Kernel2D": - """ - If the PSF kernel has one or two even-sized dimensions, return a PSF object where the kernel has odd-sized - dimensions (odd-sized dimensions are required for 2D convolution). - - The PSF can be scaled to larger / smaller sizes than the input size, if the rescale factor uses values that - deviate furher from 1.0. - - Kernels are rescald using the scikit-image routine rescale, which performs rescaling via an interpolation - routine. This may lead to loss of accuracy in the PSF kernel and it is advised that users, where possible, - create their PSF on an odd-sized array using their data reduction pipelines that remove this approximation. - - Parameters - ---------- - rescale_factor - The factor by which the kernel is rescaled. If this has a value of 1.0, the kernel is rescaled to the - closest odd-sized dimensions (e.g. 20 -> 19). Higher / lower values scale to higher / lower dimensions. - normalize - Whether the PSF should be normalized after being rescaled. - """ - - from skimage.transform import resize, rescale - - try: - kernel_rescaled = rescale( - self.native.array, - rescale_factor, - anti_aliasing=False, - mode="constant", - channel_axis=None, - ) - except TypeError: - kernel_rescaled = rescale( - self.native.array, - rescale_factor, - anti_aliasing=False, - mode="constant", - ) - - if kernel_rescaled.shape[0] % 2 == 0 and kernel_rescaled.shape[1] % 2 == 0: - kernel_rescaled = resize( - kernel_rescaled, - output_shape=( - kernel_rescaled.shape[0] + 1, - kernel_rescaled.shape[1] + 1, - ), - anti_aliasing=False, - mode="constant", - ) - - elif kernel_rescaled.shape[0] % 2 == 0 and kernel_rescaled.shape[1] % 2 != 0: - kernel_rescaled = resize( - kernel_rescaled, - output_shape=(kernel_rescaled.shape[0] + 1, kernel_rescaled.shape[1]), - anti_aliasing=False, - mode="constant", - ) - - elif kernel_rescaled.shape[0] % 2 != 0 and kernel_rescaled.shape[1] % 2 == 0: - kernel_rescaled = resize( - kernel_rescaled, - output_shape=(kernel_rescaled.shape[0], kernel_rescaled.shape[1] + 1), - anti_aliasing=False, - mode="constant", - ) - - if self.pixel_scales is not None: - pixel_scale_factors = ( - self.mask.shape[0] / kernel_rescaled.shape[0], - self.mask.shape[1] / kernel_rescaled.shape[1], - ) - - pixel_scales = ( - self.pixel_scales[0] * pixel_scale_factors[0], - self.pixel_scales[1] * pixel_scale_factors[1], - ) - - else: - pixel_scales = None - - return Kernel2D.no_mask( - values=kernel_rescaled, pixel_scales=pixel_scales, normalize=normalize - ) - @property def normalized(self) -> "Kernel2D": """ @@ -612,19 +613,45 @@ def convolved_array_with_mask_from(self, array: Array2D, mask) -> Array2D: def convolve_image(self, image, blurring_image, jax_method="direct"): """ - For a given 1D array and blurring array, convolve the two using this psf. + Convolve an input masked image with this PSF. + + This method chooses between an FFT-based convolution (default if + ``self.use_fft=True``) or a direct real-space convolution, depending on + how the Kernel2D was configured. + + In the FFT branch: + - The input image (and optional blurring image) are resized / padded to + match the FFT-friendly padded shape (``fft_shape``) associated with this kernel. + - The PSF and image are transformed to Fourier space via ``jax.numpy.fft.rfft2``. + - Convolution is performed as elementwise multiplication. + - The result is inverse-transformed and cropped back to the masked region. + + Padding ensures that the FFT implements *linear* convolution, not circular, + and avoids wrap-around artefacts. The required padding is determined by + ``fft_shape_from(mask)``. If no precomputed shapes exist, they are computed + on the fly. For reproducible behaviour, precompute and set + ``fft_shape``, ``full_shape``, and ``mask_shape`` on the kernel. + + If ``use_fft=False``, convolution falls back to + :meth:`Kernel2D.convolve_image_via_real_space`. Parameters ---------- image - 1D array of the values which are to be blurred with the psf's PSF. + Masked 2D image array to convolve. blurring_image - 1D array of the blurring values which blur into the array after PSF convolution. - jax_method - If JAX is enabled this keyword will indicate what method is used for the PSF - convolution. Can be either `direct` to calculate it in real space or `fft` - to calculated it via a fast Fourier transform. `fft` is typically faster for - kernels that are more than about 5x5. Default is `fft`. + Masked image containing flux from outside the mask core that blurs + into the masked region after convolution. If ``None``, only the direct + image is convolved, which may be numerically incorrect if the mask + excludes PSF wings. + jax_method : {"direct", "fft"} + Backend passed to ``jax.scipy.signal.convolve`` when in real-space mode. + Ignored for FFT convolutions. + + Returns + ------- + Array2D + The convolved image in slim (1D masked) format. """ if not self.use_fft: @@ -712,25 +739,48 @@ def convolve_mapping_matrix( jax_method="direct", ): """ - Convolve a source-pixel mapping matrix with this PSF in Fourier space. - Also supports a blurring mapping matrix, which is added in the same way as blurring_image. + Convolve a source-plane mapping matrix with this PSF. + + A mapping matrix maps image-plane unmasked pixels to source-plane pixels. + This method performs the equivalent operation of PSF convolution on the + mapping matrix, so that model visibilities / images can be computed via + matrix multiplication instead of explicit convolution. + + If ``use_fft=True``, convolution is performed in Fourier space: + - The mapping matrix is scattered into a 3D native cube + (ny, nx, n_src). + - An FFT of this cube is multiplied by the precomputed FFT of the PSF. + - The inverse FFT is taken and cropped to the mask region. + - The slim (masked 1D) representation is returned. + + If ``use_fft=False``, convolution falls back to + :meth:`Kernel2D.convolve_mapping_matrix_via_real_space`. + + Notes + ----- + - FFT convolution requires that ``self.fft_shape`` and related padding + attributes are precomputed. If not, a ``ValueError`` is raised with the + expected vs actual shapes. This ensures the mapping matrix is padded + consistently with the PSF. + - The optional ``blurring_mapping_matrix`` plays the same role as + ``blurring_image`` in :meth:`convolve_image`, accounting for PSF flux + that falls into the masked region from outside. Parameters ---------- - mapping_matrix : (N_masked_pixels, N_src) - Mapping matrix of unmasked pixels to source pixels. - mask : Mask - Mask object with slim-to-native mapping. - blurring_mapping_matrix : (N_blurring_pixels, N_src) or None - Mapping matrix for the blurring grid (outside the mask core). - If provided, this is scattered into native space and added to the main mapping matrix. + mapping_matrix : ndarray of shape (N_pix, N_src) + Slim mapping matrix from unmasked pixels to source pixels. + mask : Mask2D + Associated mask defining the image grid. + blurring_mapping_matrix : ndarray of shape (N_blur, N_src), optional + Mapping matrix for the blurring region, outside the mask core. jax_method : str - Currently unused, placeholder for different convolution backends. + Backend passed to real-space convolution if ``use_fft=False``. Returns ------- - (N_masked_pixels, N_src) - Blurred mapping matrix in slim form (only unmasked pixels). + ndarray of shape (N_pix, N_src) + Convolved mapping matrix in slim form. """ if not self.use_fft: return self.convolve_mapping_matrix_via_real_space( @@ -808,23 +858,127 @@ def convolve_mapping_matrix( # return slim form return blurred_mapping_matrix_native[slim_to_native_tuple] + def rescaled_with_odd_dimensions_from( + self, rescale_factor: float, normalize: bool = False + ) -> "Kernel2D": + """ + Return a version of this kernel rescaled so both dimensions are odd-sized. + + Odd-sized kernels are often required for real space convolution operations + (e.g. centered PSFs in imaging pipelines). If the kernel has one or two + even-sized dimensions, they are rescaled (via interpolation) and padded + so that both dimensions are odd. + + The kernel can also be scaled larger or smaller by changing + ``rescale_factor``. Rescaling uses ``skimage.transform.rescale`` / + ``resize``, which interpolate pixel values and may introduce small + inaccuracies compared to native instrument PSFs. Where possible, users + should generate odd-sized PSFs directly from data reduction. + + Parameters + ---------- + rescale_factor + Factor by which the kernel is rescaled. If 1.0, only adjusts size to + nearest odd dimensions. Values > 1 enlarge, < 1 shrink the kernel. + normalize + If True, the returned kernel is normalized to sum to 1.0. + + Returns + ------- + Kernel2D + Rescaled kernel with odd-sized dimensions. + """ + + from skimage.transform import resize, rescale + + try: + kernel_rescaled = rescale( + self.native.array, + rescale_factor, + anti_aliasing=False, + mode="constant", + channel_axis=None, + ) + except TypeError: + kernel_rescaled = rescale( + self.native.array, + rescale_factor, + anti_aliasing=False, + mode="constant", + ) + + if kernel_rescaled.shape[0] % 2 == 0 and kernel_rescaled.shape[1] % 2 == 0: + kernel_rescaled = resize( + kernel_rescaled, + output_shape=( + kernel_rescaled.shape[0] + 1, + kernel_rescaled.shape[1] + 1, + ), + anti_aliasing=False, + mode="constant", + ) + + elif kernel_rescaled.shape[0] % 2 == 0 and kernel_rescaled.shape[1] % 2 != 0: + kernel_rescaled = resize( + kernel_rescaled, + output_shape=(kernel_rescaled.shape[0] + 1, kernel_rescaled.shape[1]), + anti_aliasing=False, + mode="constant", + ) + + elif kernel_rescaled.shape[0] % 2 != 0 and kernel_rescaled.shape[1] % 2 == 0: + kernel_rescaled = resize( + kernel_rescaled, + output_shape=(kernel_rescaled.shape[0], kernel_rescaled.shape[1] + 1), + anti_aliasing=False, + mode="constant", + ) + + if self.pixel_scales is not None: + pixel_scale_factors = ( + self.mask.shape[0] / kernel_rescaled.shape[0], + self.mask.shape[1] / kernel_rescaled.shape[1], + ) + + pixel_scales = ( + self.pixel_scales[0] * pixel_scale_factors[0], + self.pixel_scales[1] * pixel_scale_factors[1], + ) + + else: + pixel_scales = None + + return Kernel2D.no_mask( + values=kernel_rescaled, pixel_scales=pixel_scales, normalize=normalize + ) + def convolve_image_via_real_space( - self, image, blurring_image=None, jax_method="direct" + self, image : np.ndarray, blurring_image : Optional[np.ndarray] = None, jax_method : str = "direct" ): """ - Convolve an input image with this PSF in real space. + Convolve an input masked image with this PSF in real space. + + This is the direct method (non-FFT) where convolution is explicitly + performed using ``jax.scipy.signal.convolve`` with the kernel in native + space. + + Unlike FFT convolution, this does not require padding shapes, but it is + typically much slower for large kernels (> ~5x5). Parameters ---------- - image : Array2D - 1D array of values to be blurred with the PSF. - blurring_image : Array2D or None, optional - 1D array of blurring values which convolve into the image. - If None, only the direct image is convolved. A warning is raised - because omitting the blurring image may change the correctness of - the convolution result. - jax_method : {"direct", "fft"} - Method passed to `jax.scipy.signal.convolve`. Default is "direct". + image + Masked image array to convolve. + blurring_image + Blurring contribution from outside the mask core. If None, only the + direct image is convolved (which may be numerically incorrect). + jax_method + Method flag for JAX convolution backend (default "direct"). + + Returns + ------- + Array2D + Convolved image in slim format. """ slim_to_native_tuple = self.slim_to_native_tuple @@ -872,8 +1026,35 @@ def convolve_image_via_real_space( return Array2D(values=convolved_array_1d, mask=image.mask) def convolve_mapping_matrix_via_real_space( - self, mapping_matrix, mask, blurring_mapping_matrix=None, jax_method="direct" + self, mapping_matrix : np.ndarray, mask, blurring_mapping_matrix : Optional[np.ndarray] = None, jax_method : str = "direct" ): + """ + Convolve a source-plane mapping matrix with this PSF in real space. + + Equivalent to :meth:`convolve_mapping_matrix`, but using explicit + real-space convolution rather than FFTs. This avoids FFT padding issues + but is slower for large kernels. + + The mapping matrix is expanded into a native cube (ny, nx, n_src), + convolved with the kernel (broadcast along the source axis), + and reduced back to slim form. + + Parameters + ---------- + mapping_matrix + Slim mapping matrix from unmasked pixels to source pixels. + mask + Mask defining the pixelization grid. + blurring_mapping_matrix : ndarray (N_blur, N_src), optional + Mapping matrix for blurring region pixels outside the mask core. + jax_method + Backend passed to JAX convolution. + + Returns + ------- + ndarray (N_pix, N_src) + Convolved mapping matrix in slim form. + """ # 1) Indices of unmasked (image) pixels — no `size=` to avoid wrong lengths ys, xs = self.slim_to_native_tuple or jnp.nonzero(jnp.logical_not(mask.array)) n_pix, n_src = mapping_matrix.shape From 5571d17ebab14c1dec5ce30287b348397695da39 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 20:14:35 +0100 Subject: [PATCH 09/19] remove unused functions --- .../inversion/inversion/imaging/w_tilde.py | 3 +- .../inversion/inversion/inversion_util.py | 2 +- autoarray/structures/arrays/kernel_2d.py | 78 +------------------ 3 files changed, 7 insertions(+), 76 deletions(-) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index dbce9233b..f77cae100 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -523,7 +523,8 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: ) mapped_reconstructed_image = self.psf.convolve_image( - image=mapped_reconstructed_image.native, + image=mapped_reconstructed_image, + blurring_image=None, ).array mapped_reconstructed_image = Array2D( diff --git a/autoarray/inversion/inversion/inversion_util.py b/autoarray/inversion/inversion/inversion_util.py index 15ec4cf75..95e216c9e 100644 --- a/autoarray/inversion/inversion/inversion_util.py +++ b/autoarray/inversion/inversion/inversion_util.py @@ -93,7 +93,7 @@ def curvature_matrix_via_mapping_matrix_from( noise_map Flattened 1D array of the noise-map used by the inversion during the fit. """ - array = mapping_matrix / jnp.expand_dims(noise_map.array, 1) + array = mapping_matrix / noise_map[:, None] curvature_matrix = jnp.dot(array.T, array) if add_to_curvature_diag and len(no_regularization_index_list) > 0: diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 7a1b0c8a9..31684f734 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -541,76 +541,6 @@ def normalized(self) -> "Kernel2D": """ return Kernel2D(values=self, mask=self.mask, normalize=True) - def convolved_array_from(self, array: Array2D) -> Array2D: - """ - Convolve an array with this Kernel2D - - Parameters - ---------- - image - An array representing the image the Kernel2D is convolved with. - - Returns - ------- - convolved_image - An array representing the image after convolution. - - Raises - ------ - KernelException if either Kernel2D psf dimension is odd - """ - import scipy.signal - - if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") - - array_2d = array.native - - convolved_array_2d = scipy.signal.convolve2d( - array_2d.array, self.native.array, mode="same" - ) - - convolved_array_1d = array_2d_util.array_2d_slim_from( - mask_2d=array_2d.mask, - array_2d_native=convolved_array_2d, - ) - - return Array2D(values=convolved_array_1d, mask=array_2d.mask) - - def convolved_array_with_mask_from(self, array: Array2D, mask) -> Array2D: - """ - Convolve an array with this Kernel2D - - Parameters - ---------- - image - An array representing the image the Kernel2D is convolved with. - - Returns - ------- - convolved_image - An array representing the image after convolution. - - Raises - ------ - KernelException if either Kernel2D psf dimension is odd - """ - import scipy.signal - - if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") - - convolved_array_2d = scipy.signal.convolve2d( - array.array, self.native.array, mode="same" - ) - - convolved_array_1d = array_2d_util.array_2d_slim_from( - mask_2d=mask, - array_2d_native=convolved_array_2d, - ) - - return Array2D(values=convolved_array_1d, mask=mask) - def convolve_image(self, image, blurring_image, jax_method="direct"): """ Convolve an input masked image with this PSF. @@ -990,12 +920,12 @@ def convolve_image_via_real_space( ) # start with native array padded with zeros - expanded_array_native = jnp.zeros( + image_native = jnp.zeros( image.mask.shape, dtype=jnp.asarray(image.array).dtype ) # set image pixels - expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set( + image_native = image_native.at[slim_to_native_tuple].set( jnp.asarray(image.array) ) @@ -1006,7 +936,7 @@ def convolve_image_via_real_space( jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0], ) - expanded_array_native = expanded_array_native.at[ + image_native = image_native.at[ slim_to_native_blurring_tuple ].set(jnp.asarray(blurring_image.array)) else: @@ -1018,7 +948,7 @@ def convolve_image_via_real_space( # perform real-space convolution kernel = self.stored_native.array convolve_native = jax.scipy.signal.convolve( - expanded_array_native, kernel, mode="same", method=jax_method + image_native, kernel, mode="same", method=jax_method ) convolved_array_1d = convolve_native[slim_to_native_tuple] From f5df10c5199f75f5c2c70c8a6f6ddf209bb03fad Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 20:21:26 +0100 Subject: [PATCH 10/19] disable some dataset tests by moving disable fft around --- autoarray/dataset/imaging/dataset.py | 10 ++++++---- autoarray/dataset/imaging/simulator.py | 2 +- autoarray/mask/mask_2d.py | 2 +- test_autoarray/dataset/imaging/test_dataset.py | 16 ++++++++-------- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index ece559b67..d475608bb 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -335,7 +335,7 @@ def from_fits( over_sample_size_pixelization=over_sample_size_pixelization, ) - def apply_mask(self, mask: Mask2D) -> "Imaging": + def apply_mask(self, mask: Mask2D, disable_fft_pad : bool = False) -> "Imaging": """ Apply a mask to the imaging dataset, whereby the mask is applied to the image data, noise-map and other quantities one-by-one. @@ -383,7 +383,7 @@ def apply_mask(self, mask: Mask2D) -> "Imaging": noise_covariance_matrix=noise_covariance_matrix, over_sample_size_lp=over_sample_size_lp, over_sample_size_pixelization=over_sample_size_pixelization, - disable_fft_pad=False, + disable_fft_pad=disable_fft_pad, ) dataset.unmasked = unmasked_dataset @@ -398,6 +398,7 @@ def apply_noise_scaling( self, mask: Mask2D, noise_value: float = 1e8, + disable_fft_pad : bool = False, signal_to_noise_value: Optional[float] = None, should_zero_data: bool = True, ) -> "Imaging": @@ -476,7 +477,7 @@ def apply_noise_scaling( noise_covariance_matrix=self.noise_covariance_matrix, over_sample_size_lp=self.over_sample_size_lp, over_sample_size_pixelization=self.over_sample_size_pixelization, - disable_fft_pad=False, + disable_fft_pad=disable_fft_pad, check_noise_map=False, ) @@ -495,6 +496,7 @@ def apply_over_sampling( self, over_sample_size_lp: Union[int, Array2D] = None, over_sample_size_pixelization: Union[int, Array2D] = None, + disable_fft_pad : bool = False, ) -> "AbstractDataset": """ Apply new over sampling objects to the grid and grid pixelization of the dataset. @@ -524,7 +526,7 @@ def apply_over_sampling( over_sample_size_lp=over_sample_size_lp or self.over_sample_size_lp, over_sample_size_pixelization=over_sample_size_pixelization or self.over_sample_size_pixelization, - disable_fft_pad=False, + disable_fft_pad=disable_fft_pad, check_noise_map=False, ) diff --git a/autoarray/dataset/imaging/simulator.py b/autoarray/dataset/imaging/simulator.py index 576dc6017..2ac057df7 100644 --- a/autoarray/dataset/imaging/simulator.py +++ b/autoarray/dataset/imaging/simulator.py @@ -126,7 +126,7 @@ def via_image_from( pixel_scales=image.pixel_scales, ) - image = self.psf.convolved_array_from(array=image) + image = self.psf.convolve_image(image=image, blurring_image=None) image = image + background_sky_map diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 0f4fe30f9..0e46bc71d 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -653,7 +653,7 @@ def unmasked_blurred_array_from(self, padded_array, psf, image_shape) -> Array2D The 1D unmasked image which is blurred. """ - blurred_image = psf.convolved_array_from(array=padded_array) + blurred_image = psf.convolve_image(image=padded_array, blurring_image=None) return self.trimmed_array_from( padded_array=blurred_image, image_shape=image_shape diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index ead9e51f5..c8f61e75f 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -33,22 +33,22 @@ def make_test_data_path(): return test_data_path -def test__psf_and_mask_hit_edge__automatically_pads_image_and_noise_map(): +def test__psf_and_mask_hit_edge__automatically_pads_image_and_noise_map_for_fft(): image = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) noise_map = aa.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0) psf = aa.Kernel2D.ones(shape_native=(3, 3), pixel_scales=1.0) - dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=False) + dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, disable_fft_pad=True) assert dataset.data.shape_native == (3, 3) assert dataset.noise_map.shape_native == (3, 3) - dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, pad_for_psf=True) + dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, disable_fft_pad=False) - assert dataset.data.shape_native == (5, 5) - assert dataset.noise_map.shape_native == (5, 5) + assert dataset.data.shape_native == (6, 6) + assert dataset.noise_map.shape_native == (6, 6) assert dataset.data.mask[0, 0] == True - assert dataset.data.mask[1, 1] == False + assert dataset.data.mask[2, 2] == False def test__noise_covariance_input__noise_map_uses_diag(): @@ -126,7 +126,7 @@ def test__output_to_fits(imaging_7x7, test_data_path): def test__apply_mask(imaging_7x7, mask_2d_7x7, psf_3x3): - masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7) + masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7, disable_fft_pad=True) assert (masked_imaging_7x7.data.slim == np.ones(9)).all() @@ -263,5 +263,5 @@ def test__psf_not_odd_x_odd_kernel__raises_error(): psf = aa.Kernel2D.no_mask(values=[[0.0, 1.0], [1.0, 2.0]], pixel_scales=1.0) dataset = aa.Imaging( - data=image, noise_map=noise_map, psf=psf, pad_for_psf=False + data=image, noise_map=noise_map, psf=psf, disable_fft_pad=True ) From 913020a102ae970edc994c48b8b3a12012408a46 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 20:23:55 +0100 Subject: [PATCH 11/19] all unit tests pass --- test_autoarray/config/general.yaml | 2 +- test_autoarray/fit/test_fit_dataset.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test_autoarray/config/general.yaml b/test_autoarray/config/general.yaml index 2af37eda3..66d3354fd 100644 --- a/test_autoarray/config/general.yaml +++ b/test_autoarray/config/general.yaml @@ -3,7 +3,7 @@ analysis: fits: flip_for_ds9: false psf: - use_fft_default: false # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution. + use_fft_default: false # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution. Real space used for unit tests. grid: remove_projected_centre: false adapt: diff --git a/test_autoarray/fit/test_fit_dataset.py b/test_autoarray/fit/test_fit_dataset.py index 1f8ec6380..4363c7707 100644 --- a/test_autoarray/fit/test_fit_dataset.py +++ b/test_autoarray/fit/test_fit_dataset.py @@ -64,7 +64,10 @@ def test__figure_of_merit__with_noise_covariance_matrix_in_dataset( assert fit.chi_squared != pytest.approx(chi_squared, 1.0e-4) -def test__grid_offset_via_data_model(masked_imaging_7x7, model_image_7x7): +def test__grid_offset_via_data_model(imaging_7x7, mask_2d_7x7, model_image_7x7): + + masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7, disable_fft_pad=True) + fit = aa.m.MockFitImaging( dataset=masked_imaging_7x7, use_mask_in_fit=False, From 8f0e9ebc25a5c5eaf6af300cb3f7d68a423f259e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 9 Oct 2025 20:26:46 +0100 Subject: [PATCH 12/19] renamed methods to _fromAPI --- autoarray/dataset/imaging/dataset.py | 6 +-- autoarray/dataset/imaging/simulator.py | 2 +- .../inversion/inversion/imaging/abstract.py | 6 +-- .../inversion/inversion/imaging/mapping.py | 4 +- .../inversion/inversion/imaging/w_tilde.py | 2 +- autoarray/mask/mask_2d.py | 4 +- autoarray/operators/mock/mock_psf.py | 2 +- autoarray/structures/arrays/kernel_2d.py | 50 ++++++++++--------- .../dataset/imaging/test_dataset.py | 4 +- .../imaging/test_inversion_imaging_util.py | 6 +-- .../structures/arrays/test_kernel_2d.py | 35 ++++++++----- 11 files changed, 70 insertions(+), 51 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index d475608bb..a5a0097ec 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -335,7 +335,7 @@ def from_fits( over_sample_size_pixelization=over_sample_size_pixelization, ) - def apply_mask(self, mask: Mask2D, disable_fft_pad : bool = False) -> "Imaging": + def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging": """ Apply a mask to the imaging dataset, whereby the mask is applied to the image data, noise-map and other quantities one-by-one. @@ -398,7 +398,7 @@ def apply_noise_scaling( self, mask: Mask2D, noise_value: float = 1e8, - disable_fft_pad : bool = False, + disable_fft_pad: bool = False, signal_to_noise_value: Optional[float] = None, should_zero_data: bool = True, ) -> "Imaging": @@ -496,7 +496,7 @@ def apply_over_sampling( self, over_sample_size_lp: Union[int, Array2D] = None, over_sample_size_pixelization: Union[int, Array2D] = None, - disable_fft_pad : bool = False, + disable_fft_pad: bool = False, ) -> "AbstractDataset": """ Apply new over sampling objects to the grid and grid pixelization of the dataset. diff --git a/autoarray/dataset/imaging/simulator.py b/autoarray/dataset/imaging/simulator.py index 2ac057df7..536edfec6 100644 --- a/autoarray/dataset/imaging/simulator.py +++ b/autoarray/dataset/imaging/simulator.py @@ -126,7 +126,7 @@ def via_image_from( pixel_scales=image.pixel_scales, ) - image = self.psf.convolve_image(image=image, blurring_image=None) + image = self.psf.convolved_image_from(image=image, blurring_image=None) image = image + background_sky_map diff --git a/autoarray/inversion/inversion/imaging/abstract.py b/autoarray/inversion/inversion/imaging/abstract.py index 9167af6f9..1d94c464e 100644 --- a/autoarray/inversion/inversion/imaging/abstract.py +++ b/autoarray/inversion/inversion/imaging/abstract.py @@ -92,7 +92,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]: return [ ( - self.psf.convolve_mapping_matrix( + self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_obj.mapping_matrix, mask=self.mask ) if linear_obj.operated_mapping_matrix_override is None @@ -134,7 +134,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict: if linear_func.operated_mapping_matrix_override is not None: operated_mapping_matrix = linear_func.operated_mapping_matrix_override else: - operated_mapping_matrix = self.psf.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=linear_func.mapping_matrix, mask=self.mask, ) @@ -215,7 +215,7 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict: mapper_operated_mapping_matrix_dict = {} for mapper in self.cls_list_from(cls=AbstractMapper): - operated_mapping_matrix = self.psf.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=mapper.mapping_matrix, mask=self.mask, ) diff --git a/autoarray/inversion/inversion/imaging/mapping.py b/autoarray/inversion/inversion/imaging/mapping.py index 698750a22..169d45f9d 100644 --- a/autoarray/inversion/inversion/imaging/mapping.py +++ b/autoarray/inversion/inversion/imaging/mapping.py @@ -73,7 +73,7 @@ def _data_vector_mapper(self) -> np.ndarray: mapper = mapper_list[i] param_range = mapper_param_range_list[i] - operated_mapping_matrix = self.psf.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=mapper.mapping_matrix, mask=self.mask ) @@ -132,7 +132,7 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]: mapper_i = mapper_list[i] mapper_param_range_i = mapper_param_range_list[i] - operated_mapping_matrix = self.psf.convolve_mapping_matrix( + operated_mapping_matrix = self.psf.convolved_mapping_matrix_from( mapping_matrix=mapper_i.mapping_matrix, mask=self.mask ) diff --git a/autoarray/inversion/inversion/imaging/w_tilde.py b/autoarray/inversion/inversion/imaging/w_tilde.py index f77cae100..ed87179e5 100644 --- a/autoarray/inversion/inversion/imaging/w_tilde.py +++ b/autoarray/inversion/inversion/imaging/w_tilde.py @@ -522,7 +522,7 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]: values=mapped_reconstructed_image, mask=self.mask ) - mapped_reconstructed_image = self.psf.convolve_image( + mapped_reconstructed_image = self.psf.convolved_image_from( image=mapped_reconstructed_image, blurring_image=None, ).array diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 0e46bc71d..bf6fc64c8 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -653,7 +653,9 @@ def unmasked_blurred_array_from(self, padded_array, psf, image_shape) -> Array2D The 1D unmasked image which is blurred. """ - blurred_image = psf.convolve_image(image=padded_array, blurring_image=None) + blurred_image = psf.convolved_image_from( + image=padded_array, blurring_image=None + ) return self.trimmed_array_from( padded_array=blurred_image, image_shape=image_shape diff --git a/autoarray/operators/mock/mock_psf.py b/autoarray/operators/mock/mock_psf.py index e89d2b732..44fdf847f 100644 --- a/autoarray/operators/mock/mock_psf.py +++ b/autoarray/operators/mock/mock_psf.py @@ -2,5 +2,5 @@ class MockPSF: def __init__(self, operated_mapping_matrix=None): self.operated_mapping_matrix = operated_mapping_matrix - def convolve_mapping_matrix(self, mapping_matrix, mask): + def convolved_mapping_matrix_from(self, mapping_matrix, mask): return self.operated_mapping_matrix diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 31684f734..31b238e3b 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -19,7 +19,6 @@ from autoarray.structures.arrays import array_2d_util - class Kernel2D(AbstractArray2D): def __init__( self, @@ -482,7 +481,9 @@ def from_fits( header=Header(header_sci_obj=header_sci_obj, header_hdu_obj=header_hdu_obj), ) - def fft_shape_from(self, mask : np.ndarray) -> Union[Tuple[int, int], Tuple[int, int], Tuple[int, int]]: + def fft_shape_from( + self, mask: np.ndarray + ) -> Union[Tuple[int, int], Tuple[int, int], Tuple[int, int]]: """ Compute the padded shapes required for FFT-based convolution with this kernel. @@ -541,7 +542,7 @@ def normalized(self) -> "Kernel2D": """ return Kernel2D(values=self, mask=self.mask, normalize=True) - def convolve_image(self, image, blurring_image, jax_method="direct"): + def convolved_image_from(self, image, blurring_image, jax_method="direct"): """ Convolve an input masked image with this PSF. @@ -563,7 +564,7 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): ``fft_shape``, ``full_shape``, and ``mask_shape`` on the kernel. If ``use_fft=False``, convolution falls back to - :meth:`Kernel2D.convolve_image_via_real_space`. + :meth:`Kernel2D.convolved_image_via_real_space_from`. Parameters ---------- @@ -585,7 +586,7 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): """ if not self.use_fft: - return self.convolve_image_via_real_space( + return self.convolved_image_via_real_space_from( image=image, blurring_image=blurring_image, jax_method=jax_method ) @@ -638,9 +639,7 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): ) # FFT the combined image - fft_image_native = jnp.fft.rfft2( - image_both_native, s=fft_shape, axes=(0, 1) - ) + fft_image_native = jnp.fft.rfft2(image_both_native, s=fft_shape, axes=(0, 1)) # Multiply by PSF in Fourier space and invert blurred_image_full = jnp.fft.irfft2( @@ -661,7 +660,7 @@ def convolve_image(self, image, blurring_image, jax_method="direct"): values=blurred_image_native[slim_to_native_tuple], mask=image.mask ) - def convolve_mapping_matrix( + def convolved_mapping_matrix_from( self, mapping_matrix, mask, @@ -684,7 +683,7 @@ def convolve_mapping_matrix( - The slim (masked 1D) representation is returned. If ``use_fft=False``, convolution falls back to - :meth:`Kernel2D.convolve_mapping_matrix_via_real_space`. + :meth:`Kernel2D.convolved_mapping_matrix_via_real_space_from`. Notes ----- @@ -693,7 +692,7 @@ def convolve_mapping_matrix( expected vs actual shapes. This ensures the mapping matrix is padded consistently with the PSF. - The optional ``blurring_mapping_matrix`` plays the same role as - ``blurring_image`` in :meth:`convolve_image`, accounting for PSF flux + ``blurring_image`` in :meth:`convolved_image_from`, accounting for PSF flux that falls into the masked region from outside. Parameters @@ -713,7 +712,7 @@ def convolve_mapping_matrix( Convolved mapping matrix in slim form. """ if not self.use_fft: - return self.convolve_mapping_matrix_via_real_space( + return self.convolved_mapping_matrix_via_real_space_from( mapping_matrix=mapping_matrix, mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, @@ -882,8 +881,11 @@ def rescaled_with_odd_dimensions_from( values=kernel_rescaled, pixel_scales=pixel_scales, normalize=normalize ) - def convolve_image_via_real_space( - self, image : np.ndarray, blurring_image : Optional[np.ndarray] = None, jax_method : str = "direct" + def convolved_image_via_real_space_from( + self, + image: np.ndarray, + blurring_image: Optional[np.ndarray] = None, + jax_method: str = "direct", ): """ Convolve an input masked image with this PSF in real space. @@ -920,9 +922,7 @@ def convolve_image_via_real_space( ) # start with native array padded with zeros - image_native = jnp.zeros( - image.mask.shape, dtype=jnp.asarray(image.array).dtype - ) + image_native = jnp.zeros(image.mask.shape, dtype=jnp.asarray(image.array).dtype) # set image pixels image_native = image_native.at[slim_to_native_tuple].set( @@ -936,9 +936,9 @@ def convolve_image_via_real_space( jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0], ) - image_native = image_native.at[ - slim_to_native_blurring_tuple - ].set(jnp.asarray(blurring_image.array)) + image_native = image_native.at[slim_to_native_blurring_tuple].set( + jnp.asarray(blurring_image.array) + ) else: warnings.warn( "No blurring_image provided. Only the direct image will be convolved. " @@ -955,13 +955,17 @@ def convolve_image_via_real_space( return Array2D(values=convolved_array_1d, mask=image.mask) - def convolve_mapping_matrix_via_real_space( - self, mapping_matrix : np.ndarray, mask, blurring_mapping_matrix : Optional[np.ndarray] = None, jax_method : str = "direct" + def convolved_mapping_matrix_via_real_space_from( + self, + mapping_matrix: np.ndarray, + mask, + blurring_mapping_matrix: Optional[np.ndarray] = None, + jax_method: str = "direct", ): """ Convolve a source-plane mapping matrix with this PSF in real space. - Equivalent to :meth:`convolve_mapping_matrix`, but using explicit + Equivalent to :meth:`convolved_mapping_matrix_from`, but using explicit real-space convolution rather than FFTs. This avoids FFT padding issues but is slower for large kernels. diff --git a/test_autoarray/dataset/imaging/test_dataset.py b/test_autoarray/dataset/imaging/test_dataset.py index c8f61e75f..2e07a4cc3 100644 --- a/test_autoarray/dataset/imaging/test_dataset.py +++ b/test_autoarray/dataset/imaging/test_dataset.py @@ -43,7 +43,9 @@ def test__psf_and_mask_hit_edge__automatically_pads_image_and_noise_map_for_fft( assert dataset.data.shape_native == (3, 3) assert dataset.noise_map.shape_native == (3, 3) - dataset = aa.Imaging(data=image, noise_map=noise_map, psf=psf, disable_fft_pad=False) + dataset = aa.Imaging( + data=image, noise_map=noise_map, psf=psf, disable_fft_pad=False + ) assert dataset.data.shape_native == (6, 6) assert dataset.noise_map.shape_native == (6, 6) diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index cafb8722b..546ffe763 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -203,7 +203,7 @@ def test__data_vector_via_w_tilde_data_two_methods_agree(): mapping_matrix = mapper.mapping_matrix - blurred_mapping_matrix = psf.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolved_mapping_matrix_from( mapping_matrix=mapping_matrix, mask=mask ) @@ -290,7 +290,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree(): w_tilde=w_tilde, mapping_matrix=mapping_matrix ) - blurred_mapping_matrix = psf.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolved_mapping_matrix_from( mapping_matrix=mapping_matrix, mask=mask ) @@ -370,7 +370,7 @@ def test__curvature_matrix_via_w_tilde_preload_two_methods_agree(): pix_pixels=pixelization.pixels, ) - blurred_mapping_matrix = psf.convolve_mapping_matrix( + blurred_mapping_matrix = psf.convolved_mapping_matrix_from( mapping_matrix=mapping_matrix, mask=mask, ) diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 6718cfcbf..00f4407ef 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -368,7 +368,7 @@ def test__convolve_image(): values=blurred_image_via_scipy.native, mask=mask ) - # Now reproduce this data using the convolve_image function + # Now reproduce this data using the convolved_image_from function image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0) kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0) @@ -381,7 +381,7 @@ def test__convolve_image(): blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) - blurred_masked_im_1 = kernel.convolve_image( + blurred_masked_im_1 = kernel.convolved_image_from( image=masked_image, blurring_image=blurring_image ) @@ -420,7 +420,9 @@ def test__convolve_image_no_blurring(): masked_image = aa.Array2D(values=image.native, mask=mask) - blurred_masked_im_1 = kernel.convolve_image(image=masked_image, blurring_image=None) + blurred_masked_im_1 = kernel.convolved_image_from( + image=masked_image, blurring_image=None + ) assert blurred_masked_image_via_scipy == pytest.approx( blurred_masked_im_1.array, 1e-4 @@ -471,7 +473,7 @@ def test__convolve_mapping_matrix(): ] ) - blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask) + blurred_mapping = kernel.convolved_mapping_matrix_from(mapping, mask) assert ( blurred_mapping @@ -529,7 +531,7 @@ def test__convolve_mapping_matrix(): ] ) - blurred_mapping = kernel.convolve_mapping_matrix(mapping, mask) + blurred_mapping = kernel.convolved_mapping_matrix_from(mapping, mask) assert blurred_mapping == pytest.approx( np.array( @@ -555,6 +557,7 @@ def test__convolve_mapping_matrix(): abs=1e-4, ) + def test__convolve_image__via_fft__sizes_not_precomputed__compare_numerical_value(): # ------------------------------- @@ -567,14 +570,22 @@ def test__convolve_image__via_fft__sizes_not_precomputed__compare_numerical_valu image = aa.Array2D.no_mask(values=np.arange(400).reshape(20, 20), pixel_scales=1.0) masked_image = aa.Array2D(values=image.native, mask=mask) - kernel_fft = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0, use_fft=True, normalize=True) + kernel_fft = aa.Kernel2D.no_mask( + values=np.arange(49).reshape(7, 7), + pixel_scales=1.0, + use_fft=True, + normalize=True, + ) - blurring_mask = mask.derive_mask.blurring_from(kernel_shape_native=kernel_fft.shape_native) + blurring_mask = mask.derive_mask.blurring_from( + kernel_shape_native=kernel_fft.shape_native + ) blurring_image = aa.Array2D(values=image.native, mask=blurring_mask) - blurred_fft = kernel_fft.convolve_image(image=masked_image, blurring_image=blurring_image) - - assert blurred_fft.native.array[13, 13] == pytest.approx(207.49999999999, rel=1e-6, abs=1e-6) - - + blurred_fft = kernel_fft.convolved_image_from( + image=masked_image, blurring_image=blurring_image + ) + assert blurred_fft.native.array[13, 13] == pytest.approx( + 207.49999999999, rel=1e-6, abs=1e-6 + ) From f7bda00665fe5560fdb3223acae2a17aaf15b4c4 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 13 Oct 2025 20:20:35 +0100 Subject: [PATCH 13/19] remove unmasked dataset and replace with exception --- autoarray/dataset/imaging/dataset.py | 43 +++++++--------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index a5a0097ec..5bab5624f 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -88,8 +88,6 @@ def __init__( If True, the noise-map is checked to ensure all values are above zero. """ - self.unmasked = None - self.disable_fft_pad = disable_fft_pad if psf is not None: @@ -340,26 +338,26 @@ def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging": Apply a mask to the imaging dataset, whereby the mask is applied to the image data, noise-map and other quantities one-by-one. - The original unmasked imaging data is stored as the `self.unmasked` attribute. This is used to ensure that if - the `apply_mask` function is called multiple times, every mask is always applied to the original unmasked - imaging dataset. + The `apply_mask` function cannot be called multiple times, if it is a mask may remove data, therefore + an exception is raised. If you wish to apply a new mask, reload the dataset from .fits files. Parameters ---------- mask The 2D mask that is applied to the image. """ - if self.data.mask.is_all_false: - unmasked_dataset = self - else: - unmasked_dataset = self.unmasked + if not self.data.mask.is_all_false: + raise exc.DatasetException( + "The mask has already been applied to the dataset, therefore a new mask cannot be applied. " + "If you wish to apply a new mask, please reload the dataset from .fits files." + ) - data = Array2D(values=unmasked_dataset.data.native, mask=mask) + data = Array2D(values=self.data.native, mask=mask) - noise_map = Array2D(values=unmasked_dataset.noise_map.native, mask=mask) + noise_map = Array2D(values=self.noise_map.native, mask=mask) - if unmasked_dataset.noise_covariance_matrix is not None: - noise_covariance_matrix = unmasked_dataset.noise_covariance_matrix + if self..noise_covariance_matrix is not None: + noise_covariance_matrix = self..noise_covariance_matrix noise_covariance_matrix = np.delete( noise_covariance_matrix, mask.derive_indexes.masked_slim, 0 @@ -386,8 +384,6 @@ def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging": disable_fft_pad=disable_fft_pad, ) - dataset.unmasked = unmasked_dataset - logger.info( f"IMAGING - Data masked, contains a total of {mask.pixels_in_mask} image-pixels" ) @@ -454,18 +450,6 @@ def apply_noise_scaling( else: data = self.data.native.array - data_unmasked = Array2D.no_mask( - values=data, - shape_native=self.data.shape_native, - pixel_scales=self.data.pixel_scales, - ) - - noise_map_unmasked = Array2D.no_mask( - values=noise_map, - shape_native=self.noise_map.shape_native, - pixel_scales=self.noise_map.pixel_scales, - ) - data = Array2D(values=data, mask=self.data.mask) noise_map = Array2D(values=noise_map, mask=self.data.mask) @@ -481,11 +465,6 @@ def apply_noise_scaling( check_noise_map=False, ) - if self.unmasked is not None: - dataset.unmasked = self.unmasked - dataset.unmasked.data = data_unmasked - dataset.unmasked.noise_map = noise_map_unmasked - logger.info( f"IMAGING - Data noise scaling applied, a total of {mask.pixels_in_mask} pixels were scaled to large noise values." ) From fe909062bae90f92940e7d48ab7da163f18772fe Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 13 Oct 2025 20:21:14 +0100 Subject: [PATCH 14/19] fix typo --- autoarray/dataset/imaging/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 5bab5624f..c05cfc3ad 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -356,7 +356,7 @@ def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging": noise_map = Array2D(values=self.noise_map.native, mask=mask) - if self..noise_covariance_matrix is not None: + if self.noise_covariance_matrix is not None: noise_covariance_matrix = self..noise_covariance_matrix noise_covariance_matrix = np.delete( From 01f911e3358931e529f5670663a564f6aaa7ac71 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 13 Oct 2025 20:21:44 +0100 Subject: [PATCH 15/19] fix typo --- autoarray/dataset/imaging/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index c05cfc3ad..268e1f344 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -357,7 +357,7 @@ def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging": noise_map = Array2D(values=self.noise_map.native, mask=mask) if self.noise_covariance_matrix is not None: - noise_covariance_matrix = self..noise_covariance_matrix + noise_covariance_matrix = self.noise_covariance_matrix noise_covariance_matrix = np.delete( noise_covariance_matrix, mask.derive_indexes.masked_slim, 0 From 44dbadd439e1bc9aa7903579d33c4eca7be461da Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 14 Oct 2025 08:45:23 +0100 Subject: [PATCH 16/19] remove unused unit test --- .../structures/arrays/test_kernel_2d.py | 67 ++----------------- 1 file changed, 4 insertions(+), 63 deletions(-) diff --git a/test_autoarray/structures/arrays/test_kernel_2d.py b/test_autoarray/structures/arrays/test_kernel_2d.py index 00f4407ef..1b1420633 100644 --- a/test_autoarray/structures/arrays/test_kernel_2d.py +++ b/test_autoarray/structures/arrays/test_kernel_2d.py @@ -290,66 +290,7 @@ def test__from_as_gaussian_via_alma_fits_header_parameters__identical_to_astropy assert kernel_astropy == pytest.approx(kernel_2d.native._array, abs=1e-4) -def test__convolved_array_from(): - - array_2d = aa.Array2D.no_mask( - [ - [0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0], - ], - pixel_scales=1.0, - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 3.0, 3.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) - - assert ( - blurred_array_2d.native - == np.array( - [ - [1.0, 1.0, 0.0, 0.0], - [2.0, 1.0, 1.0, 1.0], - [3.0, 3.0, 2.0, 2.0], - [0.0, 0.0, 1.0, 3.0], - ] - ) - ).all() - - array_2d = aa.Array2D.no_mask( - values=[ - [1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ], - pixel_scales=1.0, - ) - - kernel_2d = aa.Kernel2D.no_mask( - values=[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 3.0, 3.0]], pixel_scales=1.0 - ) - - blurred_array_2d = kernel_2d.convolved_array_from(array_2d) - - assert ( - blurred_array_2d.native - == np.array( - [ - [2.0, 1.0, 0.0, 0.0], - [3.0, 3.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 2.0, 2.0], - ] - ) - ).all() - - -def test__convolve_image(): +def test__convolved_image_from(): mask = aa.Mask2D.circular( shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0 @@ -390,7 +331,7 @@ def test__convolve_image(): ) -def test__convolve_image_no_blurring(): +def test__convolve_imaged_from__no_blurring(): # Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array. mask = aa.Mask2D.circular( @@ -429,7 +370,7 @@ def test__convolve_image_no_blurring(): ) -def test__convolve_mapping_matrix(): +def test__convolved_mapping_matrix_from(): mask = aa.Mask2D( mask=np.array( [ @@ -558,7 +499,7 @@ def test__convolve_mapping_matrix(): ) -def test__convolve_image__via_fft__sizes_not_precomputed__compare_numerical_value(): +def test__convolve_imaged_from__via_fft__sizes_not_precomputed__compare_numerical_value(): # ------------------------------- # Case 1: direct image convolution From 8270894b1e9d960fbeb2170377d7977517d6512c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 14 Oct 2025 09:50:59 +0100 Subject: [PATCH 17/19] mapping mative native from to avoid repeated code --- autoarray/dataset/imaging/dataset.py | 1 - autoarray/geometry/geometry_1d.py | 7 +- autoarray/mask/derive/mask_2d.py | 2 - autoarray/structures/arrays/kernel_2d.py | 175 ++++++++++++++--------- 4 files changed, 107 insertions(+), 78 deletions(-) diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 268e1f344..1422165aa 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -1,7 +1,6 @@ import logging import numpy as np from pathlib import Path -import scipy from typing import Optional, Union from autoconf import cached_property diff --git a/autoarray/geometry/geometry_1d.py b/autoarray/geometry/geometry_1d.py index a6380ca75..bfaad808b 100644 --- a/autoarray/geometry/geometry_1d.py +++ b/autoarray/geometry/geometry_1d.py @@ -1,11 +1,6 @@ from __future__ import annotations import logging -import numpy as np -from typing import TYPE_CHECKING, List, Tuple, Union - -if TYPE_CHECKING: - from autoarray.structures.grids.uniform_1d import Grid1D - from autoarray.mask.mask_2d import Mask2D +from typing import Tuple from autoarray import type as ty diff --git a/autoarray/mask/derive/mask_2d.py b/autoarray/mask/derive/mask_2d.py index 19e005362..aa28068e1 100644 --- a/autoarray/mask/derive/mask_2d.py +++ b/autoarray/mask/derive/mask_2d.py @@ -1,6 +1,5 @@ from __future__ import annotations import logging -import copy import numpy as np from typing import TYPE_CHECKING, Tuple @@ -10,7 +9,6 @@ from autoarray import exc from autoarray.mask.derive.indexes_2d import DeriveIndexes2D -from autoarray.structures.arrays import array_2d_util from autoarray.mask import mask_2d_util logging.basicConfig() diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 31b238e3b..0fb5450a2 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -1,3 +1,10 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from autoarray import Mask2D + import jax import jax.numpy as jnp import numpy as np @@ -14,9 +21,7 @@ from autoarray.structures.grids.uniform_2d import Grid2D from autoarray.structures.header import Header -from autoarray import exc from autoarray import type as ty -from autoarray.structures.arrays import array_2d_util class Kernel2D(AbstractArray2D): @@ -542,6 +547,83 @@ def normalized(self) -> "Kernel2D": """ return Kernel2D(values=self, mask=self.mask, normalize=True) + def mapping_matrix_native_from( + self, + mapping_matrix: jnp.ndarray, + mask: "Mask2D", + blurring_mapping_matrix: Optional[jnp.ndarray] = None, + blurring_mask: Optional["Mask2D"] = None, + ) -> jnp.ndarray: + """ + Expand a slim mapping matrix (image-plane) and optional blurring mapping matrix + into a full native 3D cube (ny, nx, n_src). + + This is primarily used for real-space convolution, where the pixel-to-source + mapping must be represented on the full image grid. + + Parameters + ---------- + mapping_matrix : ndarray (N_pix, N_src) + Slim mapping matrix for unmasked image pixels, mapping each image pixel + to source-plane pixels. + mask : Mask2D + Mask defining which image pixels are unmasked. Used to expand the slim + mapping matrix into a native grid. + blurring_mapping_matrix : ndarray (N_blur, N_src), optional + Mapping matrix for blurring pixels outside the main mask (e.g. light + spilling in from outside). If provided, it is also scattered into the + native cube. + blurring_mask : Mask2D, optional + Mask defining the blurring region pixels. Must be provided if + `blurring_mapping_matrix` is given and `slim_to_native_blurring_tuple` + is not already cached. + + Returns + ------- + ndarray (ny, nx, N_src) + Native 3D mapping matrix cube with dimensions (image_y, image_x, sources). + Contains contributions from both the main mapping matrix and, if provided, + the blurring mapping matrix. + """ + slim_to_native_tuple = self.slim_to_native_tuple + if slim_to_native_tuple is None: + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(mask.array), size=mapping_matrix.shape[0] + ) + + n_src = mapping_matrix.shape[1] + + # Allocate full native grid (ny, nx, n_src) + mapping_matrix_native = jnp.zeros( + mask.shape + (n_src,), dtype=mapping_matrix.dtype + ) + + # Scatter main mapping matrix into native cube + mapping_matrix_native = mapping_matrix_native.at[slim_to_native_tuple].set( + mapping_matrix + ) + + # Optionally scatter blurring mapping matrix + if blurring_mapping_matrix is not None: + slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple + + if slim_to_native_blurring_tuple is None: + if blurring_mask is None: + raise ValueError( + "blurring_mask must be provided if blurring_mapping_matrix is given " + "and slim_to_native_blurring_tuple is None." + ) + slim_to_native_blurring_tuple = jnp.nonzero( + jnp.logical_not(blurring_mask.array), + size=blurring_mapping_matrix.shape[0], + ) + + mapping_matrix_native = mapping_matrix_native.at[ + slim_to_native_blurring_tuple + ].set(blurring_mapping_matrix) + + return mapping_matrix_native + def convolved_image_from(self, image, blurring_image, jax_method="direct"): """ Convolve an input masked image with this PSF. @@ -665,6 +747,7 @@ def convolved_mapping_matrix_from( mapping_matrix, mask, blurring_mapping_matrix=None, + blurring_mask: Optional[Mask2D] = None, jax_method="direct", ): """ @@ -716,6 +799,7 @@ def convolved_mapping_matrix_from( mapping_matrix=mapping_matrix, mask=mask, blurring_mapping_matrix=blurring_mapping_matrix, + blurring_mask=blurring_mask, jax_method=jax_method, ) @@ -735,35 +819,22 @@ def convolved_mapping_matrix_from( fft_shape = self.fft_shape full_shape = self.full_shape mask_shape = self.mask_shape - fft_psf = self.fft_psf fft_psf_mapping = self.fft_psf_mapping slim_to_native_tuple = self.slim_to_native_tuple if slim_to_native_tuple is None: slim_to_native_tuple = jnp.nonzero( - jnp.logical_not(mask.array), size=mask.shape[0] + jnp.logical_not(mask.array), size=mapping_matrix.shape[0] ) - n_src = mapping_matrix.shape[1] - - # allocate full native + source dimension - mapping_matrix_native = jnp.zeros( - mask.shape + (n_src,), dtype=mapping_matrix.dtype - ) - - # scatter main mapping matrix - mapping_matrix_native = mapping_matrix_native.at[slim_to_native_tuple].set( - mapping_matrix + mapping_matrix_native = self.mapping_matrix_native_from( + mapping_matrix=mapping_matrix, + mask=mask, + blurring_mapping_matrix=blurring_mapping_matrix, + blurring_mask=blurring_mask, ) - # optionally scatter blurring mapping matrix - if blurring_mapping_matrix is not None: - slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple - mapping_matrix_native = mapping_matrix_native.at[ - slim_to_native_blurring_tuple - ].set(blurring_mapping_matrix) - # FFT convolution fft_mapping_matrix_native = jnp.fft.rfft2( mapping_matrix_native, s=fft_shape, axes=(0, 1) @@ -960,6 +1031,7 @@ def convolved_mapping_matrix_via_real_space_from( mapping_matrix: np.ndarray, mask, blurring_mapping_matrix: Optional[np.ndarray] = None, + blurring_mask: Optional[Mask2D] = None, jax_method: str = "direct", ): """ @@ -989,60 +1061,25 @@ def convolved_mapping_matrix_via_real_space_from( ndarray (N_pix, N_src) Convolved mapping matrix in slim form. """ - # 1) Indices of unmasked (image) pixels — no `size=` to avoid wrong lengths - ys, xs = self.slim_to_native_tuple or jnp.nonzero(jnp.logical_not(mask.array)) - n_pix, n_src = mapping_matrix.shape - - # Sanity check - if ys.shape[0] != n_pix: - raise ValueError( - f"Mapping rows ({n_pix}) != unmasked pixels ({ys.shape[0]}). " - "Make sure you’re using the image (not blurring) index tuple." - ) - - # 2) Allocate native cube (ny, nx, n_src) - mapping_matrix_native = jnp.zeros( - mask.shape + (n_src,), dtype=mapping_matrix.dtype - ) - # 3) Build index grids with identical shape (n_pix, n_src) - ys_exp = jnp.broadcast_to(ys[:, None], (n_pix, n_src)) - xs_exp = jnp.broadcast_to(xs[:, None], (n_pix, n_src)) - src_exp = jnp.broadcast_to(jnp.arange(n_src)[None, :], (n_pix, n_src)) - - # 4) Scatter all at once (values also shape (n_pix, n_src)) - mapping_matrix_native = mapping_matrix_native.at[(ys_exp, xs_exp, src_exp)].set( - mapping_matrix - ) + slim_to_native_tuple = self.slim_to_native_tuple - # 5) Optional blurring mapping matrix - if blurring_mapping_matrix is not None: - ys_b, xs_b = self.slim_to_native_blurring_tuple or jnp.nonzero( - jnp.logical_not( - mask.array - ) # use the correct blurring grid mask here if different + if slim_to_native_tuple is None: + slim_to_native_tuple = jnp.nonzero( + jnp.logical_not(mask.array), size=mapping_matrix.shape[0] ) - n_blur, n_src_b = blurring_mapping_matrix.shape - if n_src_b != n_src: - raise ValueError( - "blurring_mapping_matrix columns must match mapping_matrix columns (n_src)." - ) - - ys_b_exp = jnp.broadcast_to(ys_b[:, None], (n_blur, n_src)) - xs_b_exp = jnp.broadcast_to(xs_b[:, None], (n_blur, n_src)) - src_b_exp = jnp.broadcast_to(jnp.arange(n_src)[None, :], (n_blur, n_src)) - - mapping_matrix_native = mapping_matrix_native.at[ - (ys_b_exp, xs_b_exp, src_b_exp) - ].set(blurring_mapping_matrix) + mapping_matrix_native = self.mapping_matrix_native_from( + mapping_matrix=mapping_matrix, + mask=mask, + blurring_mapping_matrix=blurring_mapping_matrix, + blurring_mask=blurring_mask, + ) # 6) Real-space convolution, broadcast kernel over source axis kernel = self.stored_native.array - convolved_native = jax.scipy.signal.convolve( + blurred_mapping_matrix_native = jax.scipy.signal.convolve( mapping_matrix_native, kernel[..., None], mode="same", method=jax_method ) - # 7) Pull back to slim (n_pix, n_src) - blurred_mapping_matrix = convolved_native[ys, xs, :] - - return blurred_mapping_matrix + # return slim form + return blurred_mapping_matrix_native[slim_to_native_tuple] From 7cd493116361a504d85df7f965856832b703ee4d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 14 Oct 2025 10:44:52 +0100 Subject: [PATCH 18/19] update FFT padding defaults to go odd x odd image if FFT not used --- autoarray/config/general.yaml | 2 ++ autoarray/dataset/imaging/dataset.py | 14 +++++++++++--- test_autoarray/conftest.py | 3 +++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/autoarray/config/general.yaml b/autoarray/config/general.yaml index 7b9112e81..d74fc3d67 100644 --- a/autoarray/config/general.yaml +++ b/autoarray/config/general.yaml @@ -2,6 +2,8 @@ jax: use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy. fits: flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9. +psf: + use_fft_default: true # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution. inversion: check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same. use_positive_only_solver: true # If True, inversion's use a positive-only linear algebra solver by default, which is slower but prevents unphysical negative values in the reconstructed solutuion. diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index 1422165aa..9642eba2a 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -4,6 +4,7 @@ from typing import Optional, Union from autoconf import cached_property +from autoconf import instance from autoarray.dataset.abstract.dataset import AbstractDataset from autoarray.dataset.grids import GridsDataset @@ -95,6 +96,10 @@ def __init__( if psf is not None and not disable_fft_pad and data.mask.shape != fft_shape: + # If using real-space convolution instead of FFT, enforce odd-odd shapes + if not psf.use_fft: + fft_shape = tuple(s + 1 if s % 2 == 0 else s for s in fft_shape) + logger.info( f"Imaging data has been trimmed or padded for FFT convolution.\n" f" - Original shape : {data.mask.shape}\n" @@ -345,10 +350,13 @@ def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging": mask The 2D mask that is applied to the image. """ - if not self.data.mask.is_all_false: + invalid = np.logical_and(self.data.mask, np.logical_not(mask)) + + if np.any(invalid): raise exc.DatasetException( - "The mask has already been applied to the dataset, therefore a new mask cannot be applied. " - "If you wish to apply a new mask, please reload the dataset from .fits files." + "The new mask overlaps with pixels that are already unmasked in the dataset. " + "You cannot apply a new mask on top of an existing one. " + "If you wish to apply a different mask, please reload the dataset from .fits files." ) data = Array2D(values=self.data.native, mask=mask) diff --git a/test_autoarray/conftest.py b/test_autoarray/conftest.py index b322cab3c..4dab798cc 100644 --- a/test_autoarray/conftest.py +++ b/test_autoarray/conftest.py @@ -14,6 +14,9 @@ def pytest_configure(): from autoconf import conf + + + class PlotPatch: def __init__(self): self.paths = [] From 4e5a6476d78ece5f915d240c39dc4011dac6894c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 14 Oct 2025 11:27:30 +0100 Subject: [PATCH 19/19] black --- autoarray/dataset/imaging/simulator.py | 8 ++++++-- test_autoarray/conftest.py | 3 --- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/autoarray/dataset/imaging/simulator.py b/autoarray/dataset/imaging/simulator.py index 536edfec6..081498701 100644 --- a/autoarray/dataset/imaging/simulator.py +++ b/autoarray/dataset/imaging/simulator.py @@ -169,12 +169,16 @@ def via_image_from( image = Array2D(values=image, mask=mask) dataset = Imaging( - data=image, psf=self.psf, noise_map=noise_map, check_noise_map=False + data=image, + psf=self.psf, + noise_map=noise_map, + check_noise_map=False, + disable_fft_pad=True, ) if over_sample_size is not None: dataset = dataset.apply_over_sampling( - over_sample_size_lp=over_sample_size.native + over_sample_size_lp=over_sample_size.native, disable_fft_pad=True ) return dataset diff --git a/test_autoarray/conftest.py b/test_autoarray/conftest.py index 4dab798cc..b322cab3c 100644 --- a/test_autoarray/conftest.py +++ b/test_autoarray/conftest.py @@ -14,9 +14,6 @@ def pytest_configure(): from autoconf import conf - - - class PlotPatch: def __init__(self): self.paths = []