Merged
Conversation
rhayes777
approved these changes
Apr 2, 2025
| if psf is not None and use_normalized_psf: | ||
| psf = Kernel2D.no_mask( | ||
| values=psf.native, pixel_scales=psf.pixel_scales, normalize=True | ||
| values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True |
Collaborator
There was a problem hiding this comment.
Maybe we should have a public array property for accessing the private array attribute?
I think np.array(psf.native) should work but I guess maybe that fails because of a JAX conflict?
| ------- | ||
| The central pixel coordinates of the data structure. | ||
| """ | ||
| return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) |
Collaborator
There was a problem hiding this comment.
You don't need brackets here
Comment on lines
-754
to
+755
| if use_jax: | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1.0, 1.0]) | ||
| grid_pixels_2d = ( | ||
| (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 | ||
| ).astype(int) | ||
| else: | ||
| grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2)) | ||
|
|
||
| for y in range(grid_scaled_2d.shape[0]): | ||
| for x in range(grid_scaled_2d.shape[1]): | ||
| grid_pixels_2d[y, x, 0] = int( | ||
| (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) + centres_scaled[0] + 0.5 | ||
| ) | ||
| grid_pixels_2d[y, x, 1] = int( | ||
| (grid_scaled_2d[y, x, 1] / pixel_scales[1]) + centres_scaled[1] + 0.5 | ||
| ) | ||
|
|
||
| return grid_pixels_2d | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1.0, 1.0]) | ||
| return ( | ||
| (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 | ||
| ).astype(int) |
Collaborator
There was a problem hiding this comment.
Nice so much easier to look at without the dual functionality
| mask=np.array(mask), | ||
| mask_index_array=self.mask_index_array, | ||
| kernel_2d=np.array(self.kernel.native[:, :]), | ||
| kernel_2d=self.kernel.native, |
Comment on lines
+365
to
+368
| # if use_jax: | ||
| # # while this makes it run, it is very, very slow | ||
| # sub_size = sub_size.at[i].set(sub_size_list[j]) | ||
| # else: |
| # # while this makes it run, it is very, very slow | ||
| # grid_slim = grid_slim.at[sub_index, 0].set( | ||
| # -( | ||
| # y_scaled |
Collaborator
There was a problem hiding this comment.
Same here. I guess you're removing the disjuncts but haven't got jax working in these bits quite yet?
| def test__is_all_false(): | ||
| mask = aa.Mask1D(mask=[False, False, False, False], pixel_scales=1.0) | ||
|
|
||
| assert mask.is_all_false is True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In this PR, I have begun to adapt the source code under the assumption that both
jaxandnumbaare installed and working.I have then been making whatever changes are necessary to make unit tests pass, with the majority of updates making calculations (which I believe are not used in a likelihood function) use normal ndarrays and often
numba.I have been doing this alongside the
jax.gradof an autogalaxy likelihood function, and along the way caught a few issues where I need to be careful which parts of the code assume only normalnumpyand which parts require careful treatment of whether an array is annp.ndarrayorjnp.ndarray.This has led to a few
if isinstance(array, jnp.ndarray)statements, especially in parts of the code where both annp.ndarrayor ajnp.ndarraycould go through.Overall, it seems to be going smoothly and makes me optimistic we can reach a point where all unit tests pass and normal science runs are working without too much more time and effort, albeit a lot of work will then need to go towards cleaning up the code down the line :).