Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2736e33
remove cached properties from BorderRelocator for easier JAx
Jun 25, 2025
cd50e74
removed for loop from border function
Jun 25, 2025
ea252a2
converted relocated_grid_from to use JAX
Jun 25, 2025
b7cb3c0
border relocator function converred to JAX
Jun 25, 2025
dec6b51
grid_2d_slim_via_shape_native_not_mask_from
Jun 25, 2025
629e7e7
Rectangular fidxes
Jun 25, 2025
0766edd
fix w tild eiwth some ndarray conversions
Jun 26, 2025
3d0233f
fix over sampling tests
Jun 26, 2025
9673b4c
fix plotting
Jun 26, 2025
7e3509c
fix some more tests due to numba jax
Jun 26, 2025
b35e32d
coment out test to get past it for now, think its just linear lagebra…
Jun 26, 2025
74bfcda
updated _Reducedd matrices to use zeroing
Jun 26, 2025
1d6d517
regularization_matrix_Reduced
Jun 26, 2025
e4219b0
fix test
Jun 26, 2025
1b5b64f
add preloading in order to pass mapper indexes
Jun 26, 2025
93157b8
full JAX success
Jun 26, 2025
0b424c4
adaptive_pixel_signals_from JAX-d
Jun 26, 2025
a92d828
convert mapped_to_source_via_mapping_matrix_from to numpy
Jun 26, 2025
5677589
update data_weight_total_for_pix_from
Jun 26, 2025
3dc49b0
moved sub_slim_indexes_for_pix_index to inversion_interferometer_util
Jun 26, 2025
101b704
convert soem regularization util functions from numba to numpy
Jul 13, 2025
2896a0c
remove minus ones in data_weight_total_for_pix_from
Jul 13, 2025
6ebd7a3
mapper index list returns ndarray
Jul 13, 2025
e758555
fix autoarray/inversion/inversion/interferometer/w_tilde.py
Jul 13, 2025
9aa26b5
fixed bug where regularization matrix was returned for curvature_reg_…
Jul 13, 2025
968a994
fix case where border relocator is off
Jul 13, 2025
8af880d
w tilde now default to false
Jul 13, 2025
bbef0e3
remove old source pixel zeroing functionality
Jul 13, 2025
415926f
docuemnet preloasds and mapper_index_list -> mapper_indices
Jul 13, 2025
12b2d72
fix last unit test
Jul 13, 2025
67a5830
black
Jul 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .operators.contour import Grid2DContour
from .layout.layout import Layout1D
from .layout.layout import Layout2D
from .preloads import Preloads
from .structures.arrays.uniform_1d import Array1D
from .structures.arrays.uniform_2d import Array2D
from .structures.arrays.rgb import Array2DRGB
Expand Down
187 changes: 75 additions & 112 deletions autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper
from autoarray.inversion.regularization.abstract import AbstractRegularization
from autoarray.inversion.inversion.settings import SettingsInversion
from autoarray.preloads import Preloads
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.visibilities import Visibilities

