diff --git a/autoarray/dataset/imaging/dataset.py b/autoarray/dataset/imaging/dataset.py index cc278b67a..970cae502 100644 --- a/autoarray/dataset/imaging/dataset.py +++ b/autoarray/dataset/imaging/dataset.py @@ -189,8 +189,9 @@ def __init__( self.psf = psf if psf is not None: - if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: - raise exc.KernelException("Kernel2D Kernel2D must be odd") + if not psf.use_fft: + if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: + raise exc.KernelException("Kernel2D Kernel2D must be odd") self.grids = GridsDataset( mask=self.data.mask, diff --git a/autoarray/mask/derive/mask_2d.py b/autoarray/mask/derive/mask_2d.py index aa28068e1..c9792436f 100644 --- a/autoarray/mask/derive/mask_2d.py +++ b/autoarray/mask/derive/mask_2d.py @@ -198,9 +198,6 @@ def blurring_from(self, kernel_shape_native: Tuple[int, int]) -> Mask2D: from autoarray.mask.mask_2d import Mask2D - if kernel_shape_native[0] % 2 == 0 or kernel_shape_native[1] % 2 == 0: - raise exc.MaskException("psf_size of exterior region must be odd") - blurring_mask = mask_2d_util.blurring_mask_2d_from( mask_2d=self.mask, kernel_shape_native=kernel_shape_native, diff --git a/autoarray/structures/arrays/kernel_2d.py b/autoarray/structures/arrays/kernel_2d.py index 11aeaf927..b9e55266f 100644 --- a/autoarray/structures/arrays/kernel_2d.py +++ b/autoarray/structures/arrays/kernel_2d.py @@ -723,14 +723,18 @@ def convolved_image_from( blurred_image_full = xp.fft.irfft2( fft_psf * fft_image_native, s=fft_shape, axes=(0, 1) ) + ky, kx = self.native.array.shape # (21, 21) + off_y = (ky - 1) // 2 + off_x = (kx - 1) // 2 - # Crop back to mask_shape - start_indices = tuple( - (full_size - out_size) // 2 - for full_size, out_size in zip(full_shape, mask_shape) + blurred_image_full = xp.roll( + blurred_image_full, shift=(-off_y, -off_x), axis=(0, 1) ) + + start_indices = (off_y, off_x) + blurred_image_native = jax.lax.dynamic_slice( - blurred_image_full, start_indices, mask_shape + blurred_image_full, start_indices, image.mask.shape ) # Return slim form; optionally cast for downstream stability @@ -806,6 +810,10 @@ def convolved_mapping_matrix_from( ndarray of shape (N_pix, N_src) Convolved mapping matrix in slim form. """ + # ------------------------------------------------------------------------- + # NumPy path unchanged + # ------------------------------------------------------------------------- + # ------------------------------------------------------------------------- # NumPy path unchanged # ------------------------------------------------------------------------- @@ -835,34 +843,24 @@ def convolved_mapping_matrix_from( import jax.numpy as jnp # ------------------------------------------------------------------------- - # Validate cached FFT shapes / state + # Cached FFT shapes/state (REQUIRED) # ------------------------------------------------------------------------- if self.fft_shape is None: - full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=mask) raise ValueError( - f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n" - f"Expected mapping matrix padded to match FFT shape of PSF.\n" - f"PSF fft_shape: {fft_shape}, mask shape: {mask.shape}, " - f"mapping_matrix shape: {getattr(mapping_matrix, 'shape', 'unknown')}." + "FFT convolution requires precomputed FFT shapes on the PSF." ) - else: - fft_shape = self.fft_shape - full_shape = self.full_shape - mask_shape = self.mask_shape - fft_psf_mapping = self.fft_psf_mapping + + fft_shape = self.fft_shape + fft_psf_mapping = self.fft_psf_mapping # ------------------------------------------------------------------------- - # Mixed precision dtypes (JAX only) + # Mixed precision handling # ------------------------------------------------------------------------- fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128 - - # Ensure PSF FFT dtype matches the FFT path fft_psf_mapping = jnp.asarray(fft_psf_mapping, dtype=fft_complex_dtype) # ------------------------------------------------------------------------- - # Build native cube in the FFT dtype (THIS IS THE KEY) - # This relies on mapping_matrix_native_from honoring the use_mixed_precision - # kwarg when constructing the native mapping matrix. + # Build native cube on the *native mask grid* # ------------------------------------------------------------------------- mapping_matrix_native = self.mapping_matrix_native_from( mapping_matrix=mapping_matrix, @@ -872,6 +870,7 @@ def convolved_mapping_matrix_from( use_mixed_precision=use_mixed_precision, xp=xp, ) + # shape: (ny_native, nx_native, n_src) # ------------------------------------------------------------------------- # FFT convolution @@ -879,6 +878,7 @@ def convolved_mapping_matrix_from( fft_mapping_matrix_native = xp.fft.rfft2( mapping_matrix_native, s=fft_shape, axes=(0, 1) ) + blurred_mapping_matrix_full = xp.fft.irfft2( fft_psf_mapping * fft_mapping_matrix_native, s=fft_shape, @@ -886,21 +886,35 @@ def convolved_mapping_matrix_from( ) # ------------------------------------------------------------------------- - # Crop back to mask-shape + # APPLY SAME FIX AS convolved_image_from # ------------------------------------------------------------------------- - start_indices = tuple( - (full_size - out_size) // 2 - for full_size, out_size in zip(full_shape, mask_shape) - ) + (0,) - out_shape_full = mask_shape + (blurred_mapping_matrix_full.shape[2],) + ky, kx = self.native.array.shape + off_y = (ky - 1) // 2 + off_x = (kx - 1) // 2 + + blurred_mapping_matrix_full = xp.roll( + blurred_mapping_matrix_full, + shift=(-off_y, -off_x), + axis=(0, 1), + ) + + # ------------------------------------------------------------------------- + # Extract native grid (same as image path) + # ------------------------------------------------------------------------- + native_shape = mask.shape + start_indices = (off_y, off_x, 0) + + out_shape = native_shape + (blurred_mapping_matrix_full.shape[2],) blurred_mapping_matrix_native = jax.lax.dynamic_slice( blurred_mapping_matrix_full, start_indices, - out_shape_full, + out_shape, ) - # Return slim form + # ------------------------------------------------------------------------- + # Slim using ORIGINAL mask indices (same grid) + # ------------------------------------------------------------------------- blurred_slim = blurred_mapping_matrix_native[mask.slim_to_native_tuple] return blurred_slim