Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde
from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde
from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping
from .inversion.inversion.interferometer.lop import InversionInterferometerMappingPyLops
from .inversion.linear_obj.linear_obj import LinearObj
from .inversion.linear_obj.func_list import AbstractLinearObjFuncList
from .mask.derive.indexes_2d import DeriveIndexes2D
Expand Down
17 changes: 9 additions & 8 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from abc import ABC
from abc import abstractmethod
import jax.numpy as jnp

from autoconf.fitsable import output_to_fits

from autoarray.numpy_wrapper import np, register_pytree_node, Array
from autoarray.numpy_wrapper import register_pytree_node, Array

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(self, array):

def invert(self):
new = self.copy()
new._array = np.invert(new._array)
new._array = jnp.invert(new._array)
return new

@classmethod
Expand All @@ -104,7 +105,7 @@ def instance_flatten(cls, instance):
@staticmethod
def flip_hdu_for_ds9(values):
if conf.instance["general"]["fits"]["flip_for_ds9"]:
return np.flipud(values)
return jnp.flipud(values)
return values

@classmethod
Expand All @@ -117,7 +118,7 @@ def instance_unflatten(cls, aux_data, children):
setattr(instance, key, value)
return instance

def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
"""
Copy this object but give it a new array.

Expand Down Expand Up @@ -164,7 +165,7 @@ def __iter__(self):

@to_new_array
def sqrt(self):
return np.sqrt(self._array)
return jnp.sqrt(self._array)

@property
def array(self):
Expand Down Expand Up @@ -330,13 +331,13 @@ def __getitem__(self, item):
result = self._array[item]
if isinstance(item, slice):
result = self.with_new_array(result)
if isinstance(result, np.ndarray):
if isinstance(result, jnp.ndarray):
result = self.with_new_array(result)
return result

def __setitem__(self, key, value):
if isinstance(key, (np.ndarray, AbstractNDArray, Array)):
self._array = np.where(key, value, self._array)
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
self._array = jnp.where(key, value, self._array)
else:
self._array[key] = value

Expand Down
3 changes: 0 additions & 3 deletions autoarray/config/grids.yaml

This file was deleted.

23 changes: 12 additions & 11 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ def __init__(

self.psf = psf

if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")
if psf is not None:
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")

@cached_property
def grids(self):
Expand Down Expand Up @@ -203,9 +204,9 @@ def w_tilde(self):
indexes,
lengths,
) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from(
noise_map_native=np.array(self.noise_map.native),
kernel_native=np.array(self.psf.native),
native_index_for_slim_index=self.mask.derive_indexes.native_for_slim,
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
kernel_native=np.array(self.psf.native.array).astype("float64"),
native_index_for_slim_index=np.array(self.mask.derive_indexes.native_for_slim).astype("int"),
)

return WTildeImaging(
Expand Down Expand Up @@ -408,20 +409,20 @@ def apply_noise_scaling(
"""

if signal_to_noise_value is None:
noise_map = self.noise_map.native
noise_map[mask == False] = noise_value
noise_map = np.array(self.noise_map.native.array)
noise_map[mask.array == False] = noise_value
else:
noise_map = np.where(
mask == False,
np.median(self.data.native[mask.derive_mask.edge == False])
np.median(self.data.native.array[mask.derive_mask.edge == False])
/ signal_to_noise_value,
self.noise_map.native,
self.noise_map.native.array,
)

if should_zero_data:
data = np.where(np.invert(mask), 0.0, self.data.native)
data = np.where(np.invert(mask.array), 0.0, self.data.native.array)
else:
data = self.data.native
data = self.data.native.array

