Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
33a3ccb
fix some decorator unit tests
Jammy2211 Apr 3, 2025
8b8dc9e
removing numpy wrapper to do explicit impots
Jammy2211 Apr 3, 2025
7115f9c
move relocate radial
Jammy2211 Apr 3, 2025
6f02715
more removal of numpy wrapper nps
Jammy2211 Apr 3, 2025
aa4c9e6
remove all numpy wrappers
Jammy2211 Apr 3, 2025
3b6ab48
remove warning for now
Jammy2211 Apr 3, 2025
44a2808
fix structure plotters
Jammy2211 Apr 3, 2025
ea139fc
clean up vectors_yx
Jammy2211 Apr 3, 2025
ec1e81e
remove autofit imports
Jammy2211 Apr 3, 2025
37e81f1
fix voronoi unit test in structures
Jammy2211 Apr 3, 2025
b31c0fc
fix test_preprocess
Jammy2211 Apr 4, 2025
80fc8e8
fix test dataset abstract
Jammy2211 Apr 4, 2025
cd276cd
fix test imaging
Jammy2211 Apr 4, 2025
f6dfda5
fix layout
Jammy2211 Apr 4, 2025
5430647
fix plot unit tests
Jammy2211 Apr 4, 2025
083ed0b
over sampling unit tests
Jammy2211 Apr 4, 2025
72af86b
fix all fit tests
Jammy2211 Apr 4, 2025
340c34d
pylops removed
Jammy2211 Apr 4, 2025
39fbf72
unit tests fixed
Jammy2211 Apr 4, 2025
dd7ca45
inversion now uses Convoler again, all numba, factory tests pass
Jammy2211 Apr 5, 2025
119866d
fix inversion plotters
Jammy2211 Apr 5, 2025
580be1c
fix test overlay
Jammy2211 Apr 5, 2025
994e048
fixn interferometer conversion
Jammy2211 Apr 5, 2025
9c5354c
more plot unit tests pass
Jammy2211 Apr 5, 2025
2b6754c
tst coverage complete
Jammy2211 Apr 5, 2025
2049042
small changes to jax array handling
Jammy2211 Apr 5, 2025
0da7719
clean up grid contou
Jammy2211 Apr 5, 2025
5e31ce0
arrays now store maks in jax
Jammy2211 Apr 6, 2025
3357cbb
array 2d always stored as jax
Jammy2211 Apr 6, 2025
b9e4cb5
update array_1d_slim_from
Jammy2211 Apr 6, 2025
d132365
array 1D stuff all updated to support JAX
Jammy2211 Apr 6, 2025
bd88fee
grid_1d_slim_via_mask_from to JAX
Jammy2211 Apr 6, 2025
5f8877b
grids now use JAX
Jammy2211 Apr 6, 2025
6b7d433
fix vectors
Jammy2211 Apr 6, 2025
a6480a9
fix some visualizatoin
Jammy2211 Apr 6, 2025
1fde8b1
risky change to figure open
Jammy2211 Apr 6, 2025
875483c
fix structurre plotters now JAX arrays used
Jammy2211 Apr 6, 2025
1088baf
over sampler fix
Jammy2211 Apr 6, 2025
f72b15e
fix kmesh
Jammy2211 Apr 6, 2025
ac68a08
fix test repr
Jammy2211 Apr 6, 2025
678673e
fix test preprocess
Jammy2211 Apr 6, 2025
c611bb6
fix abstract dataset tests
Jammy2211 Apr 6, 2025
e1899a1
fix test imaigng data
Jammy2211 Apr 6, 2025
7d5ba26
fix simulator
Jammy2211 Apr 6, 2025
1feabff
fix interferometrer
Jammy2211 Apr 6, 2025
2c02e39
dataset tests pass
Jammy2211 Apr 6, 2025
87931b6
fix fit tests
Jammy2211 Apr 6, 2025
0fa9646
mask array 2d choose type based on if input is jax or numpy
Jammy2211 Apr 6, 2025
578bc4c
array 1d now use convert rules
Jammy2211 Apr 6, 2025
32f55f7
grid 2d casting
Jammy2211 Apr 6, 2025
b5601e8
grid 2d follows rules
Jammy2211 Apr 6, 2025
49b8dc5
grid 1d no conversion
Jammy2211 Apr 6, 2025
d3c5bbf
same rules for grids
Jammy2211 Apr 6, 2025
61a72b8
fix dataset unit tests
Jammy2211 Apr 6, 2025
15cd0b5
fix some unit tests
Jammy2211 Apr 6, 2025
d33ab44
mask now also follows strictire typing rules
Jammy2211 Apr 6, 2025
3aef08a
but of typing simplication
Jammy2211 Apr 6, 2025
5d96ba5
check conversion needed
Jammy2211 Apr 6, 2025
ccccf1a
add small non zero numerical value to make grid radials plot correctly
Jammy2211 Apr 7, 2025
7948e90
small shift in projected grid
Jammy2211 Apr 7, 2025
5fc2b63
fix some unit tests due to numerical addition on projeted grid
Jammy2211 Apr 8, 2025
135d5b8
black
Jammy2211 Apr 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .fit.fit_imaging import FitImaging
from .fit.fit_interferometer import FitInterferometer
from .geometry.geometry_2d import Geometry2D
from .inversion.convolver import Convolver
from .inversion.pixelization.mappers.abstract import AbstractMapper
from .inversion.pixelization import mesh
from .inversion.pixelization import image_mesh
Expand All @@ -44,7 +45,6 @@
from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde
from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde
from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping
from .inversion.inversion.interferometer.lop import InversionInterferometerMappingPyLops
from .inversion.linear_obj.linear_obj import LinearObj
from .inversion.linear_obj.func_list import AbstractLinearObjFuncList
from .mask.derive.indexes_2d import DeriveIndexes2D
Expand Down
17 changes: 9 additions & 8 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from abc import ABC
from abc import abstractmethod
import jax.numpy as jnp

from autoconf.fitsable import output_to_fits

from autoarray.numpy_wrapper import np, register_pytree_node, Array
from autoarray.numpy_wrapper import register_pytree_node, Array

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(self, array):

def invert(self):
new = self.copy()
new._array = np.invert(new._array)
new._array = jnp.invert(new._array)
return new

@classmethod
Expand All @@ -104,7 +105,7 @@ def instance_flatten(cls, instance):
@staticmethod
def flip_hdu_for_ds9(values):
if conf.instance["general"]["fits"]["flip_for_ds9"]:
return np.flipud(values)
return jnp.flipud(values)
return values

@classmethod
Expand All @@ -117,7 +118,7 @@ def instance_unflatten(cls, aux_data, children):
setattr(instance, key, value)
return instance

def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
"""
Copy this object but give it a new array.

Expand Down Expand Up @@ -164,7 +165,7 @@ def __iter__(self):

@to_new_array
def sqrt(self):
return np.sqrt(self._array)
return jnp.sqrt(self._array)

@property
def array(self):
Expand Down Expand Up @@ -330,13 +331,13 @@ def __getitem__(self, item):
result = self._array[item]
if isinstance(item, slice):
result = self.with_new_array(result)
if isinstance(result, np.ndarray):
if isinstance(result, jnp.ndarray):
result = self.with_new_array(result)
return result

def __setitem__(self, key, value):
if isinstance(key, (np.ndarray, AbstractNDArray, Array)):
self._array = np.where(key, value, self._array)
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
self._array = jnp.where(key, value, self._array)
else:
self._array[key] = value

Expand Down
3 changes: 0 additions & 3 deletions autoarray/config/grids.yaml

This file was deleted.

46 changes: 35 additions & 11 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ def __init__(

self.psf = psf

if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")
if psf is not None:
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")

@cached_property
def grids(self):
Expand All @@ -178,6 +179,27 @@ def grids(self):
psf=self.psf,
)

@cached_property
def convolver(self):
"""
Returns a `Convolver` from a mask and 2D PSF kernel.

The `Convolver` stores in memory the array indexing between the mask and PSF, enabling efficient 2D PSF
convolution of images and matrices used for linear algebra calculations (see `operators.convolver`).

This uses lazy allocation such that the calculation is only performed when the convolver is used, ensuring
efficient set up of the `Imaging` class.

Returns
-------
Convolver
The convolver given the masked imaging data's mask and PSF.
"""

from autoarray.inversion.convolver import Convolver

return Convolver(mask=self.mask, kernel=self.psf)

@cached_property
def w_tilde(self):
"""
Expand All @@ -203,9 +225,11 @@ def w_tilde(self):
indexes,
lengths,
) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from(
noise_map_native=np.array(self.noise_map.native),
kernel_native=np.array(self.psf.native),
native_index_for_slim_index=self.mask.derive_indexes.native_for_slim,
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
kernel_native=np.array(self.psf.native.array).astype("float64"),
native_index_for_slim_index=np.array(
self.mask.derive_indexes.native_for_slim
).astype("int"),
)

return WTildeImaging(
Expand Down Expand Up @@ -408,20 +432,20 @@ def apply_noise_scaling(
"""

if signal_to_noise_value is None:
noise_map = self.noise_map.native
noise_map[mask == False] = noise_value
noise_map = np.array(self.noise_map.native.array)
noise_map[mask.array == False] = noise_value
else:
noise_map = np.where(
mask == False,
np.median(self.data.native[mask.derive_mask.edge == False])
np.median(self.data.native.array[mask.derive_mask.edge == False])
/ signal_to_noise_value,
self.noise_map.native,
self.noise_map.native.array,
)

if should_zero_data:
data = np.where(np.invert(mask), 0.0, self.data.native)
data = np.where(np.invert(mask.array), 0.0, self.data.native.array)
else:
data = self.data.native
data = self.data.native.array

data_unmasked = Array2D.no_mask(
values=data,
Expand Down
6 changes: 4 additions & 2 deletions autoarray/dataset/imaging/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def via_image_from(
pixel_scales=image.pixel_scales,
)

if np.isnan(noise_map).any():
if np.isnan(noise_map.array).any():
raise exc.DatasetException(
"The noise-map has NaN values in it. This suggests your exposure time and / or"
"background sky levels are too low, creating signal counts at or close to 0.0."
Expand All @@ -161,7 +161,9 @@ def via_image_from(
image = image - background_sky_map

mask = Mask2D.all_false(
shape_native=image.shape_native, pixel_scales=image.pixel_scales
shape_native=image.shape_native,
pixel_scales=image.pixel_scales,
origin=image.origin,
)

image = Array2D(values=image, mask=mask)
Expand Down
6 changes: 4 additions & 2 deletions autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def w_tilde(self):

w_matrix = inversion_interferometer_util.w_tilde_via_preload_from(
w_tilde_preload=curvature_preload,
native_index_for_slim_index=self.real_space_mask.derive_indexes.native_for_slim,
native_index_for_slim_index=np.array(
self.real_space_mask.derive_indexes.native_for_slim
).astype("int"),
)

dirty_image = self.transformer.image_from(
Expand All @@ -205,7 +207,7 @@ def w_tilde(self):
return WTildeInterferometer(
w_matrix=w_matrix,
curvature_preload=curvature_preload,
dirty_image=dirty_image,
dirty_image=np.array(dirty_image.array),
real_space_mask=self.real_space_mask,
noise_map_value=self.noise_map[0],
)
Expand Down
28 changes: 16 additions & 12 deletions autoarray/dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def noise_map_via_data_eps_and_exposure_time_map_from(data_eps, exposure_time_ma
The exposure time at every data-point of the data.
"""
return data_eps.with_new_array(
np.abs(data_eps * exposure_time_map) ** 0.5 / exposure_time_map
np.abs(data_eps.array * exposure_time_map.array) ** 0.5
/ exposure_time_map.array
)


Expand Down Expand Up @@ -263,15 +264,17 @@ def edges_from(image, no_edges):
edges = []

for edge_no in range(no_edges):
top_edge = image.native[edge_no, edge_no : image.shape_native[1] - edge_no]
bottom_edge = image.native[
top_edge = image.native.array[
edge_no, edge_no : image.shape_native[1] - edge_no
]
bottom_edge = image.native.array[
image.shape_native[0] - 1 - edge_no,
edge_no : image.shape_native[1] - edge_no,
]
left_edge = image.native[
left_edge = image.native.array[
edge_no + 1 : image.shape_native[0] - 1 - edge_no, edge_no
]
right_edge = image.native[
right_edge = image.native.array[
edge_no + 1 : image.shape_native[0] - 1 - edge_no,
image.shape_native[1] - 1 - edge_no,
]
Expand Down Expand Up @@ -406,9 +409,10 @@ def poisson_noise_via_data_eps_from(data_eps, exposure_time_map, seed=-1):
An array describing simulated poisson noise_maps
"""
setup_random_seed(seed)
image_counts = np.multiply(data_eps, exposure_time_map)

image_counts = np.multiply(data_eps.array, exposure_time_map.array)
return data_eps - np.divide(
np.random.poisson(image_counts, data_eps.shape), exposure_time_map
np.random.poisson(image_counts, data_eps.shape), exposure_time_map.array
)


Expand Down Expand Up @@ -506,8 +510,6 @@ def noise_map_with_signal_to_noise_limit_from(
from autoarray.structures.arrays.uniform_1d import Array1D
from autoarray.structures.arrays.uniform_2d import Array2D

# TODO : Refacotr into a util

signal_to_noise_map = data / noise_map
signal_to_noise_map[signal_to_noise_map < 0] = 0

Expand All @@ -517,12 +519,14 @@ def noise_map_with_signal_to_noise_limit_from(
noise_map_limit = np.where(
(signal_to_noise_map.native > signal_to_noise_limit)
& (noise_limit_mask == False),
np.abs(data.native) / signal_to_noise_limit,
noise_map.native,
np.abs(data.native.array) / signal_to_noise_limit,
noise_map.native.array,
)

mask = Mask2D.all_false(
shape_native=data.shape_native, pixel_scales=data.pixel_scales
shape_native=data.shape_native,
pixel_scales=data.pixel_scales,
origin=data.origin,
)

if len(noise_map.native) == 1:
Expand Down
4 changes: 2 additions & 2 deletions autoarray/fit/fit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def chi_squared(self) -> float:
"""
Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map.
"""
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map)
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array)

@property
def noise_normalization(self) -> float:
Expand All @@ -95,7 +95,7 @@ def noise_normalization(self) -> float:

[Noise_Term] = sum(log(2*pi*[Noise]**2.0))
"""
return fit_util.noise_normalization_from(noise_map=self.noise_map)
return fit_util.noise_normalization_from(noise_map=self.noise_map.array)

@property
def log_likelihood(self) -> float:
Expand Down
Loading
Loading