Expand All @@ -27,6 +28,7 @@ def __init__(
dataset: Union[Imaging, Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
):
"""
An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
Expand Down Expand Up @@ -66,23 +68,14 @@ def __init__(
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
"""

try:
import numba
except ModuleNotFoundError:
raise exc.InversionException(
"Inversion functionality (linear light profiles, pixelized reconstructions) is "
"disabled if numba is not installed.\n\n"
"This is because the run-times without numba are too slow.\n\n"
"Please install numba, which is described at the following web page:\n\n"
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
)

self.dataset = dataset

self.linear_obj_list = linear_obj_list

self.settings = settings

self.preloads = preloads or Preloads()

@property
def data(self):
return self.dataset.data
Expand Down Expand Up @@ -156,17 +149,9 @@ def param_range_list_from(self, cls: Type) -> List[List[int]]:
-------
A list of the index range of the parameters of each linear object in the inversion of the input cls type.
"""
index_list = []

pixel_count = 0

for linear_obj in self.linear_obj_list:
if isinstance(linear_obj, cls):
index_list.append([pixel_count, pixel_count + linear_obj.params])

pixel_count += linear_obj.params

return index_list
return inversion_util.param_range_list_from(
cls=cls, linear_obj_list=self.linear_obj_list
)

def cls_list_from(self, cls: Type, cls_filtered: Optional[Type] = None) -> List:
"""
Expand Down Expand Up @@ -267,6 +252,22 @@ def no_regularization_index_list(self) -> List[int]:

return no_regularization_index_list

@property
def mapper_indices(self) -> np.ndarray:

if self.preloads.mapper_indices is not None:
return self.preloads.mapper_indices

mapper_indices = []

param_range_list = self.param_range_list_from(cls=AbstractMapper)

for param_range in param_range_list:

mapper_indices += range(param_range[0], param_range[1])

return np.array(mapper_indices)

@property
def mask(self) -> Array2D:
return self.data.mask
Expand Down Expand Up @@ -354,19 +355,14 @@ def regularization_matrix_reduced(self) -> Optional[np.ndarray]:
regularization it is bypassed.
"""

regularization_matrix = self.regularization_matrix

if self.all_linear_obj_have_regularization:
return regularization_matrix
return self.regularization_matrix

regularization_matrix = np.delete(
regularization_matrix, self.no_regularization_index_list, 0
)
regularization_matrix = np.delete(
regularization_matrix, self.no_regularization_index_list, 1
)
# ids of values which are on edge so zero-d and not solved for.
ids_to_keep = self.mapper_indices

return regularization_matrix
# Zero rows and columns in the matrix we want to ignore
return self.regularization_matrix[ids_to_keep][:, ids_to_keep]

@cached_property
def curvature_reg_matrix(self) -> np.ndarray:
Expand All @@ -381,55 +377,31 @@ def curvature_reg_matrix(self) -> np.ndarray:
if not self.has(cls=AbstractRegularization):
return self.curvature_matrix

if len(self.regularization_list) == 1:
curvature_matrix = self.curvature_matrix
curvature_matrix += self.regularization_matrix

del self.__dict__["curvature_matrix"]

return curvature_matrix

return np.add(self.curvature_matrix, self.regularization_matrix)
return jnp.add(self.curvature_matrix, self.regularization_matrix)

@cached_property
def curvature_reg_matrix_reduced(self) -> np.ndarray:
def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
"""
The linear system of equations solves for F + regularization_coefficient*H, which is computed below.
The regularization matrix H is used to impose smoothness on our inversion's reconstruction. This enters the
linear algebra system we solve for using D and F above and is given by
equation (12) in https://arxiv.org/pdf/astro-ph/0302587.pdf.

This is the curvature reg matrix for only the mappers, which is necessary for computing the log det
term without the linear light profiles included.
A complete description of regularization is given in the `regularization.py` and `regularization_util.py`
modules.

For multiple mappers, the regularization matrix is computed as the block diagonal of each individual mapper.
The scipy function `block_diag` has an overhead associated with it and if there is only one mapper and
regularization it is bypassed.
"""

if self.all_linear_obj_have_regularization:
return self.curvature_reg_matrix

curvature_reg_matrix = self.curvature_reg_matrix
# ids of values which are on edge so zero-d and not solved for.
ids_to_keep = self.mapper_indices

curvature_reg_matrix = np.delete(
curvature_reg_matrix, self.no_regularization_index_list, 0
)
curvature_reg_matrix = np.delete(
curvature_reg_matrix, self.no_regularization_index_list, 1
)

return curvature_reg_matrix

@property
def mapper_zero_pixel_list(self) -> np.ndarray:
mapper_zero_pixel_list = []
param_range_list = self.param_range_list_from(cls=LinearObj)
for param_range, linear_obj in zip(param_range_list, self.linear_obj_list):
if isinstance(linear_obj, AbstractMapper):
mapping_matrix_for_image_pixels_source_zero = linear_obj.mapping_matrix[
self.settings.image_pixels_source_zero
]
source_pixels_zero = (
np.sum(mapping_matrix_for_image_pixels_source_zero != 0, axis=0)
!= 0
)
mapper_zero_pixel_list.append(
np.where(source_pixels_zero == True)[0] + param_range[0]
)
return mapper_zero_pixel_list
# Zero rows and columns in the matrix we want to ignore
return self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep]

@cached_property
def reconstruction(self) -> np.ndarray:
Expand All @@ -448,51 +420,36 @@ def reconstruction(self) -> np.ndarray:
ZTx := np.dot(Z.T, x)
"""
if self.settings.use_positive_only_solver:
"""
For the new implementation, we now need to take out the cols and rows of
the curvature_reg_matrix that corresponds to the parameters we force to be 0.
Similar for the data vector.

What we actually doing is that we have set the correspoding cols of the Z to be 0.
As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out.
And the data_vector = ZTx, so the corresponding row is also taken out.
"""
if self.preloads.source_pixel_zeroed_indices is not None:

if (
self.has(cls=AbstractMapper)
and self.settings.force_edge_pixels_to_zeros
):
# ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
ids_to_keep = self.preloads.source_pixel_zeroed_indices_to_keep

ids_zeros = jnp.array(self.mapper_edge_pixel_list, dtype=int)
# Use advanced indexing to select rows/columns
data_vector = self.data_vector[ids_to_keep]
curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][
:, ids_to_keep
]

