Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b317010
fix mask tests by using ==
Jammy2211 Mar 31, 2025
e1899b4
fix more tests using ==
Jammy2211 Mar 31, 2025
cb5f13b
fixes to get basic func_grad to work
Jammy2211 Mar 31, 2025
15bf0db
progress stopped at convolver
Jammy2211 Apr 1, 2025
d3649ff
updated grid_2d_slim_via_mask_from to be JAX implementation
Jammy2211 Apr 1, 2025
adf5ead
remove numba from grid_2d_centre_from
Jammy2211 Apr 1, 2025
31cdd33
remove numba from pixel_coordinates_2d_from -> fixes is circular
Jammy2211 Apr 1, 2025
ff1e811
fixing grid_2d_slim_over_sampled_via_mask_from to use numba
Jammy2211 Apr 1, 2025
b322a3f
removed use of use_jax in one function
Jammy2211 Apr 1, 2025
9e3c76c
grid_pixels_2d_slim_from now uses native numpy, could support JAX
Jammy2211 Apr 1, 2025
ead617e
grid_pixel_centres_2d_slim_from, could support JAX
Jammy2211 Apr 1, 2025
2769aaf
grid_pixel_indexes_2d_slim_from, could support JAX
Jammy2211 Apr 1, 2025
b2ba6bd
grid_scaled_2d_slim_from, could support JAX
Jammy2211 Apr 1, 2025
0532104
grid_pixel_centres_2d_from, could support JAX
Jammy2211 Apr 1, 2025
d90ff2e
explciit separate imports
Jammy2211 Apr 1, 2025
59b21e9
fix unit test in test__transform_2d_grid_from_reference_frame
Jammy2211 Apr 1, 2025
c453a3c
use absolute tolerance to fix geomtry util unit tests
Jammy2211 Apr 1, 2025
0c4bb30
fix test__pixel_coordinates_2d_from
Jammy2211 Apr 1, 2025
d891947
cleaned up jax imports of array_2d_util to make more tests pass
Jammy2211 Apr 1, 2025
ea7aa9d
cleanup imports of grid_2d_util
Jammy2211 Apr 1, 2025
4014d03
convert methods in grid_2d_util assume ndarray
Jammy2211 Apr 1, 2025
075654f
more simlpifying of convert functions
Jammy2211 Apr 1, 2025
17817b8
mask derive fixed
Jammy2211 Apr 1, 2025
b76cc9a
another way to make hecks only use ndarray
Jammy2211 Apr 1, 2025
c9e275d
fixes which ensure grad works on real LH function
Jammy2211 Apr 1, 2025
70c0212
fix all uniform_2d unit tests
Jammy2211 Apr 1, 2025
c417511
fix all of kernel 2d
Jammy2211 Apr 1, 2025
db9cfb7
fix repr
Jammy2211 Apr 1, 2025
3cb3f76
remove relocate_to_radial_minimum test as all functionality is to be …
Jammy2211 Apr 1, 2025
467d1ea
fix Grid2D test_unifrom
Jammy2211 Apr 1, 2025
7751080
fix grid test_uniform_1d
Jammy2211 Apr 1, 2025
f4c3269
hammer hammer hammer
Jammy2211 Apr 1, 2025
8d2b338
fix over sampler test
Jammy2211 Apr 2, 2025
70843c0
mrge succcess
Jammy2211 Apr 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(

if psf is not None and use_normalized_psf:
psf = Kernel2D.no_mask(
values=psf.native, pixel_scales=psf.pixel_scales, normalize=True
values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have a public array property for accessing the private array attribute?

I think np.array(psf.native) should work but I guess maybe that fails because of a JAX conflict?

)

self.psf = psf
Expand Down Expand Up @@ -193,7 +193,7 @@ def convolver(self):
The convolver given the masked imaging data's mask and PSF.
"""

return Convolver(mask=self.mask, kernel=self.psf)
return Convolver(mask=self.mask, kernel=Kernel2D(values=self.psf._array, mask=self.psf.mask, header=self.psf.header))

@cached_property
def w_tilde(self):
Expand Down
6 changes: 3 additions & 3 deletions autoarray/fit/fit_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import wraps
import jax.numpy as np

from autoarray.numpy_wrapper import np
from autoarray.mask.abstract_mask import Mask

from autoarray import type as ty
Expand Down Expand Up @@ -83,7 +83,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float:
chi_squared_map
The chi-squared-map of values of the model-data fit to the dataset.
"""
return np.sum(chi_squared_map)
return np.sum(chi_squared_map._array)


def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
Expand All @@ -97,7 +97,7 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
noise_map
The masked noise-map of the dataset.
"""
return np.sum(np.log(2 * np.pi * noise_map**2.0))
return np.sum(np.log(2 * np.pi * noise_map._array**2.0))


def normalized_residual_map_complex_from(
Expand Down
3 changes: 2 additions & 1 deletion autoarray/geometry/geometry_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ def scaled_coordinates_2d_from(
-------
A 2D (y,x) pixel-value coordinate.
"""

return geometry_util.scaled_coordinates_2d_from(
pixel_coordinates_2d=pixel_coordinates_2d,
pixel_coordinates_2d=np.array(pixel_coordinates_2d),
shape_native=self.shape_native,
pixel_scales=self.pixel_scales,
origins=self.origin,
Expand Down
228 changes: 102 additions & 126 deletions autoarray/geometry/geometry_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import jax.numpy as jnp
import numpy as np
from typing import Tuple, Union
from autoarray.numpy_wrapper import np, use_jax


from autoarray import numba_util
from autoarray import type as ty
Expand Down Expand Up @@ -179,8 +181,69 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float]

return pixel_scales

@numba_util.jit()
def central_pixel_coordinates_2d_numba_from(
shape_native: Tuple[int, int],
) -> Tuple[float, float]:
"""
Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``)
from the shape of that data structure.

Examples of the central pixels are as follows:

- For a 3x3 image, the central pixel is pixel [1, 1].
- For a 4x4 image, the central pixel is [1.5, 1.5].

Parameters
----------
shape_native
The dimensions of the data structure, which can be in 1D, 2D or higher dimensions.

Returns
-------
The central pixel coordinates of the data structure.
"""
return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need brackets here


@numba_util.jit()
def central_scaled_coordinate_2d_numba_from(
shape_native: Tuple[int, int],
pixel_scales: ty.PixelScales,
origin: Tuple[float, float] = (0.0, 0.0),
) -> Tuple[float, float]:
"""
Returns the central scaled coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``)
from the shape of that data structure.

This is computed by using the data structure's shape and converting it to scaled units using an input
pixel-coordinates to scaled-coordinate conversion factor `pixel_scales`.

The origin of the scaled grid can also be input and moved from (0.0, 0.0).

Parameters
----------
shape_native
The 2D shape of the data structure whose central scaled coordinates are computed.
pixel_scales
The (y,x) scaled units to pixel units conversion factor of the 2D data structure.
origin
The (y,x) scaled units origin of the coordinate system the central scaled coordinate is computed on.

Returns
-------
The central coordinates of the 2D data structure in scaled units.
"""

central_pixel_coordinates = central_pixel_coordinates_2d_numba_from(
shape_native=shape_native
)

y_pixel = central_pixel_coordinates[0] + (origin[0] / pixel_scales[0])
x_pixel = central_pixel_coordinates[1] - (origin[1] / pixel_scales[1])

return (y_pixel, x_pixel)


def central_pixel_coordinates_2d_from(
shape_native: Tuple[int, int],
) -> Tuple[float, float]:
Expand All @@ -205,7 +268,6 @@ def central_pixel_coordinates_2d_from(
return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2)


@numba_util.jit()
def central_scaled_coordinate_2d_from(
shape_native: Tuple[int, int],
pixel_scales: ty.PixelScales,
Expand Down Expand Up @@ -234,7 +296,7 @@ def central_scaled_coordinate_2d_from(
The central coordinates of the 2D data structure in scaled units.
"""

central_pixel_coordinates = central_pixel_coordinates_2d_from(
central_pixel_coordinates = central_pixel_coordinates_2d_numba_from(
shape_native=shape_native
)

Expand All @@ -243,8 +305,6 @@ def central_scaled_coordinate_2d_from(

return (y_pixel, x_pixel)


@numba_util.jit()
def pixel_coordinates_2d_from(
scaled_coordinates_2d: Tuple[float, float],
shape_native: Tuple[int, int],
Expand Down Expand Up @@ -352,7 +412,7 @@ def scaled_coordinates_2d_from(
origin=(0.0, 0.0)
)
"""
central_scaled_coordinates = central_scaled_coordinate_2d_from(
central_scaled_coordinates = central_scaled_coordinate_2d_numba_from(
shape_native=shape_native, pixel_scales=pixel_scales, origin=origins
)

Expand Down Expand Up @@ -382,18 +442,16 @@ def transform_grid_2d_to_reference_frame(
grid
The 2d grid of (y, x) coordinates which are transformed to a new reference frame.
"""
if use_jax:
shifted_grid_2d = grid_2d.array - np.array(centre)
else:
shifted_grid_2d = grid_2d - np.array(centre)
radius = np.sqrt(np.sum(shifted_grid_2d**2.0, axis=1))
theta_coordinate_to_profile = np.arctan2(
shifted_grid_2d = np.array(grid_2d) - jnp.array(centre)

radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1))
theta_coordinate_to_profile = jnp.arctan2(
shifted_grid_2d[:, 0], shifted_grid_2d[:, 1]
) - np.radians(angle)
return np.vstack(
) - jnp.radians(angle)
return jnp.vstack(
[
radius * np.sin(theta_coordinate_to_profile),
radius * np.cos(theta_coordinate_to_profile),
radius * jnp.sin(theta_coordinate_to_profile),
radius * jnp.cos(theta_coordinate_to_profile),
]
).T

Expand Down Expand Up @@ -435,7 +493,6 @@ def transform_grid_2d_from_reference_frame(
return np.vstack((y, x)).T


@numba_util.jit()
def grid_pixels_2d_slim_from(
grid_scaled_2d_slim: np.ndarray,
shape_native: Tuple[int, int],
Expand Down Expand Up @@ -476,33 +533,15 @@ def grid_pixels_2d_slim_from(
grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_scaled_2d_slim=grid_scaled_2d_slim, shape=(2,2),
pixel_scales=(0.5, 0.5), origin=(0.0, 0.0))
"""

centres_scaled = central_scaled_coordinate_2d_from(
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
)
if use_jax:
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1, 1])
return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5
else:
grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2))
for slim_index in range(grid_scaled_2d_slim.shape[0]):
grid_pixels_2d_slim[slim_index, 0] = (
(-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0])
+ centres_scaled[0]
+ 0.5
)
grid_pixels_2d_slim[slim_index, 1] = (
(grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1])
+ centres_scaled[1]
+ 0.5
)

