-
Notifications
You must be signed in to change notification settings - Fork 7
Feature/jax and numba #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
b317010
fix mask tests by using ==
Jammy2211 e1899b4
fix more tests using ==
Jammy2211 cb5f13b
fixes to get basic func_grad to work
Jammy2211 15bf0db
progress stopped at convolver
Jammy2211 d3649ff
updated grid_2d_slim_via_mask_from to be JAX implementation
Jammy2211 adf5ead
remove numba from grid_2d_centre_from
Jammy2211 31cdd33
remove numba from pixel_coordinates_2d_from -> fixes is circular
Jammy2211 ff1e811
fixing grid_2d_slim_over_sampled_via_mask_from to use numba
Jammy2211 b322a3f
removed use of use_jax in one function
Jammy2211 9e3c76c
grid_pixels_2d_slim_from now uses native numpy, could support JAX
Jammy2211 ead617e
grid_pixel_centres_2d_slim_from, could support JAX
Jammy2211 2769aaf
grid_pixel_indexes_2d_slim_from, could support JAX
Jammy2211 b2ba6bd
grid_scaled_2d_slim_from, could support JAX
Jammy2211 0532104
grid_pixel_centres_2d_from, could support JAX
Jammy2211 d90ff2e
explciit separate imports
Jammy2211 59b21e9
fix unit test in test__transform_2d_grid_from_reference_frame
Jammy2211 c453a3c
use absolute tolerance to fix geomtry util unit tests
Jammy2211 0c4bb30
fix test__pixel_coordinates_2d_from
Jammy2211 d891947
cleaned up jax imports of array_2d_util to make more tests pass
Jammy2211 ea7aa9d
cleanup imports of grid_2d_util
Jammy2211 4014d03
convert methods in grid_2d_util assume ndarray
Jammy2211 075654f
more simlpifying of convert functions
Jammy2211 17817b8
mask derive fixed
Jammy2211 b76cc9a
another way to make hecks only use ndarray
Jammy2211 c9e275d
fixes which ensure grad works on real LH function
Jammy2211 70c0212
fix all uniform_2d unit tests
Jammy2211 c417511
fix all of kernel 2d
Jammy2211 db9cfb7
fix repr
Jammy2211 3cb3f76
remove relocate_to_radial_minimum test as all functionality is to be …
Jammy2211 467d1ea
fix Grid2D test_unifrom
Jammy2211 7751080
fix grid test_uniform_1d
Jammy2211 f4c3269
hammer hammer hammer
Jammy2211 8d2b338
fix over sampler test
Jammy2211 70843c0
mrge succcess
Jammy2211 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| from typing import Tuple, Union | ||
| from autoarray.numpy_wrapper import np, use_jax | ||
|
|
||
|
|
||
| from autoarray import numba_util | ||
| from autoarray import type as ty | ||
|
|
@@ -179,8 +181,69 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float] | |
|
|
||
| return pixel_scales | ||
|
|
||
| @numba_util.jit() | ||
| def central_pixel_coordinates_2d_numba_from( | ||
| shape_native: Tuple[int, int], | ||
| ) -> Tuple[float, float]: | ||
| """ | ||
| Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) | ||
| from the shape of that data structure. | ||
|
|
||
| Examples of the central pixels are as follows: | ||
|
|
||
| - For a 3x3 image, the central pixel is pixel [1, 1]. | ||
| - For a 4x4 image, the central pixel is [1.5, 1.5]. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| shape_native | ||
| The dimensions of the data structure, which can be in 1D, 2D or higher dimensions. | ||
|
|
||
| Returns | ||
| ------- | ||
| The central pixel coordinates of the data structure. | ||
| """ | ||
| return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't need brackets here |
||
|
|
||
| @numba_util.jit() | ||
| def central_scaled_coordinate_2d_numba_from( | ||
| shape_native: Tuple[int, int], | ||
| pixel_scales: ty.PixelScales, | ||
| origin: Tuple[float, float] = (0.0, 0.0), | ||
| ) -> Tuple[float, float]: | ||
| """ | ||
| Returns the central scaled coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``) | ||
| from the shape of that data structure. | ||
|
|
||
| This is computed by using the data structure's shape and converting it to scaled units using an input | ||
| pixel-coordinates to scaled-coordinate conversion factor `pixel_scales`. | ||
|
|
||
| The origin of the scaled grid can also be input and moved from (0.0, 0.0). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| shape_native | ||
| The 2D shape of the data structure whose central scaled coordinates are computed. | ||
| pixel_scales | ||
| The (y,x) scaled units to pixel units conversion factor of the 2D data structure. | ||
| origin | ||
| The (y,x) scaled units origin of the coordinate system the central scaled coordinate is computed on. | ||
|
|
||
| Returns | ||
| ------- | ||
| The central coordinates of the 2D data structure in scaled units. | ||
| """ | ||
|
|
||
| central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( | ||
| shape_native=shape_native | ||
| ) | ||
|
|
||
| y_pixel = central_pixel_coordinates[0] + (origin[0] / pixel_scales[0]) | ||
| x_pixel = central_pixel_coordinates[1] - (origin[1] / pixel_scales[1]) | ||
|
|
||
| return (y_pixel, x_pixel) | ||
|
|
||
|
|
||
| def central_pixel_coordinates_2d_from( | ||
| shape_native: Tuple[int, int], | ||
| ) -> Tuple[float, float]: | ||
|
|
@@ -205,7 +268,6 @@ def central_pixel_coordinates_2d_from( | |
| return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2) | ||
|
|
||
|
|
||
| @numba_util.jit() | ||
| def central_scaled_coordinate_2d_from( | ||
| shape_native: Tuple[int, int], | ||
| pixel_scales: ty.PixelScales, | ||
|
|
@@ -234,7 +296,7 @@ def central_scaled_coordinate_2d_from( | |
| The central coordinates of the 2D data structure in scaled units. | ||
| """ | ||
|
|
||
| central_pixel_coordinates = central_pixel_coordinates_2d_from( | ||
| central_pixel_coordinates = central_pixel_coordinates_2d_numba_from( | ||
| shape_native=shape_native | ||
| ) | ||
|
|
||
|
|
@@ -243,8 +305,6 @@ def central_scaled_coordinate_2d_from( | |
|
|
||
| return (y_pixel, x_pixel) | ||
|
|
||
|
|
||
| @numba_util.jit() | ||
| def pixel_coordinates_2d_from( | ||
| scaled_coordinates_2d: Tuple[float, float], | ||
| shape_native: Tuple[int, int], | ||
|
|
@@ -352,7 +412,7 @@ def scaled_coordinates_2d_from( | |
| origin=(0.0, 0.0) | ||
| ) | ||
| """ | ||
| central_scaled_coordinates = central_scaled_coordinate_2d_from( | ||
| central_scaled_coordinates = central_scaled_coordinate_2d_numba_from( | ||
| shape_native=shape_native, pixel_scales=pixel_scales, origin=origins | ||
| ) | ||
|
|
||
|
|
@@ -382,18 +442,16 @@ def transform_grid_2d_to_reference_frame( | |
| grid | ||
| The 2d grid of (y, x) coordinates which are transformed to a new reference frame. | ||
| """ | ||
| if use_jax: | ||
| shifted_grid_2d = grid_2d.array - np.array(centre) | ||
| else: | ||
| shifted_grid_2d = grid_2d - np.array(centre) | ||
| radius = np.sqrt(np.sum(shifted_grid_2d**2.0, axis=1)) | ||
| theta_coordinate_to_profile = np.arctan2( | ||
| shifted_grid_2d = np.array(grid_2d) - jnp.array(centre) | ||
|
|
||
| radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1)) | ||
| theta_coordinate_to_profile = jnp.arctan2( | ||
| shifted_grid_2d[:, 0], shifted_grid_2d[:, 1] | ||
| ) - np.radians(angle) | ||
| return np.vstack( | ||
| ) - jnp.radians(angle) | ||
| return jnp.vstack( | ||
| [ | ||
| radius * np.sin(theta_coordinate_to_profile), | ||
| radius * np.cos(theta_coordinate_to_profile), | ||
| radius * jnp.sin(theta_coordinate_to_profile), | ||
| radius * jnp.cos(theta_coordinate_to_profile), | ||
| ] | ||
| ).T | ||
|
|
||
|
|
@@ -435,7 +493,6 @@ def transform_grid_2d_from_reference_frame( | |
| return np.vstack((y, x)).T | ||
|
|
||
|
|
||
| @numba_util.jit() | ||
| def grid_pixels_2d_slim_from( | ||
| grid_scaled_2d_slim: np.ndarray, | ||
| shape_native: Tuple[int, int], | ||
|
|
@@ -476,33 +533,15 @@ def grid_pixels_2d_slim_from( | |
| grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_scaled_2d_slim=grid_scaled_2d_slim, shape=(2,2), | ||
| pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) | ||
| """ | ||
|
|
||
| centres_scaled = central_scaled_coordinate_2d_from( | ||
| shape_native=shape_native, pixel_scales=pixel_scales, origin=origin | ||
| ) | ||
| if use_jax: | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1, 1]) | ||
| return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 | ||
| else: | ||
| grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) | ||
| for slim_index in range(grid_scaled_2d_slim.shape[0]): | ||
| grid_pixels_2d_slim[slim_index, 0] = ( | ||
| (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) | ||
| + centres_scaled[0] | ||
| + 0.5 | ||
| ) | ||
| grid_pixels_2d_slim[slim_index, 1] = ( | ||
| (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) | ||
| + centres_scaled[1] | ||
| + 0.5 | ||
| ) | ||
|
|
||
| return grid_pixels_2d_slim | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1, 1]) | ||
| return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 | ||
|
|
||
|
|
||
| @numba_util.jit() | ||
| def grid_pixel_centres_2d_slim_from( | ||
| grid_scaled_2d_slim: np.ndarray, | ||
| shape_native: Tuple[int, int], | ||
|
|
@@ -547,32 +586,14 @@ def grid_pixel_centres_2d_slim_from( | |
| shape_native=shape_native, pixel_scales=pixel_scales, origin=origin | ||
| ) | ||
|
|
||
| if use_jax: | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1.0, 1.0]) | ||
| grid_pixels_2d_slim = ( | ||
| (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 | ||
| ).astype(int) | ||
| else: | ||
| grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2)) | ||
|
|
||
| for slim_index in range(grid_scaled_2d_slim.shape[0]): | ||
| grid_pixels_2d_slim[slim_index, 0] = int( | ||
| (-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0]) | ||
| + centres_scaled[0] | ||
| + 0.5 | ||
| ) | ||
| grid_pixels_2d_slim[slim_index, 1] = int( | ||
| (grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1]) | ||
| + centres_scaled[1] | ||
| + 0.5 | ||
| ) | ||
|
|
||
| return grid_pixels_2d_slim | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1.0, 1.0]) | ||
| return ( | ||
| (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5 | ||
| ).astype(int) | ||
|
|
||
|
|
||
| @numba_util.jit() | ||
| def grid_pixel_indexes_2d_slim_from( | ||
| grid_scaled_2d_slim: np.ndarray, | ||
| shape_native: Tuple[int, int], | ||
|
|
@@ -625,25 +646,13 @@ def grid_pixel_indexes_2d_slim_from( | |
| origin=origin, | ||
| ) | ||
|
|
||
| if use_jax: | ||
| grid_pixel_indexes_2d_slim = ( | ||
| (grid_pixels_2d_slim * np.array([shape_native[1], 1])) | ||
| .sum(axis=1) | ||
| .astype(int) | ||
| ) | ||
| else: | ||
| grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0]) | ||
|
|
||
| for slim_index in range(grid_pixels_2d_slim.shape[0]): | ||
| grid_pixel_indexes_2d_slim[slim_index] = int( | ||
| grid_pixels_2d_slim[slim_index, 0] * shape_native[1] | ||
| + grid_pixels_2d_slim[slim_index, 1] | ||
| ) | ||
|
|
||
| return grid_pixel_indexes_2d_slim | ||
| return ( | ||
| (grid_pixels_2d_slim * np.array([shape_native[1], 1])) | ||
| .sum(axis=1) | ||
| .astype(int) | ||
| ) | ||
|
|
||
|
|
||
| @numba_util.jit() | ||
| def grid_scaled_2d_slim_from( | ||
| grid_pixels_2d_slim: np.ndarray, | ||
| shape_native: Tuple[int, int], | ||
|
|
@@ -682,33 +691,18 @@ def grid_scaled_2d_slim_from( | |
| grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_pixels_2d_slim=grid_pixels_2d_slim, shape=(2,2), | ||
| pixel_scales=(0.5, 0.5), origin=(0.0, 0.0)) | ||
| """ | ||
|
|
||
| centres_scaled = central_scaled_coordinate_2d_from( | ||
| shape_native=shape_native, pixel_scales=pixel_scales, origin=origin | ||
| ) | ||
| if use_jax: | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1, 1]) | ||
| grid_scaled_2d_slim = ( | ||
| (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign | ||
| ) | ||
| else: | ||
| grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2)) | ||
|
|
||
| for slim_index in range(grid_scaled_2d_slim.shape[0]): | ||
| grid_scaled_2d_slim[slim_index, 0] = ( | ||
| -(grid_pixels_2d_slim[slim_index, 0] - centres_scaled[0] - 0.5) | ||
| * pixel_scales[0] | ||
| ) | ||
| grid_scaled_2d_slim[slim_index, 1] = ( | ||
| grid_pixels_2d_slim[slim_index, 1] - centres_scaled[1] - 0.5 | ||
| ) * pixel_scales[1] | ||
|
|
||
| return grid_scaled_2d_slim | ||
|
|
||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1, 1]) | ||
| return ( | ||
| (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign | ||
| ) | ||
|
|
||
|
|
||
| @numba_util.jit() | ||
| def grid_pixel_centres_2d_from( | ||
| grid_scaled_2d: np.ndarray, | ||
| shape_native: Tuple[int, int], | ||
|
|
@@ -753,30 +747,12 @@ def grid_pixel_centres_2d_from( | |
| shape_native=shape_native, pixel_scales=pixel_scales, origin=origin | ||
| ) | ||
|
|
||
| if use_jax: | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1.0, 1.0]) | ||
| grid_pixels_2d = ( | ||
| (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 | ||
| ).astype(int) | ||
| else: | ||
| grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2)) | ||
|
|
||
| for y in range(grid_scaled_2d.shape[0]): | ||
| for x in range(grid_scaled_2d.shape[1]): | ||
| grid_pixels_2d[y, x, 0] = int( | ||
| (-grid_scaled_2d[y, x, 0] / pixel_scales[0]) | ||
| + centres_scaled[0] | ||
| + 0.5 | ||
| ) | ||
| grid_pixels_2d[y, x, 1] = int( | ||
| (grid_scaled_2d[y, x, 1] / pixel_scales[1]) | ||
| + centres_scaled[1] | ||
| + 0.5 | ||
| ) | ||
|
|
||
| return grid_pixels_2d | ||
| centres_scaled = np.array(centres_scaled) | ||
| pixel_scales = np.array(pixel_scales) | ||
| sign = np.array([-1.0, 1.0]) | ||
| return ( | ||
| (sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5 | ||
| ).astype(int) | ||
|
|
||
|
|
||
| def extent_symmetric_from( | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should have a public array property for accessing the private array attribute?
I think np.array(psf.native) should work but I guess maybe that fails because of a JAX conflict?