data_unmasked = Array2D.no_mask(
values=data,
Expand Down
12 changes: 6 additions & 6 deletions autoarray/dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ def edges_from(image, no_edges):
edges = []

for edge_no in range(no_edges):
top_edge = image.native[edge_no, edge_no : image.shape_native[1] - edge_no]
bottom_edge = image.native[
top_edge = image.native.array[edge_no, edge_no : image.shape_native[1] - edge_no]
bottom_edge = image.native.array[
image.shape_native[0] - 1 - edge_no,
edge_no : image.shape_native[1] - edge_no,
]
left_edge = image.native[
left_edge = image.native.array[
edge_no + 1 : image.shape_native[0] - 1 - edge_no, edge_no
]
right_edge = image.native[
right_edge = image.native.array[
edge_no + 1 : image.shape_native[0] - 1 - edge_no,
image.shape_native[1] - 1 - edge_no,
]
Expand Down Expand Up @@ -517,8 +517,8 @@ def noise_map_with_signal_to_noise_limit_from(
noise_map_limit = np.where(
(signal_to_noise_map.native > signal_to_noise_limit)
& (noise_limit_mask == False),
np.abs(data.native) / signal_to_noise_limit,
noise_map.native,
np.abs(data.native.array) / signal_to_noise_limit,
noise_map.native.array,
)

mask = Mask2D.all_false(
Expand Down
87 changes: 36 additions & 51 deletions autoarray/fit/fit_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import wraps
import jax.numpy as np
import jax.numpy as jnp
import numpy as np

from autoarray.mask.abstract_mask import Mask

Expand Down Expand Up @@ -83,7 +84,7 @@ def chi_squared_from(*, chi_squared_map: ty.DataLike) -> float:
chi_squared_map
The chi-squared-map of values of the model-data fit to the dataset.
"""
return np.sum(chi_squared_map._array)
return jnp.sum(np.array(chi_squared_map))


def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
Expand All @@ -97,12 +98,12 @@ def noise_normalization_from(*, noise_map: ty.DataLike) -> float:
noise_map
The masked noise-map of the dataset.
"""
return np.sum(np.log(2 * np.pi * noise_map._array**2.0))
return jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map)**2.0))


def normalized_residual_map_complex_from(
*, residual_map: np.ndarray, noise_map: np.ndarray
) -> np.ndarray:
*, residual_map: jnp.ndarray, noise_map: jnp.ndarray
) -> jnp.ndarray:
"""
Returns the normalized residual-map of the fit of complex model-data to a dataset, where:

Expand All @@ -126,8 +127,8 @@ def normalized_residual_map_complex_from(


def chi_squared_map_complex_from(
*, residual_map: np.ndarray, noise_map: np.ndarray
) -> np.ndarray:
*, residual_map: jnp.ndarray, noise_map: jnp.ndarray
) -> jnp.ndarray:
"""
Returnss the chi-squared-map of the fit of complex model-data to a dataset, where:

Expand All @@ -145,7 +146,7 @@ def chi_squared_map_complex_from(
return chi_squared_map_real + 1j * chi_squared_map_imag


def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float:
def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float:
"""
Returns the chi-squared terms of each complex model data's fit to a masked dataset, by summing the masked
chi-squared-map of the fit.
Expand All @@ -157,12 +158,12 @@ def chi_squared_complex_from(*, chi_squared_map: np.ndarray) -> float:
chi_squared_map
The chi-squared-map of values of the model-data fit to the dataset.
"""
chi_squared_real = np.sum(chi_squared_map.real)
chi_squared_imag = np.sum(chi_squared_map.imag)
chi_squared_real = jnp.sum(np.array(chi_squared_map.real))
chi_squared_imag = jnp.sum(np.array(chi_squared_map.imag))
return chi_squared_real + chi_squared_imag


def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float:
def noise_normalization_complex_from(*, noise_map: jnp.ndarray) -> float:
"""
Returns the noise-map normalization terms of a complex noise-map, summing the noise_map value in every pixel as:

Expand All @@ -173,8 +174,8 @@ def noise_normalization_complex_from(*, noise_map: np.ndarray) -> float:
noise_map
The masked noise-map of the dataset.
"""
noise_normalization_real = np.sum(np.log(2 * np.pi * noise_map.real**2.0))
noise_normalization_imag = np.sum(np.log(2 * np.pi * noise_map.imag**2.0))
noise_normalization_real = jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map).real**2.0))
noise_normalization_imag = jnp.sum(jnp.log(2 * jnp.pi * np.array(noise_map).imag**2.0))
return noise_normalization_real + noise_normalization_imag


Expand All @@ -198,9 +199,7 @@ def residual_map_with_mask_from(
model_data
The model data used to fit the data.
"""
return np.subtract(
data, model_data, out=np.zeros_like(data), where=np.asarray(mask) == 0
)
return jnp.where(jnp.asarray(mask) == 0, jnp.subtract(data, model_data), 0)


@to_new_array
Expand All @@ -223,13 +222,7 @@ def normalized_residual_map_with_mask_from(
mask
The mask applied to the residual-map, where `False` entries are included in the calculation.
"""
return np.divide(
residual_map,
noise_map,
out=np.zeros_like(residual_map),
where=np.asarray(mask) == 0,
)

return jnp.where(jnp.asarray(mask) == 0, jnp.divide(residual_map, noise_map), 0)

@to_new_array
def chi_squared_map_with_mask_from(
Expand All @@ -251,13 +244,10 @@ def chi_squared_map_with_mask_from(
mask
The mask applied to the residual-map, where `False` entries are included in the calculation.
"""
return np.square(
np.divide(
residual_map,
noise_map,
out=np.zeros_like(residual_map),
where=np.asarray(mask) == 0,
)
return jnp.where(
jnp.asarray(mask) == 0,
jnp.square(residual_map / noise_map),
0
)


Expand All @@ -275,7 +265,7 @@ def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask) -> f
mask
The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
"""
return float(np.sum(chi_squared_map[np.asarray(mask) == 0]))
return float(jnp.sum(chi_squared_map[jnp.asarray(mask) == 0]))


def chi_squared_with_mask_fast_from(
Expand All @@ -302,14 +292,14 @@ def chi_squared_with_mask_fast_from(
The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
"""
return float(
np.sum(
np.square(
np.divide(
np.subtract(
jnp.sum(
jnp.square(
jnp.divide(
jnp.subtract(
data,
model_data,
)[np.asarray(mask) == 0],
noise_map[np.asarray(mask) == 0],
)[jnp.asarray(mask) == 0],
noise_map[jnp.asarray(mask) == 0],
)
)
)
Expand All @@ -331,11 +321,11 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask) ->
mask
The mask applied to the noise-map, where `False` entries are included in the calculation.
"""
return float(np.sum(np.log(2 * np.pi * noise_map[np.asarray(mask) == 0] ** 2.0)))
return float(jnp.sum(jnp.log(2 * jnp.pi * noise_map[jnp.asarray(mask) == 0] ** 2.0)))


def chi_squared_with_noise_covariance_from(
*, residual_map: ty.DataLike, noise_covariance_matrix_inv: np.ndarray
*, residual_map: ty.DataLike, noise_covariance_matrix_inv: jnp.ndarray
) -> float:
"""
Returns the chi-squared value of the fit of model-data to a masked dataset, where
Expand All @@ -351,7 +341,7 @@ def chi_squared_with_noise_covariance_from(
The inverse of the noise covariance matrix.
"""

return residual_map @ noise_covariance_matrix_inv @ residual_map
return residual_map.array @ noise_covariance_matrix_inv @ residual_map.array


def log_likelihood_from(*, chi_squared: float, noise_normalization: float) -> float:
Expand Down Expand Up @@ -431,8 +421,8 @@ def log_evidence_from(


def residual_flux_fraction_map_from(
*, residual_map: np.ndarray, data: np.ndarray
) -> np.ndarray:
*, residual_map: jnp.ndarray, data: jnp.ndarray
) -> jnp.ndarray:
"""
Returns the residual flux fraction map of the fit of model-data to a masked dataset, where:

Expand All @@ -445,12 +435,12 @@ def residual_flux_fraction_map_from(
data
The data of the dataset.
"""
return np.divide(residual_map, data, out=np.zeros_like(residual_map))
return jnp.where(data != 0, residual_map / data, 0)


def residual_flux_fraction_map_with_mask_from(
*, residual_map: np.ndarray, data: np.ndarray, mask: Mask
) -> np.ndarray:
*, residual_map: jnp.ndarray, data: jnp.ndarray, mask: Mask
) -> jnp.ndarray:
"""
Returnss the residual flux fraction map of the fit of model-data to a masked dataset, where:

Expand All @@ -467,9 +457,4 @@ def residual_flux_fraction_map_with_mask_from(
mask
The mask applied to the residual-map, where `False` entries are included in the calculation.
"""
return np.divide(
residual_map,
data,
out=np.zeros_like(residual_map),
where=np.asarray(mask) == 0,
)
return jnp.where(mask == 0, residual_map / data, 0)
2 changes: 0 additions & 2 deletions autoarray/geometry/geometry_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from autoarray import type as ty
from autoarray.geometry import geometry_util

from autofit.jax_wrapper import use_jax

logging.basicConfig()
logger = logging.getLogger(__name__)

Expand Down
Loading
Loading