return grid_pixels_2d_slim
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1, 1])
return (sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5


@numba_util.jit()
def grid_pixel_centres_2d_slim_from(
grid_scaled_2d_slim: np.ndarray,
shape_native: Tuple[int, int],
Expand Down Expand Up @@ -547,32 +586,14 @@ def grid_pixel_centres_2d_slim_from(
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
)

if use_jax:
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1.0, 1.0])
grid_pixels_2d_slim = (
(sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5
).astype(int)
else:
grid_pixels_2d_slim = np.zeros((grid_scaled_2d_slim.shape[0], 2))

for slim_index in range(grid_scaled_2d_slim.shape[0]):
grid_pixels_2d_slim[slim_index, 0] = int(
(-grid_scaled_2d_slim[slim_index, 0] / pixel_scales[0])
+ centres_scaled[0]
+ 0.5
)
grid_pixels_2d_slim[slim_index, 1] = int(
(grid_scaled_2d_slim[slim_index, 1] / pixel_scales[1])
+ centres_scaled[1]
+ 0.5
)

return grid_pixels_2d_slim
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1.0, 1.0])
return (
(sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5
).astype(int)


@numba_util.jit()
def grid_pixel_indexes_2d_slim_from(
grid_scaled_2d_slim: np.ndarray,
shape_native: Tuple[int, int],
Expand Down Expand Up @@ -625,25 +646,13 @@ def grid_pixel_indexes_2d_slim_from(
origin=origin,
)

if use_jax:
grid_pixel_indexes_2d_slim = (
(grid_pixels_2d_slim * np.array([shape_native[1], 1]))
.sum(axis=1)
.astype(int)
)
else:
grid_pixel_indexes_2d_slim = np.zeros(grid_pixels_2d_slim.shape[0])

for slim_index in range(grid_pixels_2d_slim.shape[0]):
grid_pixel_indexes_2d_slim[slim_index] = int(
grid_pixels_2d_slim[slim_index, 0] * shape_native[1]
+ grid_pixels_2d_slim[slim_index, 1]
)

return grid_pixel_indexes_2d_slim
return (
(grid_pixels_2d_slim * np.array([shape_native[1], 1]))
.sum(axis=1)
.astype(int)
)


@numba_util.jit()
def grid_scaled_2d_slim_from(
grid_pixels_2d_slim: np.ndarray,
shape_native: Tuple[int, int],
Expand Down Expand Up @@ -682,33 +691,18 @@ def grid_scaled_2d_slim_from(
grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_pixels_2d_slim=grid_pixels_2d_slim, shape=(2,2),
pixel_scales=(0.5, 0.5), origin=(0.0, 0.0))
"""

centres_scaled = central_scaled_coordinate_2d_from(
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
)
if use_jax:
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1, 1])
grid_scaled_2d_slim = (
(grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign
)
else:
grid_scaled_2d_slim = np.zeros((grid_pixels_2d_slim.shape[0], 2))

for slim_index in range(grid_scaled_2d_slim.shape[0]):
grid_scaled_2d_slim[slim_index, 0] = (
-(grid_pixels_2d_slim[slim_index, 0] - centres_scaled[0] - 0.5)
* pixel_scales[0]
)
grid_scaled_2d_slim[slim_index, 1] = (
grid_pixels_2d_slim[slim_index, 1] - centres_scaled[1] - 0.5
) * pixel_scales[1]

return grid_scaled_2d_slim

centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1, 1])
return (
(grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign
)


@numba_util.jit()
def grid_pixel_centres_2d_from(
grid_scaled_2d: np.ndarray,
shape_native: Tuple[int, int],
Expand Down Expand Up @@ -753,30 +747,12 @@ def grid_pixel_centres_2d_from(
shape_native=shape_native, pixel_scales=pixel_scales, origin=origin
)

if use_jax:
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1.0, 1.0])
grid_pixels_2d = (
(sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5
).astype(int)
else:
grid_pixels_2d = np.zeros((grid_scaled_2d.shape[0], grid_scaled_2d.shape[1], 2))

for y in range(grid_scaled_2d.shape[0]):
for x in range(grid_scaled_2d.shape[1]):
grid_pixels_2d[y, x, 0] = int(
(-grid_scaled_2d[y, x, 0] / pixel_scales[0])
+ centres_scaled[0]
+ 0.5
)
grid_pixels_2d[y, x, 1] = int(
(grid_scaled_2d[y, x, 1] / pixel_scales[1])
+ centres_scaled[1]
+ 0.5
)

return grid_pixels_2d
centres_scaled = np.array(centres_scaled)
pixel_scales = np.array(pixel_scales)
sign = np.array([-1.0, 1.0])
return (
(sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5
).astype(int)


def extent_symmetric_from(
Expand Down
Loading
Loading