Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e819fa1
remove cached property
Nov 5, 2025
997be69
use .xp decorator on one inversion test, which still passes
Nov 5, 2025
8c5cf8c
all of inversion module passes with xp
Nov 5, 2025
4aae8f4
fix dataset tests
Nov 5, 2025
f8f005a
cnonditional on set
Nov 5, 2025
0d2ac11
xp on mask_2d_util
Nov 5, 2025
4f9e2c6
xp on mask_1d_util
Nov 5, 2025
6748208
first successful jit compile using .xp API
Nov 5, 2025
47d81d2
added xp name space to inversion_util
Nov 6, 2025
86608a7
xp used throughout Inversion classes
Nov 6, 2025
9d87b14
put xp throught some structure util, los of refactoring still to go
Nov 7, 2025
2d7b461
all array unitt ests pass
Nov 7, 2025
86c49e3
xp in grid
Nov 7, 2025
bc09850
clever removal of isnumpy
Nov 7, 2025
378cb2c
more xp unit test fixes
Nov 7, 2025
49a5e4a
xp put in border relcoator
Nov 7, 2025
c57a6af
xpo now through all of regularization and mapper
Nov 7, 2025
e3a428c
fix adaptive pixel signals
Nov 7, 2025
a947ee9
fix another regt est
Nov 7, 2025
e7efd73
most of mapper util uses xp
Nov 7, 2025
bb21afc
jax removed from mapper util
Nov 7, 2025
cf86118
mesh_util
Nov 7, 2025
00a5301
moved jnps remove
Nov 8, 2025
1def232
jnp to xp in Rectangfular mesh
Nov 8, 2025
ed3461c
jnp revmoed from vectors
Nov 8, 2025
7bd9e4d
lots of JAX stuff cleaned up
Nov 8, 2025
e99bfaa
fix numba import
Nov 8, 2025
08165cd
over sampler xp simplified
Nov 8, 2025
4ac872a
fix weeird unit ests
Nov 8, 2025
e69ff07
going to take detour removing grid project
Nov 8, 2025
d55304f
remove tests which use now unusued project code
Nov 8, 2025
f624774
fix numpy interp methods
Nov 9, 2025
0e0e429
xp refactor complete
Nov 9, 2025
c190e21
weight using adapt data added
Nov 9, 2025
afd0d41
split rectangular into RectangularMagnification and RectangularSource
Nov 10, 2025
5347dd5
fix unit tests
Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from abc import abstractmethod
import jax.numpy as jnp
from jax._src.tree_util import register_pytree_node
from jax import Array

import numpy as np

from autoconf.fitsable import output_to_fits

Expand Down Expand Up @@ -64,7 +65,11 @@ def wrapper(self, other):


class AbstractNDArray(ABC):
def __init__(self, array):

__no_flatten__ = ()

def __init__(self, array, xp=np):

self._is_transformed = False

while isinstance(array, AbstractNDArray):
Expand All @@ -79,7 +84,7 @@ def __init__(self, array):
except ValueError:
pass

__no_flatten__ = ()
self._xp = xp

def invert(self):
new = self.copy()
Expand All @@ -102,12 +107,6 @@ def instance_flatten(cls, instance):
)
return values, keys

@staticmethod
def flip_hdu_for_ds9(values):
if conf.instance["general"]["fits"]["flip_for_ds9"]:
return jnp.flipud(values)
return values

