Skip to content

Feature/remove convolver#166

Merged
Jammy2211 merged 11 commits intofeature/jax_wrapperfrom
feature/remove_convolver
Apr 3, 2025
Merged

Feature/remove convolver#166
Jammy2211 merged 11 commits intofeature/jax_wrapperfrom
feature/remove_convolver

Conversation

@Jammy2211
Copy link
Owner

The Convolver object mapped out the 2D convolution calculation for a 2D mask and 2D kernel, using the fact that for model-fitting both quantities were fixed and therefore the exact sequence of calculations required could be precomputed in memory.

This object relies heavily on in-place memory manipulation and therefore is not suitable for JAX, therefore I have removed it.

All 2D convolutions are now performed using standard jax.scipy convolution methods on 2D arrays.

The Conolver object worked on the contents of masked arrays mapped to 1D representation, and the current JAX implementation of convolutions requires mapping back and forth from 1D and 2D. This likely leads to a loss of performance and we may need to optimize these functions in the future.

@Jammy2211 Jammy2211 requested review from CKrawczyk and rhayes777 April 2, 2025 20:12
self.psf = psf

if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen this before but never really understood, why do you want the PSF kernel to be odd?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For real space convolution, even x even kernels can induce a half pixel shift depending on where you centre it. This meant that things like the lens model centres had a half pixel shift, in the end, I think.

Alternatively, you could sample the "peak" of the even x even PSF in the central 4 pixels, which means you don't fully sample its shape where it matters most.

This may not matter for an FFT? And certainly wont matter once we do PSF over sampling?

Its definitely a hack I am happy to see removed.

Comment on lines +505 to +519
slim_to_native = jnp.nonzero(
jnp.logical_not(image.mask.array), size=image.shape[0]
)
slim_to_native_blurring = jnp.nonzero(
jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0]
)

Raises
------
KernelException if either Kernel2D psf dimension is odd
expanded_array_native = jnp.zeros(image.mask.shape)

expanded_array_native = expanded_array_native.at[slim_to_native].set(
image.array
)
expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set(
blurring_image.array
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth checking if jnp.where can ends up doing this faster after a jit? Or if it is the same speed, you can reduce this to two lines of code (just double check the lines below to make sure they do the same thing).

expanded_array_native = jnp.where(image.mask.array, 0, image.array)
expanded_array_native = jnp.where(blurring_image.mask.array, expanded_array_native, blurring_image.array)

Copy link
Collaborator

@CKrawczyk CKrawczyk Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just double-checked this, and I think the index ordering when using jnp.where is not quite the same as what you have, so my suggestion needs to transpose the mask first:

expanded_array_native = jnp.where(image.mask.array.T, 0, image.array)
expanded_array_native = jnp.where(blurring_image.mask.array.T, expanded_array_native, blurring_image.array)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll give it a go!

At the moment my approach is to get all source code suppoting JAX (e.g. unit tests pass, successful grad of likelihood functions), but to focus on the actual optimization after. However, for bottleneck code (of which this function is one of the most important) I'm keeping a repo of profiling tests, including comparisons to the old numba code.

I'll work this suggesting into this PR before merging (assuming it passes the unit tests).

Copy link
Collaborator

@CKrawczyk CKrawczyk Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And with more testing, I have found since we have a 2D mask and a 1D slim from the jnp.where does not work as intended. The special case of one item being True per row does work, but that is it. Keep the code as you have it.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, we may be able to produce images in 2D, perform convolution and then flatten.

Storing everything in 1D throughout the whole likelhood function was optimal for numba because we could store in memory a lot of arrays that did the mappings for us. But I don't think this is very JAX-y.

Comment on lines +550 to +552
expanded_array_native = jnp.zeros(mask.shape)

expanded_array_native = expanded_array_native.at[slim_to_native].set(image)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above this can be done with (again, this should be check to make sure it is doing the same thing).

expanded_array_native = jnp.where(mask.array.T, 0, image)

@Jammy2211 Jammy2211 merged commit c9268b2 into feature/jax_wrapper Apr 3, 2025
0 of 8 checks passed
@Jammy2211 Jammy2211 deleted the feature/remove_convolver branch April 30, 2025 18:13
@Jammy2211 Jammy2211 restored the feature/remove_convolver branch April 30, 2025 18:14
@Jammy2211 Jammy2211 deleted the feature/remove_convolver branch June 24, 2025 13:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants