Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
from .operators.contour import Grid2DContour
from .layout.layout import Layout1D
from .layout.layout import Layout2D
from .preloads import Preloads
from .structures.arrays.uniform_1d import Array1D
from .structures.arrays.uniform_2d import Array2D
from .structures.arrays.rgb import Array2DRGB
Expand Down
1 change: 1 addition & 0 deletions autoarray/config/general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ psf:
inversion:
check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same.
use_positive_only_solver: true # If True, inversion's use a positive-only linear algebra solver by default, which is slower but prevents unphysical negative values in the reconstructed solutuion.
use_edge_zeroed_pixels : true # If True, the edge pixels of a pixelization are set to zero, which prevents unphysical values in the reconstructed solution at the edge of the pixelization.
no_regularization_add_to_curvature_diag_value : 1.0e-3 # The default value added to the curvature matrix's diagonal when regularization is not applied to a linear object, which prevents inversion's failing due to the matrix being singular.
use_border_relocator: false # If True, by default a pixelization's border is used to relocate all pixels outside its border to the border.
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.
Expand Down
4 changes: 2 additions & 2 deletions autoarray/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def make_delaunay_mapper_9_3x3():
pixel_scales=1.0,
)

mesh = aa.mesh.Delaunay()
mesh = aa.mesh.Delaunay(pixels=9)

interpolator = mesh.interpolator_from(
source_plane_data_grid=make_grid_2d_sub_2_7x7(),
Expand Down Expand Up @@ -443,7 +443,7 @@ def make_knn_mapper_9_3x3():
pixel_scales=1.0,
)

mesh = aa.mesh.KNearestNeighbor(split_neighbor_division=1)
mesh = aa.mesh.KNearestNeighbor(pixels=9, split_neighbor_division=1)