@classmethod
def instance_unflatten(cls, aux_data, children):
"""
Expand Down Expand Up @@ -138,6 +137,12 @@ def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
new_array._array = array
return new_array

@staticmethod
def flip_hdu_for_ds9(values):
if conf.instance["general"]["fits"]["flip_for_ds9"]:
return jnp.flipud(values)
return values

def copy(self):
new = copy(self)
return new
Expand Down Expand Up @@ -336,6 +341,7 @@ def __getitem__(self, item):
return result

def __setitem__(self, key, value):
from jax import Array
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
self._array = jnp.where(key, value, self._array)
else:
Expand Down
2 changes: 1 addition & 1 deletion autoarray/config/general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ inversion:
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.
numba:
use_numba: true
cache: false
cache: true
nopython: true
parallel: false
pixelization:
Expand Down
14 changes: 1 addition & 13 deletions autoarray/dataset/abstract/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import warnings
from typing import Optional, Union

from autoconf import cached_property

from autoarray.dataset.grids import GridsDataset

from autoarray import exc
from autoarray.mask.mask_1d import Mask1D
from autoarray.mask.mask_2d import Mask2D
Expand Down Expand Up @@ -140,14 +136,6 @@ def __init__(
def grid(self):
return self.grids.lp

@cached_property
def grids(self):
return GridsDataset(
mask=self.data.mask,
over_sample_size_lp=self.over_sample_size_lp,
over_sample_size_pixelization=self.over_sample_size_pixelization,
)

@property
def shape_native(self):
return self.mask.shape_native
Expand Down Expand Up @@ -188,7 +176,7 @@ def signal_to_noise_max(self) -> float:
"""
return np.max(self.signal_to_noise_map)

@cached_property
@property
def noise_covariance_matrix_inv(self) -> np.ndarray:
"""
Returns the inverse of the noise covariance matrix, which is used when computing a chi-squared which accounts
Expand Down
5 changes: 1 addition & 4 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
from pathlib import Path
from typing import Optional, Union

from autoconf import cached_property
from autoconf import instance

from autoarray.dataset.abstract.dataset import AbstractDataset
from autoarray.dataset.grids import GridsDataset
from autoarray.dataset.imaging.w_tilde import WTildeImaging
Expand Down Expand Up @@ -194,7 +191,7 @@ def __init__(
psf=self.psf,
)

@cached_property
@property
def w_tilde(self):
"""
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked
Expand Down
6 changes: 2 additions & 4 deletions autoarray/dataset/imaging/w_tilde.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
import numpy as np

from autoconf import cached_property

from autoarray.dataset.abstract.w_tilde import AbstractWTilde

from autoarray.inversion.inversion.imaging import inversion_imaging_util
Expand Down Expand Up @@ -55,7 +53,7 @@ def __init__(
self.psf = psf
self.mask = mask

@cached_property
@property
def w_matrix(self):
"""
The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF
Expand Down Expand Up @@ -93,7 +91,7 @@ def w_matrix(self):
).astype("int"),
)

@cached_property
@property
def psf_operator_matrix_dense(self):

return inversion_imaging_util.psf_operator_matrix_dense_from(
Expand Down
3 changes: 1 addition & 2 deletions autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
from pathlib import Path

from autoconf import cached_property
from autoconf.fitsable import ndarray_via_fits_from, output_to_fits

from autoarray.dataset.abstract.dataset import AbstractDataset
Expand Down Expand Up @@ -166,7 +165,7 @@ def w_tilde_preprocessing(self):

fits.writeto(filename, data=curvature_preload)

@cached_property
@property
def w_tilde(self):
"""
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
Expand Down
2 changes: 1 addition & 1 deletion autoarray/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class MeshException(Exception):
"""
Raises exceptions associated with the `inversion/mesh` modules and `Mesh` classes.

For example if a `Rectangular` mesh has dimensions below 3x3.
For example if a `RectangularMagnification` mesh has dimensions below 3x3.
"""

pass
Expand Down
20 changes: 10 additions & 10 deletions autoarray/fit/fit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import numpy as np

from autoconf import cached_property

from autoarray.dataset.grids import GridsInterface
from autoarray.dataset.dataset_model import DatasetModel
from autoarray.fit import fit_util
Expand Down Expand Up @@ -85,7 +83,7 @@ def chi_squared(self) -> float:
"""
Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map.
"""
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array)
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array, xp=self._xp)

