Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import jax.numpy as jnp
from jax.scipy.linalg import block_diag
import numpy as np

from typing import Dict, List, Optional, Type, Union
Expand Down Expand Up @@ -334,8 +335,6 @@ def regularization_matrix(self) -> Optional[np.ndarray]:
If the `settings.force_edge_pixels_to_zeros` is `True`, the edge pixels of each mapper in the inversion
are regularized so high their value is forced to zero.
"""
from scipy.linalg import block_diag

return block_diag(
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
)
Expand Down Expand Up @@ -664,30 +663,17 @@ def log_det_regularization_matrix_term(self) -> float:
float
The log determinant of the regularization matrix.
"""
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import splu

if not self.has(cls=AbstractRegularization):
return 0.0

try:
lu = splu(csc_matrix(self.regularization_matrix_reduced))
diagL = lu.L.diagonal()
diagU = lu.U.diagonal()
diagL = diagL.astype(np.complex128)
diagU = diagU.astype(np.complex128)

return np.real(np.log(diagL).sum() + np.log(diagU).sum())

except RuntimeError:
try:
return 2.0 * np.sum(
np.log(
np.diag(np.linalg.cholesky(self.regularization_matrix_reduced))
)
return 2.0 * np.sum(
jnp.log(
jnp.diag(jnp.linalg.cholesky(self.regularization_matrix_reduced))
)
except np.linalg.LinAlgError as e:
raise exc.InversionException() from e
)
except np.linalg.LinAlgError as e:
raise exc.InversionException() from e

@property
def reconstruction_noise_map_with_covariance(self) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion autoarray/inversion/pixelization/mappers/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def pixel_signals_from(self, signal_scale: float) -> np.ndarray:
pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index,
pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index,
slim_index_for_sub_slim_index=self.over_sampler.slim_for_sub_slim,
adapt_data=np.array(self.adapt_data),
adapt_data=self.adapt_data.array,
)

def slim_indexes_for_pix_indexes(self, pix_indexes: List) -> List[List]:
Expand Down
147 changes: 81 additions & 66 deletions autoarray/inversion/pixelization/mesh/mesh_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from autoarray import numba_util


