Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .fit.fit_imaging import FitImaging
from .fit.fit_interferometer import FitInterferometer
from .geometry.geometry_2d import Geometry2D
from .inversion.convolver import Convolver
from .inversion.pixelization.mappers.abstract import AbstractMapper
from .inversion.pixelization import mesh
from .inversion.pixelization import image_mesh
Expand All @@ -44,7 +45,6 @@
from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde
from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde
from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping
from .inversion.inversion.interferometer.lop import InversionInterferometerMappingPyLops
from .inversion.linear_obj.linear_obj import LinearObj
from .inversion.linear_obj.func_list import AbstractLinearObjFuncList
from .mask.derive.indexes_2d import DeriveIndexes2D
Expand Down
17 changes: 9 additions & 8 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from abc import ABC
from abc import abstractmethod
import jax.numpy as jnp

from autoconf.fitsable import output_to_fits

from autoarray.numpy_wrapper import np, register_pytree_node, Array
from autoarray.numpy_wrapper import register_pytree_node, Array

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(self, array):

def invert(self):
new = self.copy()
new._array = np.invert(new._array)
new._array = jnp.invert(new._array)
return new

@classmethod
Expand All @@ -104,7 +105,7 @@ def instance_flatten(cls, instance):
@staticmethod
def flip_hdu_for_ds9(values):
if conf.instance["general"]["fits"]["flip_for_ds9"]:
return np.flipud(values)
return jnp.flipud(values)
return values

@classmethod
Expand All @@ -117,7 +118,7 @@ def instance_unflatten(cls, aux_data, children):
setattr(instance, key, value)
return instance

def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
"""
Copy this object but give it a new array.

Expand Down Expand Up @@ -164,7 +165,7 @@ def __iter__(self):

@to_new_array
def sqrt(self):
return np.sqrt(self._array)
return jnp.sqrt(self._array)

@property
def array(self):
Expand Down Expand Up @@ -330,13 +331,13 @@ def __getitem__(self, item):
result = self._array[item]
if isinstance(item, slice):
result = self.with_new_array(result)
if isinstance(result, np.ndarray):
if isinstance(result, jnp.ndarray):
result = self.with_new_array(result)
return result

def __setitem__(self, key, value):
if isinstance(key, (np.ndarray, AbstractNDArray, Array)):
self._array = np.where(key, value, self._array)
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
self._array = jnp.where(key, value, self._array)
else:
self._array[key] = value

Expand Down
3 changes: 0 additions & 3 deletions autoarray/config/grids.yaml

This file was deleted.

44 changes: 33 additions & 11 deletions autoarray/dataset/imaging/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ def __init__(

self.psf = psf

if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")
if psf is not None:
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
raise exc.KernelException("Kernel2D Kernel2D must be odd")

@cached_property
def grids(self):
Expand All @@ -178,6 +179,27 @@ def grids(self):
psf=self.psf,
)

@cached_property
def convolver(self):
"""
Returns a `Convolver` from a mask and 2D PSF kernel.

The `Convolver` stores in memory the array indexing between the mask and PSF, enabling efficient 2D PSF
convolution of images and matrices used for linear algebra calculations (see `operators.convolver`).

This uses lazy allocation such that the calculation is only performed when the convolver is used, ensuring
efficient set up of the `Imaging` class.

Returns
-------
Convolver
The convolver given the masked imaging data's mask and PSF.
"""

from autoarray.inversion.convolver import Convolver

return Convolver(mask=self.mask, kernel=self.psf)

@cached_property
def w_tilde(self):
"""
Expand All @@ -203,9 +225,9 @@ def w_tilde(self):
indexes,
lengths,
) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from(
noise_map_native=np.array(self.noise_map.native),
kernel_native=np.array(self.psf.native),
native_index_for_slim_index=self.mask.derive_indexes.native_for_slim,
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
kernel_native=np.array(self.psf.native.array).astype("float64"),
native_index_for_slim_index=np.array(self.mask.derive_indexes.native_for_slim).astype("int"),
)

return WTildeImaging(
Expand Down Expand Up @@ -408,20 +430,20 @@ def apply_noise_scaling(
"""

if signal_to_noise_value is None:
noise_map = self.noise_map.native
noise_map[mask == False] = noise_value
noise_map = np.array(self.noise_map.native.array)
noise_map[mask.array == False] = noise_value
else:
noise_map = np.where(
mask == False,
np.median(self.data.native[mask.derive_mask.edge == False])
np.median(self.data.native.array[mask.derive_mask.edge == False])
/ signal_to_noise_value,
self.noise_map.native,
self.noise_map.native.array,
)

if should_zero_data:
data = np.where(np.invert(mask), 0.0, self.data.native)
data = np.where(np.invert(mask.array), 0.0, self.data.native.array)
else:
data = self.data.native
data = self.data.native.array

data_unmasked = Array2D.no_mask(
values=data,
Expand Down
2 changes: 1 addition & 1 deletion autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def w_tilde(self):

w_matrix = inversion_interferometer_util.w_tilde_via_preload_from(
w_tilde_preload=curvature_preload,
native_index_for_slim_index=self.real_space_mask.derive_indexes.native_for_slim,
native_index_for_slim_index=np.array(self.real_space_mask.derive_indexes.native_for_slim).astype("int"),
)

dirty_image = self.transformer.image_from(
Expand Down
12 changes: 6 additions & 6 deletions autoarray/dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ def edges_from(image, no_edges):
edges = []

for edge_no in range(no_edges):
top_edge = image.native[edge_no, edge_no : image.shape_native[1] - edge_no]
bottom_edge = image.native[
top_edge = image.native.array[edge_no, edge_no : image.shape_native[1] - edge_no]
bottom_edge = image.native.array[
image.shape_native[0] - 1 - edge_no,
edge_no : image.shape_native[1] - edge_no,
]
left_edge = image.native[
left_edge = image.native.array[
edge_no + 1 : image.shape_native[0] - 1 - edge_no, edge_no
]
right_edge = image.native[
right_edge = image.native.array[
edge_no + 1 : image.shape_native[0] - 1 - edge_no,
image.shape_native[1] - 1 - edge_no,
]
Expand Down Expand Up @@ -517,8 +517,8 @@ def noise_map_with_signal_to_noise_limit_from(
noise_map_limit = np.where(
(signal_to_noise_map.native > signal_to_noise_limit)
& (noise_limit_mask == False),
np.abs(data.native) / signal_to_noise_limit,
noise_map.native,
np.abs(data.native.array) / signal_to_noise_limit,
noise_map.native.array,
)

mask = Mask2D.all_false(
Expand Down
Loading
Loading