Skip to content

Feature/jax and numba#165

Merged
Jammy2211 merged 34 commits intofeature/jax_wrapperfrom
feature/jax_and_numba
Apr 2, 2025
Merged

Feature/jax and numba#165
Jammy2211 merged 34 commits intofeature/jax_wrapperfrom
feature/jax_and_numba

Conversation

@Jammy2211
Copy link
Owner

In this PR, I have begun to adapt the source code under the assumption that both jax and numba are installed and working.

I have then been making whatever changes are necessary to make unit tests pass, with the majority of updates making calculations (which I believe are not used in a likelihood function) use normal ndarrays and often numba.

I have been doing this alongside the jax.grad of an autogalaxy likelihood function, and along the way caught a few issues where I need to be careful which parts of the code assume only normal numpy and which parts require careful treatment of whether an array is an np.ndarray or jnp.ndarray.

This has led to a few if isinstance(array, jnp.ndarray) statements, especially in parts of the code where both an np.ndarray or a jnp.ndarray could go through.

Overall, it seems to be going smoothly and makes me optimistic we can reach a point where all unit tests pass and normal science runs are working without too much more time and effort, albeit a lot of work will then need to go towards cleaning up the code down the line :).

@Jammy2211 Jammy2211 requested review from CKrawczyk and rhayes777 April 1, 2025 20:01
if psf is not None and use_normalized_psf:
psf = Kernel2D.no_mask(
values=psf.native, pixel_scales=psf.pixel_scales, normalize=True
values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have a public array property for accessing the private array attribute?

I think np.array(psf.native) should work but I guess maybe that fails because of a JAX conflict?

-------
The central pixel coordinates of the data structure.
"""
return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need brackets here

Comment on lines -754 to +755
if use_jax:
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1.0, 1.0])
grid_pixels_2d = (
(sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5
).astype(int)
else:
grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2))

for y in range(grid_scaled_2d.shape[0]):
for x in range(grid_scaled_2d.shape[1]):
grid_pixels_2d[y, x, 0] = int(
(-grid_scaled_2d[y, x, 0] / pixel_scales[0]) + centres_scaled[0] + 0.5
)
grid_pixels_2d[y, x, 1] = int(
(grid_scaled_2d[y, x, 1] / pixel_scales[1]) + centres_scaled[1] + 0.5
)

return grid_pixels_2d
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1.0, 1.0])
return (
(sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5
).astype(int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice so much easier to look at without the dual functionality

mask=np.array(mask),
mask_index_array=self.mask_index_array,
kernel_2d=np.array(self.kernel.native[:, :]),
kernel_2d=self.kernel.native,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha that's suprising?

Comment on lines +365 to +368
# if use_jax:
# # while this makes it run, it is very, very slow
# sub_size = sub_size.at[i].set(sub_size_list[j])
# else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented code?

# # while this makes it run, it is very, very slow
# grid_slim = grid_slim.at[sub_index, 0].set(
# -(
# y_scaled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. I guess you're removing the disjuncts but haven't got jax working in these bits quite yet?

def test__is_all_false():
mask = aa.Mask1D(mask=[False, False, False, False], pixel_scales=1.0)

assert mask.is_all_false is True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is is correct

@Jammy2211 Jammy2211 merged commit d0c324b into feature/jax_wrapper Apr 2, 2025
0 of 8 checks passed
@Jammy2211 Jammy2211 deleted the feature/jax_and_numba branch June 24, 2025 13:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants