Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from autoarray.dataset.interferometer.w_tilde import WTildeInterferometer
from autoarray.dataset.grids import GridsDataset
from autoarray.operators.transformer import TransformerNUFFT

from autoarray.mask.mask_2d import Mask2D
from autoarray.structures.visibilities import Visibilities
from autoarray.structures.visibilities import VisibilitiesNoiseMap

Expand All @@ -25,8 +25,9 @@ def __init__(
data: Visibilities,
noise_map: VisibilitiesNoiseMap,
uv_wavelengths: np.ndarray,
real_space_mask,
real_space_mask: Mask2D,
transformer_class=TransformerNUFFT,
dft_preload_transform: bool = True,
preprocessing_directory=None,
):
"""
Expand Down Expand Up @@ -73,6 +74,9 @@ def __init__(
transformer_class
The class of the Fourier Transform which maps images from real space to Fourier space visibilities and
the uv-plane.
dft_preload_transform
If True, precomputes and stores the cosine and sine terms for the Fourier transform.
This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets).
"""
self.real_space_mask = real_space_mask

Expand All @@ -86,7 +90,9 @@ def __init__(
self.uv_wavelengths = uv_wavelengths

self.transformer = transformer_class(
uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask
uv_wavelengths=uv_wavelengths,
real_space_mask=real_space_mask,
preload_transform=dft_preload_transform,
)

self.preprocessing_directory = (
Expand Down Expand Up @@ -114,6 +120,7 @@ def from_fits(
noise_map_hdu=0,
uv_wavelengths_hdu=0,
transformer_class=TransformerNUFFT,
dft_preload_transform: bool = True,
):
"""
Factory for loading the interferometer data_type from .fits files, as well as computing properties like the
Expand All @@ -139,6 +146,7 @@ def from_fits(
noise_map=noise_map,
uv_wavelengths=uv_wavelengths,
transformer_class=transformer_class,
dft_preload_transform=dft_preload_transform,
)

def w_tilde_preprocessing(self):
Expand Down
2 changes: 1 addition & 1 deletion autoarray/fit/fit_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def chi_squared(self) -> float:
Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map.
"""
return fit_util.chi_squared_complex_from(
chi_squared_map=self.chi_squared_map,
chi_squared_map=self.chi_squared_map.array,
)

@property
Expand Down
4 changes: 2 additions & 2 deletions autoarray/fit/fit_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def chi_squared_complex_from(*, chi_squared_map: jnp.ndarray) -> float:
chi_squared_map
The chi-squared-map of values of the model-data fit to the dataset.
"""
chi_squared_real = jnp.sum(np.array(chi_squared_map.real))
chi_squared_imag = jnp.sum(np.array(chi_squared_map.imag))
chi_squared_real = jnp.sum(chi_squared_map.real)
chi_squared_imag = jnp.sum(chi_squared_map.imag)
return chi_squared_real + chi_squared_imag


Expand Down
2 changes: 1 addition & 1 deletion autoarray/fit/plot/fit_interferometer_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def figures_2d(
auto_labels=AutoLabels(
title="Model Visibilities", filename="model_data"
),
color_array=np.real(self.fit.model_data),
color_array=np.real(self.fit.model_data.array),
)

if residual_map_real:
Expand Down
2 changes: 1 addition & 1 deletion autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def mapping_matrix(self) -> np.ndarray:
If there are multiple linear objects, the mapping matrices are stacked such that their simultaneous linear
equations are solved simultaneously. This property returns the stacked mapping matrix.
"""
return np.hstack(
return jnp.hstack(
[linear_obj.mapping_matrix for linear_obj in self.linear_obj_list]
)

Expand Down
6 changes: 3 additions & 3 deletions autoarray/inversion/inversion/dataset_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(
noise_map
An array describing the RMS standard deviation error in each pixel used for computing quantities like the
chi-squared in a fit (in PyAutoGalaxy and PyAutoLens the recommended units are electrons per second).
grids
The grids of (y,x) Cartesian coordinates that the image data is paired with, which are used for evaluting
light profiles and calculations associated with a pixelization.
over_sampler
Performs over-sampling whereby the masked image pixels are split into sub-pixels, which are all
mapped via the mapper with sub-fractional values of flux.
Expand All @@ -50,9 +53,6 @@ def __init__(
w_tilde
The w_tilde matrix used by the w-tilde formalism to construct the data vector and
curvature matrix during an inversion efficiently..
grids
The grids of (y,x) Cartesian coordinates that the image data is paired with, which are used for evaluting
light profiles and calculations associated with a pixelization.
noise_covariance_matrix
A noise-map covariance matrix representing the covariance between noise in every `data` value, which
can be used via a bespoke fit to account for correlated noise in the data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ def w_tilde_via_preload_from(w_tilde_preload, native_index_for_slim_index):
return w_tilde_via_preload


@numba_util.jit()
def data_vector_via_transformed_mapping_matrix_from(
transformed_mapping_matrix: np.ndarray,
visibilities: np.ndarray,
Expand All @@ -406,31 +405,24 @@ def data_vector_via_transformed_mapping_matrix_from(
noise_map
Flattened 1D array of the noise-map used by the inversion during the fit.
"""
# Extract components
vis_real = visibilities.real
vis_imag = visibilities.imag
f_real = transformed_mapping_matrix.real
f_imag = transformed_mapping_matrix.imag
noise_real = noise_map.real
noise_imag = noise_map.imag

data_vector = np.zeros(transformed_mapping_matrix.shape[1])

visibilities_real = visibilities.real
visibilities_imag = visibilities.imag
transformed_mapping_matrix_real = transformed_mapping_matrix.real
transformed_mapping_matrix_imag = transformed_mapping_matrix.imag
noise_map_real = noise_map.real
noise_map_imag = noise_map.imag

for vis_1d_index in range(transformed_mapping_matrix.shape[0]):
for pix_1d_index in range(transformed_mapping_matrix.shape[1]):
real_value = (
visibilities_real[vis_1d_index]
* transformed_mapping_matrix_real[vis_1d_index, pix_1d_index]
/ (noise_map_real[vis_1d_index] ** 2.0)
)
imag_value = (
visibilities_imag[vis_1d_index]
* transformed_mapping_matrix_imag[vis_1d_index, pix_1d_index]
/ (noise_map_imag[vis_1d_index] ** 2.0)
)
data_vector[pix_1d_index] += real_value + imag_value
# Square noise components
inv_var_real = 1.0 / (noise_real**2)
inv_var_imag = 1.0 / (noise_imag**2)

return data_vector
# Real and imaginary contributions
weighted_real = (vis_real * inv_var_real)[:, None] * f_real
weighted_imag = (vis_imag * inv_var_imag)[:, None] * f_imag

# Sum over visibilities
return np.sum(weighted_real + weighted_imag, axis=0)


@numba_util.jit()
Expand Down Expand Up @@ -512,7 +504,6 @@ def curvature_matrix_via_w_tilde_curvature_preload_interferometer_from(
return curvature_matrix


@numba_util.jit()
def mapped_reconstructed_visibilities_from(
transformed_mapping_matrix: np.ndarray, reconstruction: np.ndarray
) -> np.ndarray:
Expand All @@ -525,20 +516,7 @@ def mapped_reconstructed_visibilities_from(
The matrix representing the blurred mappings between sub-grid pixels and pixelization pixels.

"""
mapped_reconstructed_visibilities = (0.0 + 0.0j) * np.zeros(
transformed_mapping_matrix.shape[0]
)

transformed_mapping_matrix_real = transformed_mapping_matrix.real
transformed_mapping_matrix_imag = transformed_mapping_matrix.imag

for i in range(transformed_mapping_matrix.shape[0]):
for j in range(reconstruction.shape[0]):
mapped_reconstructed_visibilities[i] += (
reconstruction[j] * transformed_mapping_matrix_real[i, j]
) + 1.0j * (reconstruction[j] * transformed_mapping_matrix_imag[i, j])

return mapped_reconstructed_visibilities
return transformed_mapping_matrix @ reconstruction


"""
Expand Down
15 changes: 7 additions & 8 deletions autoarray/inversion/inversion/interferometer/mapping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
import numpy as np
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -76,8 +77,8 @@ def data_vector(self) -> np.ndarray:
"""

return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from(
transformed_mapping_matrix=np.array(self.operated_mapping_matrix),
visibilities=np.array(self.data),
transformed_mapping_matrix=self.operated_mapping_matrix,
visibilities=self.data,
noise_map=np.array(self.noise_map),
)

Expand Down Expand Up @@ -106,13 +107,13 @@ def curvature_matrix(self) -> np.ndarray:
noise_map=self.noise_map.imag,
)

curvature_matrix = np.add(real_curvature_matrix, imag_curvature_matrix)
curvature_matrix = jnp.add(real_curvature_matrix, imag_curvature_matrix)

if len(self.no_regularization_index_list) > 0:
curvature_matrix = inversion_util.curvature_matrix_with_added_to_diag_from(
curvature_matrix=curvature_matrix,
no_regularization_index_list=self.no_regularization_index_list,
value=self.settings.no_regularization_add_to_curvature_diag_value,
no_regularization_index_list=self.no_regularization_index_list,
)

return curvature_matrix
Expand Down Expand Up @@ -152,10 +153,8 @@ def mapped_reconstructed_data_dict(

visibilities = (
inversion_interferometer_util.mapped_reconstructed_visibilities_from(
transformed_mapping_matrix=np.array(
operated_mapping_matrix_list[index]
),
reconstruction=np.array(reconstruction),
transformed_mapping_matrix=operated_mapping_matrix_list[index],
reconstruction=reconstruction,
)
)

Expand Down
4 changes: 1 addition & 3 deletions autoarray/inversion/inversion/interferometer/w_tilde.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def data_vector(self) -> np.ndarray:

The calculation is described in more detail in `inversion_util.w_tilde_data_interferometer_from`.
"""
return np.dot(
self.linear_obj_list[0].mapping_matrix.T, self.w_tilde.dirty_image
)
return np.dot(self.mapping_matrix.T, self.w_tilde.dirty_image)

@cached_property
@profile_func
Expand Down
Loading
Loading