@property
def noise_normalization(self) -> float:
Expand All @@ -94,7 +92,7 @@ def noise_normalization(self) -> float:

[Noise_Term] = sum(log(2*pi*[Noise]**2.0))
"""
return fit_util.noise_normalization_from(noise_map=self.noise_map.array)
return fit_util.noise_normalization_from(noise_map=self.noise_map.array, xp=self._xp)

@property
def log_likelihood(self) -> float:
Expand All @@ -115,6 +113,7 @@ def __init__(
dataset,
use_mask_in_fit: bool = False,
dataset_model: DatasetModel = None,
xp=np
):
"""Class to fit a masked dataset where the dataset's data structures are any dimension.

Expand Down Expand Up @@ -147,12 +146,13 @@ def __init__(
self.dataset = dataset
self.use_mask_in_fit = use_mask_in_fit
self.dataset_model = dataset_model or DatasetModel()
self._xp = xp

@property
def mask(self) -> Mask2D:
return self.dataset.mask

@cached_property
@property
def grids(self) -> GridsInterface:

def subtracted_from(grid, offset):
Expand Down Expand Up @@ -196,7 +196,7 @@ def residual_map(self) -> ty.DataLike:

if self.use_mask_in_fit:
return fit_util.residual_map_with_mask_from(
data=self.data, model_data=self.model_data, mask=self.mask
data=self.data, model_data=self.model_data, mask=self.mask, xp=self._xp
)
return super().residual_map

Expand All @@ -209,7 +209,7 @@ def normalized_residual_map(self) -> ty.DataLike:
"""
if self.use_mask_in_fit:
return fit_util.normalized_residual_map_with_mask_from(
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp
)
return super().normalized_residual_map

Expand All @@ -222,7 +222,7 @@ def chi_squared_map(self) -> ty.DataLike:
"""
if self.use_mask_in_fit:
return fit_util.chi_squared_map_with_mask_from(
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp
)
return super().chi_squared_map

Expand All @@ -243,7 +243,7 @@ def chi_squared(self) -> float:

if self.use_mask_in_fit:
return fit_util.chi_squared_with_mask_from(
chi_squared_map=self.chi_squared_map, mask=self.mask
chi_squared_map=self.chi_squared_map, mask=self.mask, xp=self._xp
)
return super().chi_squared

Expand All @@ -256,7 +256,7 @@ def noise_normalization(self) -> float:
"""
if self.use_mask_in_fit:
return fit_util.noise_normalization_with_mask_from(
noise_map=self.noise_map, mask=self.mask
noise_map=self.noise_map, mask=self.mask, xp=self._xp
)
return super().noise_normalization

Expand Down
4 changes: 3 additions & 1 deletion autoarray/fit/fit_imaging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
import numpy as np

from autoarray.dataset.imaging.dataset import Imaging
from autoarray.dataset.dataset_model import DatasetModel
Expand All @@ -14,6 +14,7 @@ def __init__(
dataset: Imaging,
use_mask_in_fit: bool = False,
dataset_model: DatasetModel = None,
xp=np
):
"""
Class to fit a masked imaging dataset.
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(
dataset=dataset,
use_mask_in_fit=use_mask_in_fit,
dataset_model=dataset_model,
xp=xp
)

@property
Expand Down
3 changes: 2 additions & 1 deletion autoarray/fit/fit_interferometer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from typing import Dict, Optional

from autoarray.dataset.interferometer.dataset import Interferometer

Expand All @@ -18,6 +17,7 @@ def __init__(
dataset: Interferometer,
dataset_model: DatasetModel = None,
use_mask_in_fit: bool = False,
xp=np
):
"""
Class to fit a masked interferometer dataset.
Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(
dataset=dataset,
dataset_model=dataset_model,
use_mask_in_fit=use_mask_in_fit,
xp=xp
)

@property
Expand Down
Loading
Loading