Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 29 additions & 110 deletions autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import jax
import jax.numpy as jnp
import numpy as np
from scipy.linalg import block_diag
Expand Down Expand Up @@ -73,17 +74,6 @@ def __init__(
A dictionary which contains timing of certain functions calls which is used for profiling.
"""

# try:
# import numba
# except ModuleNotFoundError:
# raise exc.InversionException(
# "Inversion functionality (linear light profiles, pixelized reconstructions) is "
# "disabled if numba is not installed.\n\n"
# "This is because the run-times without numba are too slow.\n\n"
# "Please install numba, which is described at the following web page:\n\n"
# "https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
# )

self.dataset = dataset

self.linear_obj_list = linear_obj_list
Expand Down Expand Up @@ -317,7 +307,7 @@ def operated_mapping_matrix(self) -> np.ndarray:
If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous
linear equations are solved simultaneously.
"""
return np.hstack(self.operated_mapping_matrix_list)
return jnp.hstack(self.operated_mapping_matrix_list)

@cached_property
@profile_func
Expand Down Expand Up @@ -474,46 +464,50 @@ def reconstruction(self) -> np.ndarray:
And the data_vector = ZTx, so the corresponding row is also taken out.
"""

if self.settings.force_edge_pixels_to_zeros:
if self.settings.force_edge_image_pixels_to_zeros:
ids_zeros = np.unique(
np.append(
self.mapper_edge_pixel_list, self.mapper_zero_pixel_list
)
)
else:
ids_zeros = self.mapper_edge_pixel_list
if (
self.has(cls=AbstractMapper)
and self.settings.force_edge_pixels_to_zeros
):

values_to_solve = np.ones(
np.shape(self.curvature_reg_matrix)[0], dtype=bool
ids_zeros = jnp.array(self.mapper_edge_pixel_list, dtype=int)

values_to_solve = jnp.ones(
self.curvature_reg_matrix.shape[0], dtype=bool
)
values_to_solve[ids_zeros] = False
values_to_solve = values_to_solve.at[ids_zeros].set(False)

data_vector_input = self.data_vector[values_to_solve]

curvature_reg_matrix_input = self.curvature_reg_matrix[
values_to_solve, :
][:, values_to_solve]

solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0])

solutions[values_to_solve] = (
inversion_util.reconstruction_positive_only_from(
data_vector=data_vector_input,
curvature_reg_matrix=curvature_reg_matrix_input,
settings=self.settings,
)
# Get the values to assign (must be a JAX array)
reconstruction = inversion_util.reconstruction_positive_only_from(
data_vector=data_vector_input,
curvature_reg_matrix=curvature_reg_matrix_input,
settings=self.settings,
)

# Allocate JAX array
solutions = jnp.zeros(self.curvature_reg_matrix.shape[0])

# Get indices where True
indices = jnp.where(values_to_solve)[0]

# Set reconstruction values at those indices
solutions = solutions.at[indices].set(reconstruction)

return solutions

else:
solutions = inversion_util.reconstruction_positive_only_from(

return inversion_util.reconstruction_positive_only_from(
data_vector=self.data_vector,
curvature_reg_matrix=self.curvature_reg_matrix,
settings=self.settings,
)

return solutions

mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)

return inversion_util.reconstruction_positive_negative_from(
Expand All @@ -522,81 +516,6 @@ def reconstruction(self) -> np.ndarray:
mapper_param_range_list=mapper_param_range_list,
)

# @cached_property
# @profile_func
# def reconstruction(self) -> np.ndarray:
# """
# Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12)
# of https://arxiv.org/pdf/astro-ph/0302587.pdf (Positive-Negative solution)
#
# ============================================================================================
#
# Solve the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf (Non-negative solution)
# Find non-negative solution that minimizes |Z * S - x|^2.
#
# We use fnnls (https://github.com/jvendrow/fnnls) to optimize the quadratic value. Two commonly used
# variables in the code are defined as follows:
# ZTZ := np.dot(Z.T, Z)
# ZTx := np.dot(Z.T, x)
# """
# if self.settings.use_positive_only_solver:
# """
# For the new implementation, we now need to take out the cols and rows of
# the curvature_reg_matrix that corresponds to the parameters we force to be 0.
# Similar for the data vector.
#
# What we actually doing is that we have set the correspoding cols of the Z to be 0.
# As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out.
# And the data_vector = ZTx, so the corresponding row is also taken out.
# """
#
# if self.settings.force_edge_pixels_to_zeros:
# if self.settings.force_edge_image_pixels_to_zeros:
# ids_zeros = np.unique(
# np.append(
# self.mapper_edge_pixel_list, self.mapper_zero_pixel_list
# )
# )
# else:
# ids_zeros = self.mapper_edge_pixel_list
#
# values_to_solve = np.ones(
# np.shape(self.curvature_reg_matrix)[0], dtype=bool
# )
# values_to_solve[ids_zeros] = False
#
# data_vector_input = self.data_vector[values_to_solve]
#
# curvature_reg_matrix_input = self.curvature_reg_matrix[
# values_to_solve, :
# ][:, values_to_solve]
#
# solutions = inversion_util.reconstruction_positive_only_from(
# data_vector=data_vector_input,
# curvature_reg_matrix=curvature_reg_matrix_input,
# settings=self.settings,
# )
#
# mask = values_to_solve.astype(bool)
#
# return solutions[mask]
# else:
# solutions = inversion_util.reconstruction_positive_only_from(
# data_vector=self.data_vector,
# curvature_reg_matrix=self.curvature_reg_matrix,
# settings=self.settings,
# )
#
# return solutions
#
# mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)
#
# return inversion_util.reconstruction_positive_negative_from(
# data_vector=self.data_vector,
# curvature_reg_matrix=self.curvature_reg_matrix,
# mapper_param_range_list=mapper_param_range_list,
# )

@cached_property
@profile_func
def reconstruction_reduced(self) -> np.ndarray:
Expand Down
1 change: 0 additions & 1 deletion autoarray/inversion/inversion/imaging/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def data_vector(self) -> np.ndarray:

The calculation is described in more detail in `inversion_util.data_vector_via_blurred_mapping_matrix_from`.
"""

return inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from(
blurred_mapping_matrix=self.operated_mapping_matrix,
image=self.data.array,
Expand Down
2 changes: 1 addition & 1 deletion autoarray/inversion/inversion/imaging/w_tilde.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray:
data_weights=mapper.unique_mappings.data_weights,
pix_lengths=mapper.unique_mappings.pix_lengths,
pix_pixels=mapper.params,
curvature_weights=curvature_weights,
curvature_weights=np.array(curvature_weights),
image_frame_1d_lengths=self.convolver.image_frame_1d_lengths,
image_frame_1d_indexes=self.convolver.image_frame_1d_indexes,
image_frame_1d_kernels=self.convolver.image_frame_1d_kernels,
Expand Down
8 changes: 5 additions & 3 deletions autoarray/inversion/inversion/interferometer/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def data_vector(self) -> np.ndarray:
"""

return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from(
transformed_mapping_matrix=self.operated_mapping_matrix,
transformed_mapping_matrix=np.array(self.operated_mapping_matrix),
visibilities=np.array(self.data),
noise_map=np.array(self.noise_map),
)
Expand Down Expand Up @@ -152,8 +152,10 @@ def mapped_reconstructed_data_dict(

visibilities = (
inversion_interferometer_util.mapped_reconstructed_visibilities_from(
transformed_mapping_matrix=operated_mapping_matrix_list[index],
reconstruction=reconstruction,
transformed_mapping_matrix=np.array(
operated_mapping_matrix_list[index]
),
reconstruction=np.array(reconstruction),
)
)

Expand Down
Loading
Loading