@numba_util.jit()
def rectangular_neighbors_from(
shape_native: Tuple[int, int],
) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -68,7 +67,6 @@ def rectangular_neighbors_from(
return neighbors, neighbors_sizes


@numba_util.jit()
def rectangular_corner_neighbors(
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -113,7 +111,6 @@ def rectangular_corner_neighbors(
return neighbors, neighbors_sizes


@numba_util.jit()
def rectangular_top_edge_neighbors(
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -136,17 +133,20 @@ def rectangular_top_edge_neighbors(
-------
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
"""
for pix in range(1, shape_native[1] - 1):
pixel_index = pix
neighbors[pixel_index, 0:3] = np.array(
[pixel_index - 1, pixel_index + 1, pixel_index + shape_native[1]]
)
neighbors_sizes[pixel_index] = 3
"""
Vectorized version of the top edge neighbor update using NumPy arithmetic.
"""
# Pixels along the top edge, excluding corners
top_edge_pixels = np.arange(1, shape_native[1] - 1)

neighbors[top_edge_pixels, 0] = top_edge_pixels - 1
neighbors[top_edge_pixels, 1] = top_edge_pixels + 1
neighbors[top_edge_pixels, 2] = top_edge_pixels + shape_native[1]
neighbors_sizes[top_edge_pixels] = 3

return neighbors, neighbors_sizes


@numba_util.jit()
def rectangular_left_edge_neighbors(
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -169,21 +169,20 @@ def rectangular_left_edge_neighbors(
-------
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
"""
for pix in range(1, shape_native[0] - 1):
pixel_index = pix * shape_native[1]
neighbors[pixel_index, 0:3] = np.array(
[
pixel_index - shape_native[1],
pixel_index + 1,
pixel_index + shape_native[1],
]
)
neighbors_sizes[pixel_index] = 3
# Row indices (excluding top and bottom corners)
rows = np.arange(1, shape_native[0] - 1)

# Convert to flat pixel indices for the left edge (first column)
pixel_indices = rows * shape_native[1]

neighbors[pixel_indices, 0] = pixel_indices - shape_native[1]
neighbors[pixel_indices, 1] = pixel_indices + 1
neighbors[pixel_indices, 2] = pixel_indices + shape_native[1]
neighbors_sizes[pixel_indices] = 3

return neighbors, neighbors_sizes


@numba_util.jit()
def rectangular_right_edge_neighbors(
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -206,21 +205,20 @@ def rectangular_right_edge_neighbors(
-------
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
"""
for pix in range(1, shape_native[0] - 1):
pixel_index = pix * shape_native[1] + shape_native[1] - 1
neighbors[pixel_index, 0:3] = np.array(
[
pixel_index - shape_native[1],
pixel_index - 1,
pixel_index + shape_native[1],
]
)
neighbors_sizes[pixel_index] = 3
# Rows excluding the top and bottom corners
rows = np.arange(1, shape_native[0] - 1)

# Flat indices for the right edge pixels
pixel_indices = rows * shape_native[1] + shape_native[1] - 1

neighbors[pixel_indices, 0] = pixel_indices - shape_native[1]
neighbors[pixel_indices, 1] = pixel_indices - 1
neighbors[pixel_indices, 2] = pixel_indices + shape_native[1]
neighbors_sizes[pixel_indices] = 3

return neighbors, neighbors_sizes


@numba_util.jit()
def rectangular_bottom_edge_neighbors(
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -243,19 +241,21 @@ def rectangular_bottom_edge_neighbors(
-------
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
"""
pixels = int(shape_native[0] * shape_native[1])
n_rows, n_cols = shape_native
pixels = n_rows * n_cols

for pix in range(1, shape_native[1] - 1):
pixel_index = pixels - pix - 1
neighbors[pixel_index, 0:3] = np.array(
[pixel_index - shape_native[1], pixel_index - 1, pixel_index + 1]
)
neighbors_sizes[pixel_index] = 3
# Horizontal pixel positions along bottom row, excluding corners
cols = np.arange(1, n_cols - 1)
pixel_indices = pixels - cols - 1 # Reverse order from right to left

neighbors[pixel_indices, 0] = pixel_indices - n_cols
neighbors[pixel_indices, 1] = pixel_indices - 1
neighbors[pixel_indices, 2] = pixel_indices + 1
neighbors_sizes[pixel_indices] = 3

return neighbors, neighbors_sizes


@numba_util.jit()
def rectangular_central_neighbors(
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -279,46 +279,61 @@ def rectangular_central_neighbors(
-------
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
"""
for x in range(1, shape_native[0] - 1):
for y in range(1, shape_native[1] - 1):
pixel_index = x * shape_native[1] + y
neighbors[pixel_index, 0:4] = np.array(
[
pixel_index - shape_native[1],
pixel_index - 1,
pixel_index + 1,
pixel_index + shape_native[1],
]
)
neighbors_sizes[pixel_index] = 4
n_rows, n_cols = shape_native

# Grid coordinates excluding edges
xs = np.arange(1, n_rows - 1)
ys = np.arange(1, n_cols - 1)

# 2D grid of central pixel indices
grid_x, grid_y = np.meshgrid(xs, ys, indexing="ij")
pixel_indices = grid_x * n_cols + grid_y
pixel_indices = pixel_indices.ravel()

# Compute neighbor indices
neighbors[pixel_indices, 0] = pixel_indices - n_cols # Up
neighbors[pixel_indices, 1] = pixel_indices - 1 # Left
neighbors[pixel_indices, 2] = pixel_indices + 1 # Right
neighbors[pixel_indices, 3] = pixel_indices + n_cols # Down

neighbors_sizes[pixel_indices] = 4

return neighbors, neighbors_sizes


def rectangular_edge_pixel_list_from(neighbors: np.ndarray) -> List:
def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]:
"""
Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization.

This is computed by searching the `neighbors` array for pixels that have a neighbor with index -1, meaning there
is at least one neighbor from the 4 expected missing.
Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization,
based on its 2D shape.

Parameters
----------
neighbors
An array of dimensions [total_pixels, 4] which provides the index of all neighbors of every pixel in the
rectangular pixelization (entries of -1 correspond to no neighbor).
shape_native
The (rows, cols) shape of the rectangular 2D pixel grid.

Returns
-------
A list of the 1D indices of all pixels on the edge of a rectangular pixelization.
A list of the 1D indices of all edge pixels.
"""
edge_pixel_list = []
rows, cols = shape_native

# Top row
top = np.arange(0, cols)

# Bottom row
bottom = np.arange((rows - 1) * cols, rows * cols)

# Left column (excluding corners)
left = np.arange(1, rows - 1) * cols

# Right column (excluding corners)
right = (np.arange(1, rows - 1) + 1) * cols - 1

for i, neighbors in enumerate(neighbors):
if -1 in neighbors:
edge_pixel_list.append(i)
# Concatenate all edge indices
edge_pixel_indices = np.concatenate([top, left, right, bottom])

return edge_pixel_list
# Sort and return
return np.sort(edge_pixel_indices).tolist()


@numba_util.jit()
Expand Down
Loading
Loading