Feature/remove convolver#166
Conversation
| self.psf = psf | ||
|
|
||
| if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0: | ||
| raise exc.KernelException("Kernel2D Kernel2D must be odd") |
There was a problem hiding this comment.
I have seen this before but never really understood, why do you want the PSF kernel to be odd?
There was a problem hiding this comment.
For real space convolution, even x even kernels can induce a half pixel shift depending on where you centre it. This meant that things like the lens model centres had a half pixel shift, in the end, I think.
Alternatively, you could sample the "peak" of the even x even PSF in the central 4 pixels, which means you don't fully sample its shape where it matters most.
This may not matter for an FFT? And certainly wont matter once we do PSF over sampling?
Its definitely a hack I am happy to see removed.
| slim_to_native = jnp.nonzero( | ||
| jnp.logical_not(image.mask.array), size=image.shape[0] | ||
| ) | ||
| slim_to_native_blurring = jnp.nonzero( | ||
| jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0] | ||
| ) | ||
|
|
||
| Raises | ||
| ------ | ||
| KernelException if either Kernel2D psf dimension is odd | ||
| expanded_array_native = jnp.zeros(image.mask.shape) | ||
|
|
||
| expanded_array_native = expanded_array_native.at[slim_to_native].set( | ||
| image.array | ||
| ) | ||
| expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set( | ||
| blurring_image.array | ||
| ) |
There was a problem hiding this comment.
Might be worth checking if jnp.where can ends up doing this faster after a jit? Or if it is the same speed, you can reduce this to two lines of code (just double check the lines below to make sure they do the same thing).
expanded_array_native = jnp.where(image.mask.array, 0, image.array)
expanded_array_native = jnp.where(blurring_image.mask.array, expanded_array_native, blurring_image.array)There was a problem hiding this comment.
I just double-checked this, and I think the index ordering when using jnp.where is not quite the same as what you have, so my suggestion needs to transpose the mask first:
expanded_array_native = jnp.where(image.mask.array.T, 0, image.array)
expanded_array_native = jnp.where(blurring_image.mask.array.T, expanded_array_native, blurring_image.array)There was a problem hiding this comment.
I'll give it a go!
At the moment my approach is to get all source code suppoting JAX (e.g. unit tests pass, successful grad of likelihood functions), but to focus on the actual optimization after. However, for bottleneck code (of which this function is one of the most important) I'm keeping a repo of profiling tests, including comparisons to the old numba code.
I'll work this suggesting into this PR before merging (assuming it passes the unit tests).
There was a problem hiding this comment.
And with more testing, I have found since we have a 2D mask and a 1D slim from the jnp.where does not work as intended. The special case of one item being True per row does work, but that is it. Keep the code as you have it.
There was a problem hiding this comment.
Cool, we may be able to produce images in 2D, perform convolution and then flatten.
Storing everything in 1D throughout the whole likelhood function was optimal for numba because we could store in memory a lot of arrays that did the mappings for us. But I don't think this is very JAX-y.
| expanded_array_native = jnp.zeros(mask.shape) | ||
|
|
||
| expanded_array_native = expanded_array_native.at[slim_to_native].set(image) |
There was a problem hiding this comment.
As above this can be done with (again, this should be check to make sure it is doing the same thing).
expanded_array_native = jnp.where(mask.array.T, 0, image)
The
Convolverobject mapped out the 2D convolution calculation for a 2D mask and 2D kernel, using the fact that for model-fitting both quantities were fixed and therefore the exact sequence of calculations required could be precomputed in memory.This object relies heavily on in-place memory manipulation and therefore is not suitable for JAX, therefore I have removed it.
All 2D convolutions are now performed using standard
jax.scipyconvolution methods on 2D arrays.The
Conolverobject worked on the contents of masked arrays mapped to 1D representation, and the current JAX implementation of convolutions requires mapping back and forth from 1D and 2D. This likely leads to a loss of performance and we may need to optimize these functions in the future.