values_to_solve = jnp.ones(
self.curvature_reg_matrix.shape[0], dtype=bool
# Perform reconstruction via fnnls
reconstruction_partial = (
inversion_util.reconstruction_positive_only_from(
data_vector=data_vector,
curvature_reg_matrix=curvature_reg_matrix,
settings=self.settings,
)
)
values_to_solve = values_to_solve.at[ids_zeros].set(False)

data_vector_input = self.data_vector[values_to_solve]

curvature_reg_matrix_input = self.curvature_reg_matrix[
values_to_solve, :
][:, values_to_solve]
# Allocate full solution array
reconstruction = jnp.zeros(self.data_vector.shape[0])

# Get the values to assign (must be a JAX array)
reconstruction = inversion_util.reconstruction_positive_only_from(
data_vector=data_vector_input,
curvature_reg_matrix=curvature_reg_matrix_input,
settings=self.settings,
# Scatter the partial solution back to the full shape
reconstruction = reconstruction.at[ids_to_keep].set(
reconstruction_partial
)

# Allocate JAX array
solutions = jnp.zeros(self.curvature_reg_matrix.shape[0])

# Get indices where True
indices = jnp.where(values_to_solve)[0]

# Set reconstruction values at those indices
solutions = solutions.at[indices].set(reconstruction)

return solutions
return reconstruction

else:

Expand Down Expand Up @@ -522,7 +479,11 @@ def reconstruction_reduced(self) -> np.ndarray:
if self.all_linear_obj_have_regularization:
return self.reconstruction

return np.delete(self.reconstruction, self.no_regularization_index_list, axis=0)
# ids of values which are on edge so zero-d and not solved for.
ids_to_keep = self.mapper_indices

# Zero rows and columns in the matrix we want to ignore
return self.reconstruction[ids_to_keep]

@property
def reconstruction_dict(self) -> Dict[LinearObj, np.ndarray]:
Expand Down Expand Up @@ -665,9 +626,9 @@ def regularization_term(self) -> float:
if not self.has(cls=AbstractRegularization):
return 0.0

return np.matmul(
return jnp.matmul(
self.reconstruction_reduced.T,
np.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced),
jnp.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced),
)

@cached_property
Expand All @@ -682,7 +643,9 @@ def log_det_curvature_reg_matrix_term(self) -> float:

try:
return 2.0 * np.sum(
np.log(np.diag(np.linalg.cholesky(self.curvature_reg_matrix_reduced)))
jnp.log(
jnp.diag(jnp.linalg.cholesky(self.curvature_reg_matrix_reduced))
)
)
except np.linalg.LinAlgError as e:
raise exc.InversionException() from e
Expand Down
5 changes: 5 additions & 0 deletions autoarray/inversion/inversion/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList
from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde
from autoarray.inversion.inversion.settings import SettingsInversion
from autoarray.preloads import Preloads
from autoarray.structures.arrays.uniform_2d import Array2D


def inversion_from(
dataset: Union[Imaging, Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
):
"""
Factory which given an input dataset and list of linear objects, creates an `Inversion`.
Expand Down Expand Up @@ -55,6 +57,7 @@ def inversion_from(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
)

return inversion_interferometer_from(
Expand All @@ -68,6 +71,7 @@ def inversion_imaging_from(
dataset,
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
):
"""
Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`.
Expand Down Expand Up @@ -126,6 +130,7 @@ def inversion_imaging_from(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
)


Expand Down
5 changes: 4 additions & 1 deletion autoarray/inversion/inversion/imaging/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from typing import Dict, List, Optional, Union, Type
from typing import Dict, List, Union, Type

from autoconf import cached_property

Expand All @@ -10,6 +10,7 @@
from autoarray.inversion.inversion.abstract import AbstractInversion
from autoarray.inversion.linear_obj.linear_obj import LinearObj
from autoarray.inversion.inversion.settings import SettingsInversion
from autoarray.preloads import Preloads

from autoarray.inversion.inversion.imaging import inversion_imaging_util

Expand All @@ -20,6 +21,7 @@ def __init__(
dataset: Union[Imaging, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
):
"""
An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
Expand Down Expand Up @@ -66,6 +68,7 @@ def __init__(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
)

@property
Expand Down
3 changes: 3 additions & 0 deletions autoarray/inversion/inversion/imaging/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from autoarray.inversion.linear_obj.linear_obj import LinearObj
from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper
from autoarray.inversion.inversion.settings import SettingsInversion
from autoarray.preloads import Preloads
from autoarray.structures.arrays.uniform_2d import Array2D

from autoarray.inversion.inversion import inversion_util
Expand All @@ -21,6 +22,7 @@ def __init__(
dataset: Union[Imaging, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: SettingsInversion = SettingsInversion(),
preloads: Preloads = None,
):
"""
Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations
Expand All @@ -46,6 +48,7 @@ def __init__(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
)

@property
Expand Down
Loading
Loading