Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def __getitem__(self, item):

def __setitem__(self, key, value):
from jax import Array

if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
self._array = jnp.where(key, value, self._array)
else:
Expand Down
5 changes: 4 additions & 1 deletion autoarray/dataset/imaging/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def __init__(
self.noise_seed = noise_seed

def via_image_from(
self, image: Array2D, over_sample_size: Optional[Union[int, np.ndarray]] = None, xp=np
self,
image: Array2D,
over_sample_size: Optional[Union[int, np.ndarray]] = None,
xp=np,
) -> Imaging:
"""
Simulate an `Imaging` dataset from an input image.
Expand Down
20 changes: 15 additions & 5 deletions autoarray/fit/fit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def chi_squared(self) -> float:
"""
Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map.
"""
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array, xp=self._xp)
return fit_util.chi_squared_from(
chi_squared_map=self.chi_squared_map.array, xp=self._xp
)

@property
def noise_normalization(self) -> float:
Expand All @@ -92,7 +94,9 @@ def noise_normalization(self) -> float:

[Noise_Term] = sum(log(2*pi*[Noise]**2.0))
"""
return fit_util.noise_normalization_from(noise_map=self.noise_map.array, xp=self._xp)
return fit_util.noise_normalization_from(
noise_map=self.noise_map.array, xp=self._xp
)

@property
def log_likelihood(self) -> float:
Expand All @@ -113,7 +117,7 @@ def __init__(
dataset,
use_mask_in_fit: bool = False,
dataset_model: DatasetModel = None,
xp=np
xp=np,
):
"""Class to fit a masked dataset where the dataset's data structures are any dimension.

Expand Down Expand Up @@ -209,7 +213,10 @@ def normalized_residual_map(self) -> ty.DataLike:
"""
if self.use_mask_in_fit:
return fit_util.normalized_residual_map_with_mask_from(
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp
residual_map=self.residual_map,
noise_map=self.noise_map,
mask=self.mask,
xp=self._xp,
)
return super().normalized_residual_map

Expand All @@ -222,7 +229,10 @@ def chi_squared_map(self) -> ty.DataLike:
"""
if self.use_mask_in_fit:
return fit_util.chi_squared_map_with_mask_from(
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp
residual_map=self.residual_map,
noise_map=self.noise_map,
mask=self.mask,
xp=self._xp,
)
return super().chi_squared_map

Expand Down
4 changes: 2 additions & 2 deletions autoarray/fit/fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
dataset: Imaging,
use_mask_in_fit: bool = False,
dataset_model: DatasetModel = None,
xp=np
xp=np,
):
"""
Class to fit a masked imaging dataset.
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
dataset=dataset,
use_mask_in_fit=use_mask_in_fit,
dataset_model=dataset_model,
xp=xp
xp=xp,
)

@property
Expand Down
4 changes: 2 additions & 2 deletions autoarray/fit/fit_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
dataset: Interferometer,
dataset_model: DatasetModel = None,
use_mask_in_fit: bool = False,
xp=np
xp=np,
):
"""
Class to fit a masked interferometer dataset.
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
dataset=dataset,
dataset_model=dataset_model,
use_mask_in_fit=use_mask_in_fit,
xp=xp
xp=xp,
)

@property
Expand Down
19 changes: 13 additions & 6 deletions autoarray/fit/fit_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def chi_squared_map_with_mask_from(
return xp.where(xp.asarray(mask) == 0, xp.square(residual_map / noise_map), 0)


def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask, xp=np) -> float:
def chi_squared_with_mask_from(
*, chi_squared_map: ty.DataLike, mask: Mask, xp=np
) -> float:
"""
Returns the chi-squared terms of each model data's fit to a masked dataset, by summing the masked
chi-squared-map of the fit.
Expand All @@ -265,7 +267,12 @@ def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask, xp=n


def chi_squared_with_mask_fast_from(
*, data: ty.DataLike, mask: Mask, model_data: ty.DataLike, noise_map: ty.DataLike, xp=np
*,
data: ty.DataLike,
mask: Mask,
model_data: ty.DataLike,
noise_map: ty.DataLike,
xp=np,
) -> float:
"""
Returns the chi-squared terms of each model data's fit to a masked dataset, by summing the masked
Expand Down Expand Up @@ -302,7 +309,9 @@ def chi_squared_with_mask_fast_from(
)


def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask, xp=np) -> float:
def noise_normalization_with_mask_from(
*, noise_map: ty.DataLike, mask: Mask, xp=np
) -> float:
"""
Returns the noise-map normalization terms of masked noise-map, summing the noise_map value in every pixel as:

Expand All @@ -317,9 +326,7 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask, xp
mask
The mask applied to the noise-map, where `False` entries are included in the calculation.
"""
return float(
xp.sum(xp.log(2 * xp.pi * noise_map[xp.asarray(mask) == 0] ** 2.0))
)
return float(xp.sum(xp.log(2 * xp.pi * noise_map[xp.asarray(mask) == 0] ** 2.0)))


def chi_squared_with_noise_covariance_from(
Expand Down
38 changes: 27 additions & 11 deletions autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from autoarray.util import misc_util
from autoarray.inversion.inversion import inversion_util


class AbstractInversion:
def __init__(
self,
dataset: Union[Imaging, Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
xp=np
xp=np,
):
"""
An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
Expand Down Expand Up @@ -75,8 +76,6 @@ def __init__(

self._xp = xp



@property
def data(self):
return self.dataset.data
Expand Down Expand Up @@ -333,10 +332,15 @@ def regularization_matrix(self) -> Optional[np.ndarray]:
"""
if self._xp.__name__.startswith("jax"):
from jax.scipy.linalg import block_diag

return block_diag(
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
*[
linear_obj.regularization_matrix
for linear_obj in self.linear_obj_list
]
)
from scipy.linalg import block_diag

return block_diag(
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
)
Expand Down Expand Up @@ -448,7 +452,7 @@ def reconstruction(self) -> np.ndarray:
data_vector=data_vector,
curvature_reg_matrix=curvature_reg_matrix,
settings=self.settings,
xp=self._xp
xp=self._xp,
)
)

Expand All @@ -471,13 +475,13 @@ def reconstruction(self) -> np.ndarray:
data_vector=self.data_vector,
curvature_reg_matrix=self.curvature_reg_matrix,
settings=self.settings,
xp=self._xp
xp=self._xp,
)

return inversion_util.reconstruction_positive_negative_from(
data_vector=self.data_vector,
curvature_reg_matrix=self.curvature_reg_matrix,
xp=self._xp
xp=self._xp,
)

@property
Expand Down Expand Up @@ -640,7 +644,9 @@ def regularization_term(self) -> float:

return self._xp.matmul(
self.reconstruction_reduced.T,
self._xp.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced),
self._xp.matmul(
self.regularization_matrix_reduced, self.reconstruction_reduced
),
)

@property
Expand All @@ -654,7 +660,11 @@ def log_det_curvature_reg_matrix_term(self) -> float:
return 0.0

return 2.0 * self._xp.sum(
self._xp.log(self._xp.diag(self._xp.linalg.cholesky(self.curvature_reg_matrix_reduced)))
self._xp.log(
self._xp.diag(
self._xp.linalg.cholesky(self.curvature_reg_matrix_reduced)
)
)
)

@property
Expand All @@ -675,7 +685,11 @@ def log_det_regularization_matrix_term(self) -> float:
return 0.0

return 2.0 * self._xp.sum(
self._xp.log(self._xp.diag(self._xp.linalg.cholesky(self.regularization_matrix_reduced)))
self._xp.log(
self._xp.diag(
self._xp.linalg.cholesky(self.regularization_matrix_reduced)
)
)
)

@property
Expand Down Expand Up @@ -738,7 +752,9 @@ def regularization_weights_from(self, index: int) -> np.ndarray:

return np.zeros((pixels,))

return regularization.regularization_weights_from(linear_obj=linear_obj, xp=self._xp)
return regularization.regularization_weights_from(
linear_obj=linear_obj, xp=self._xp
)

@property
def regularization_weights_mapper_dict(self) -> Dict[LinearObj, np.ndarray]:
Expand Down
21 changes: 9 additions & 12 deletions autoarray/inversion/inversion/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def inversion_from(
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
xp=np
xp=np,
):
"""
Factory which given an input dataset and list of linear objects, creates an `Inversion`.
Expand Down Expand Up @@ -60,14 +60,11 @@ def inversion_from(
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp
xp=xp,
)

return inversion_interferometer_from(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
xp=xp
dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, xp=xp
)


Expand All @@ -76,7 +73,7 @@ def inversion_imaging_from(
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
xp=np
xp=np,
):
"""
Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`.
Expand Down Expand Up @@ -129,23 +126,23 @@ def inversion_imaging_from(
w_tilde=w_tilde,
linear_obj_list=linear_obj_list,
settings=settings,
xp=xp
xp=xp,
)

return InversionImagingMapping(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp
xp=xp,
)


def inversion_interferometer_from(
dataset: Union[Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
xp=np
xp=np,
):
"""
Factory which given an input `Interferometer` dataset and list of linear objects, creates
Expand Down Expand Up @@ -199,13 +196,13 @@ def inversion_interferometer_from(
w_tilde=w_tilde,
linear_obj_list=linear_obj_list,
settings=settings,
xp=xp
xp=xp,
)

else:
return InversionInterferometerMapping(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
xp=xp
xp=xp,
)
12 changes: 6 additions & 6 deletions autoarray/inversion/inversion/imaging/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
xp=np
xp=np,
):
"""
An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
Expand Down Expand Up @@ -93,7 +93,9 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
return [
(
self.psf.convolved_mapping_matrix_from(
mapping_matrix=linear_obj.mapping_matrix, mask=self.mask, xp=self._xp
mapping_matrix=linear_obj.mapping_matrix,
mask=self.mask,
xp=self._xp,
)
if linear_obj.operated_mapping_matrix_override is None
else self.linear_func_operated_mapping_matrix_dict[linear_obj]
Expand Down Expand Up @@ -137,7 +139,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict:
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
mapping_matrix=linear_func.mapping_matrix,
mask=self.mask,
xp=self._xp
xp=self._xp,
)

linear_func_operated_mapping_matrix_dict[linear_func] = (
Expand Down Expand Up @@ -217,9 +219,7 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict:

for mapper in self.cls_list_from(cls=AbstractMapper):
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
mapping_matrix=mapper.mapping_matrix,
mask=self.mask,
xp=self._xp
mapping_matrix=mapper.mapping_matrix, mask=self.mask, xp=self._xp
)

mapper_operated_mapping_matrix_dict[mapper] = operated_mapping_matrix
Expand Down
Loading
Loading