Conversation
There was a problem hiding this comment.
Pull request overview
This pull request addresses FFT convolution alignment issues in the autoarray package, specifically fixing how regions are extracted after FFT convolution to ensure proper centering and alignment with the native mask grid. The changes update both image and mapping matrix convolution methods to use a consistent roll-and-extract approach for proper PSF centering, and relax PSF dimension validation to allow even-sized PSFs when using FFT-based convolution.
Changes:
- Fixed FFT convolution cropping logic by implementing explicit offset calculation with
xp.rollfor proper centering, followed byjax.lax.dynamic_sliceextraction - Relaxed PSF shape validation to only require odd dimensions for non-FFT PSFs
- Removed redundant kernel shape validation from blurring mask computation
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| autoarray/structures/arrays/kernel_2d.py | Updated convolved_image_from and convolved_mapping_matrix_from to use roll-based centering and corrected extraction sizes; improved code comments |
| autoarray/mask/derive/mask_2d.py | Removed odd-dimension validation for kernel shapes in blurring mask computation |
| autoarray/dataset/imaging/dataset.py | Modified PSF validation to only enforce odd dimensions when psf.use_fft is False |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| blurred_image_full = xp.fft.irfft2( | ||
| fft_psf * fft_image_native, s=fft_shape, axes=(0, 1) | ||
| ) | ||
| ky, kx = self.native.array.shape # (21, 21) |
There was a problem hiding this comment.
The inline comment # (21, 21) appears to be an example value for demonstration purposes, but it may be unclear to future readers that this is not a constant value. Consider clarifying this comment to indicate it's an example, e.g., # e.g., (21, 21) or removing it if it's not needed.
| ky, kx = self.native.array.shape # (21, 21) | |
| ky, kx = self.native.array.shape # e.g., (21, 21) |
| # ------------------------------------------------------------------------- | ||
| # NumPy path unchanged | ||
| # ------------------------------------------------------------------------- |
There was a problem hiding this comment.
Duplicate comment section detected. Lines 813-815 contain the same comment "NumPy path unchanged" as lines 817-819. One of these duplicate comment blocks should be removed to improve code clarity.
| # ------------------------------------------------------------------------- | |
| # NumPy path unchanged | |
| # ------------------------------------------------------------------------- |
| blurred_image_native = jax.lax.dynamic_slice( | ||
| blurred_image_full, start_indices, mask_shape | ||
| blurred_image_full, start_indices, image.mask.shape | ||
| ) |
There was a problem hiding this comment.
Potential bug when self.fft_shape is None: After resizing the image to fft_shape at line 686, image.mask.shape becomes fft_shape. At line 737, the code tries to extract a region of size image.mask.shape (which is fft_shape) from blurred_image_full starting at (off_y, off_x). Since blurred_image_full has shape fft_shape and off_y, off_x > 0 for kernels with size > 1, this extraction would exceed the array bounds. The old code used mask_shape (computed at line 675 before the resize) for the extraction size, which was correct. Consider saving mask_shape before the resize and using it for extraction instead of image.mask.shape.
This pull request addresses improvements and bug fixes related to FFT convolution and cropping logic for image and mapping matrix operations in the
autoarraypackage. The main focus is on correcting the alignment and extraction of regions after FFT convolution, ensuring consistency between image and mapping matrix processing, and clarifying the handling of PSF and mask shapes.This bug may impact existing PyAutoLens-JAX fits, albeit it does not crop up for all fits and depends on the shapes of the input data and PSF.
FFT Convolution and Cropping Logic Fixes
convolved_image_fromandconvolved_mapping_matrix_frommethods inkernel_2d.pyto use explicit offset calculation andxp.rollfor proper centering after FFT convolution, followed by extraction usingjax.lax.dynamic_slice. This ensures the output is correctly aligned with the native mask grid. [1] [2]PSF and Mask Shape Validation
dataset.pyto only require odd dimensions whenpsf.use_fftis False, allowing more flexibility for FFT-based PSFs.mask_2d.py, as this is now handled elsewhere or only required for non-FFT PSFs.Code Clarity and Robustness
convolved_mapping_matrix_from, simplifying error messages and ensuring precomputed shapes are required for FFT convolution.convolved_mapping_matrix_fromto highlight mixed precision handling and the construction of native cubes on the mask grid. [1] [2]These changes collectively improve the reliability and correctness of FFT-based convolution operations in the codebase.