interpolator = mesh.interpolator_from(
source_plane_data_grid=make_grid_2d_sub_2_7x7(),
Expand Down
125 changes: 108 additions & 17 deletions autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from autoarray.inversion.mappers.abstract import Mapper
from autoarray.inversion.regularization.abstract import AbstractRegularization
from autoarray.settings import Settings
from autoarray.preloads import Preloads
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
from autoarray.structures.visibilities import Visibilities
Expand All @@ -27,7 +26,6 @@ def __init__(
dataset: Union[Imaging, Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: Settings = None,
preloads: Preloads = None,
xp=np,
):
"""
Expand Down Expand Up @@ -74,8 +72,6 @@ def __init__(

self.settings = settings or Settings()

self.preloads = preloads or Preloads()

self.use_jax = xp is not np

@property
Expand Down Expand Up @@ -234,9 +230,6 @@ def no_regularization_index_list(self) -> List[int]:
@property
def mapper_indices(self) -> np.ndarray:

if self.preloads.mapper_indices is not None:
return self.preloads.mapper_indices

mapper_indices = []

param_range_list = self.param_range_list_from(cls=Mapper)
Expand Down Expand Up @@ -386,6 +379,107 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
# Zero rows and columns in the matrix we want to ignore
return self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep]

@cached_property
def zeroed_ids_to_keep(self):
"""
Return the **positive global indices** of linear parameters that should be
kept (solved for) in the inversion, accounting for **zeroed pixel indices**
from one or more mappers.

---------------------------------------------------------------------------
Parameter vector layout
---------------------------------------------------------------------------
This method assumes the full linear parameter vector is ordered as:

[ non-pixel linear objects ][ mapper_0 pixels ][ mapper_1 pixels ] ... [ mapper_M pixels ]

where:

- *Non-pixel linear objects* include quantities such as analytic light
profiles, regularization amplitudes, etc.
- Each mapper contributes a contiguous block of pixel-based linear parameters.
- The concatenated pixel blocks occupy the **final** entries of the parameter
vector, with total length:

total_pixels = sum(mapper.mesh.pixels for mapper in mappers)

---------------------------------------------------------------------------
Zeroed pixel convention
---------------------------------------------------------------------------
For each mapper:

- `mapper.mesh.zeroed_pixels` must be a 1D array of **positive, mesh-local**
pixel indices in the range `[0, mapper.mesh.pixels - 1]`.
- These indices identify pixels that should be **excluded** from the linear
solve (e.g. edge pixels, masked regions, or padding pixels).
- Indexing is defined purely within the mapper’s own pixelization (e.g.
row-major flattening for rectangular meshes).

This method converts all mesh-local zeroed pixel indices into **global
parameter indices**, correctly offsetting for:
- the presence of non-pixel linear objects at the start of the vector
- the cumulative pixel counts of preceding mappers

---------------------------------------------------------------------------
Backend and implementation details
---------------------------------------------------------------------------
- The implementation is backend-agnostic and supports both NumPy and JAX via
`self._xp`.
- The returned indices are **positive global indices**, suitable for advanced
indexing of:
- `self.data_vector`
- `self.curvature_reg_matrix`
- When using JAX, this method avoids backend-incompatible operations and
preserves JIT compatibility under the same constraints as the rest of the
inversion pipeline.

Returns
-------
array-like
A 1D array of **positive global indices**, sorted in ascending order,
corresponding to linear parameters that should be kept in the inversion.
"""

mapper_list = self.cls_list_from(cls=Mapper)

n_total = int(self.total_params)

pixels_per_mapper = [int(m.mesh.pixels) for m in mapper_list]
total_pixels = int(sum(pixels_per_mapper))

# Global start index of concatenated pixel block
pixel_start = n_total - total_pixels

# Total number of zeroed pixels across all mappers (Python int => static)
total_zeroed = int(sum(len(m.mesh.zeroed_pixels) for m in mapper_list))
n_keep = int(n_total - total_zeroed)

# Build global indices-to-zero across all mappers
zeros_global_list = []
offset = 0
for m, n_pix in zip(mapper_list, pixels_per_mapper):
zeros_local = self._xp.asarray(m.mesh.zeroed_pixels, dtype=self._xp.int32)
zeros_global_list.append(pixel_start + offset + zeros_local)
offset += n_pix

zeros_global = (
self._xp.concatenate(zeros_global_list)
if len(zeros_global_list) > 0
else self._xp.asarray([], dtype=self._xp.int32)
)

keep = self._xp.ones((n_total,), dtype=bool)

if self._xp is np:
keep[zeros_global] = False
keep_ids = self._xp.nonzero(keep)[0]

else:
keep = keep.at[zeros_global].set(False)
keep_ids = self._xp.nonzero(keep, size=n_keep)[0]

return keep_ids

@cached_property
def reconstruction(self) -> np.ndarray:
"""
Expand All @@ -405,16 +499,13 @@ def reconstruction(self) -> np.ndarray:

if self.settings.use_positive_only_solver:

if self.preloads.source_pixel_zeroed_indices is not None:

# ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
ids_to_keep = self.preloads.source_pixel_zeroed_indices_to_keep
if self.settings.use_edge_zeroed_pixels and self.has(cls=Mapper):

# Use advanced indexing to select rows/columns
data_vector = self.data_vector[ids_to_keep]
curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][
:, ids_to_keep
]
data_vector = self.data_vector[self.zeroed_ids_to_keep]
curvature_reg_matrix = self.curvature_reg_matrix[
self.zeroed_ids_to_keep
][:, self.zeroed_ids_to_keep]

# Perform reconstruction via fnnls
reconstruction_partial = (
Expand All @@ -431,11 +522,11 @@ def reconstruction(self) -> np.ndarray:

# Scatter the partial solution back to the full shape
if self._xp.__name__.startswith("jax"):
reconstruction = reconstruction.at[ids_to_keep].set(
reconstruction = reconstruction.at[self.zeroed_ids_to_keep].set(
reconstruction_partial
)
else:
reconstruction[ids_to_keep] = reconstruction_partial
reconstruction[self.zeroed_ids_to_keep] = reconstruction_partial

return reconstruction

Expand Down
7 changes: 0 additions & 7 deletions autoarray/inversion/inversion/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@
InversionImagingSparse,
)
from autoarray.settings import Settings
from autoarray.preloads import Preloads
from autoarray.structures.arrays.uniform_2d import Array2D


def inversion_from(
dataset: Union[Imaging, Interferometer, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: Settings = None,
preloads: Preloads = None,
xp=np,
):
"""
Expand Down Expand Up @@ -68,7 +66,6 @@ def inversion_from(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp,
)

Expand All @@ -81,7 +78,6 @@ def inversion_imaging_from(
dataset,
linear_obj_list: List[LinearObj],
settings: Settings = None,
preloads: Preloads = None,
xp=np,
):
"""
Expand Down Expand Up @@ -133,23 +129,20 @@ def inversion_imaging_from(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp,
)

return InversionImagingSparse(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp,
)

return InversionImagingMapping(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp,
)

Expand Down
3 changes: 0 additions & 3 deletions autoarray/inversion/inversion/imaging/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from autoarray.inversion.inversion.abstract import AbstractInversion
from autoarray.inversion.linear_obj.linear_obj import LinearObj
from autoarray.settings import Settings
from autoarray.preloads import Preloads

from autoarray.inversion.inversion.imaging import inversion_imaging_util

Expand All @@ -19,7 +18,6 @@ def __init__(
dataset: Union[Imaging, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: Settings = None,
preloads: Preloads = None,
xp=np,
):
"""
Expand Down Expand Up @@ -67,7 +65,6 @@ def __init__(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp,
)

Expand Down
3 changes: 0 additions & 3 deletions autoarray/inversion/inversion/imaging/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from autoarray.inversion.linear_obj.linear_obj import LinearObj
from autoarray.inversion.mappers.abstract import Mapper
from autoarray.settings import Settings
from autoarray.preloads import Preloads
from autoarray.structures.arrays.uniform_2d import Array2D

from autoarray.inversion.inversion import inversion_util
Expand All @@ -22,7 +21,6 @@ def __init__(
dataset: Union[Imaging, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: Settings = None,
preloads: Preloads = None,
xp=np,
):
"""
Expand All @@ -49,7 +47,6 @@ def __init__(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp,
)

Expand Down
3 changes: 0 additions & 3 deletions autoarray/inversion/inversion/imaging/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from autoarray.settings import Settings
from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList
from autoarray.inversion.mappers.abstract import Mapper
from autoarray.preloads import Preloads
from autoarray.structures.arrays.uniform_2d import Array2D

from autoarray.inversion.inversion.imaging import inversion_imaging_util
Expand All @@ -22,7 +21,6 @@ def __init__(
dataset: Union[Imaging, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: Settings = None,
preloads: Preloads = None,
xp=np,
):
"""
Expand All @@ -49,7 +47,6 @@ def __init__(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp,
)

Expand Down
3 changes: 0 additions & 3 deletions autoarray/inversion/inversion/imaging_numba/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from autoarray.settings import Settings
from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList
from autoarray.inversion.mappers.abstract import Mapper
from autoarray.preloads import Preloads
from autoarray.structures.arrays.uniform_2d import Array2D

from autoarray.inversion.inversion.imaging_numba import inversion_imaging_numba_util
Expand All @@ -22,7 +21,6 @@ def __init__(
dataset: Union[Imaging, DatasetInterface],
linear_obj_list: List[LinearObj],
settings: Settings = None,
preloads: Preloads = None,
xp=np,
):
"""
Expand All @@ -49,7 +47,6 @@ def __init__(
dataset=dataset,
linear_obj_list=linear_obj_list,
settings=settings,
preloads=preloads,
xp=xp,
)

Expand Down
7 changes: 5 additions & 2 deletions autoarray/inversion/mappers/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
regularization: Optional[AbstractRegularization] = None,
settings: Settings = None,
image_plane_mesh_grid=None,
preloads=None,
xp=np,
):
"""
Expand Down Expand Up @@ -96,7 +95,7 @@ def __init__(
self.interpolator = interpolator

self.image_plane_mesh_grid = image_plane_mesh_grid
self.preloads = preloads

self.settings = settings or Settings()

@property
Expand All @@ -111,6 +110,10 @@ def pixels(self) -> int:
def mask(self):
return self.source_plane_data_grid.mask

@property
def mesh(self):
return self.interpolator.mesh

@property
def mesh_geometry(self):
return self.interpolator.mesh_geometry
Expand Down
Loading
Loading