Skip to content

Feature/convolve fft#189

Merged
Jammy2211 merged 20 commits intomainfrom
feature/convolve_fft
Oct 14, 2025
Merged

Feature/convolve fft#189
Jammy2211 merged 20 commits intomainfrom
feature/convolve_fft

Conversation

@Jammy2211
Copy link
Owner

FFT and Real-Space Convolution Refactor

This PR refactors Kernel2D to support both FFT- and real-space convolution paths using JAX FFTs for accelerated computation.

FFT Convolution (default)

  • Uses jax.numpy.fft.rfft2 / irfft2 for efficient 2D convolution.

  • Requires precomputed padding shapes (fft_shape, full_shape, mask_shape) to ensure linear convolution (avoiding wrap-around).

  • Raises clear errors if shapes are inconsistent, with diagnostics on expected vs actual shapes.

  • Much faster for large kernels (>~5×5).

Real-Space Convolution (fallback)

  • Uses jax.scipy.signal.convolve with explicit masking.

  • Slower but avoids padding, useful for tests and small kernels.

  • Available for both images and mapping matrices.

Mapping Matrices

  • Convolution of mapping matrices now supported via both FFT and real space.

  • Blurring matrices are handled consistently in both modes.

Other improvements

  • Clearer warnings when blurring contributions are omitted.

  • Improved error messages and logging with full shape diagnostics.

  • Comprehensive docstrings explaining FFT padding, trade-offs, and usage.

This unifies convolution logic under JAX, making operations GPU/TPU compatible and more performant while preserving correctness for masked images.

@Jammy2211 Jammy2211 merged commit 15e8cf2 into main Oct 14, 2025
0 of 8 checks passed
@Jammy2211 Jammy2211 deleted the feature/convolve_fft branch November 30, 2025 16:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant