Skip to content
Merged
99 changes: 50 additions & 49 deletions autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import logging
import numpy as np
from pathlib import Path
from typing import Optional

from autoconf import cached_property

from autoconf.fitsable import ndarray_via_fits_from, output_to_fits

from autoarray.dataset.abstract.dataset import AbstractDataset
from autoarray.dataset.interferometer.w_tilde import WTildeInterferometer
from autoarray.dataset.grids import GridsDataset
from autoarray.operators.transformer import TransformerDFT
from autoarray.operators.transformer import TransformerNUFFT
from autoarray.mask.mask_2d import Mask2D
from autoarray.structures.visibilities import Visibilities
from autoarray.structures.visibilities import VisibilitiesNoiseMap

from autoarray.inversion.inversion.interferometer import inversion_interferometer_util

from autoarray import exc
from autoarray.inversion.inversion.interferometer import (
inversion_interferometer_util,
)

logger = logging.getLogger(__name__)

Expand All @@ -30,8 +29,8 @@ def __init__(
uv_wavelengths: np.ndarray,
real_space_mask: Mask2D,
transformer_class=TransformerNUFFT,
dft_preload_transform: bool = True,
w_tilde: Optional[WTildeInterferometer] = None,
raise_error_dft_visibilities_limit: bool = True,
):
"""
An interferometer dataset, containing the visibilities data, noise-map, real-space msk, Fourier transformer and
Expand Down Expand Up @@ -77,9 +76,6 @@ def __init__(
transformer_class
The class of the Fourier Transform which maps images from real space to Fourier space visibilities and
the uv-plane.
dft_preload_transform
If True, precomputes and stores the cosine and sine terms for the Fourier transform.
This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets).
"""
self.real_space_mask = real_space_mask

Expand All @@ -95,11 +91,8 @@ def __init__(
self.transformer = transformer_class(
uv_wavelengths=uv_wavelengths,
real_space_mask=real_space_mask,
preload_transform=dft_preload_transform,
)

self.dft_preload_transform = dft_preload_transform

use_w_tilde = True if w_tilde is not None else False

self.grids = GridsDataset(
Expand All @@ -111,6 +104,22 @@ def __init__(

self.w_tilde = w_tilde

if raise_error_dft_visibilities_limit:
if (
self.uv_wavelengths.shape[0] > 10000
and transformer_class == TransformerDFT
):
raise exc.DatasetException(
"""
Interferometer datasets with more than 10,000 visibilities should use the TransformerNUFFT class for
efficient Fourier transforms between real and uv-space. The DFT (Discrete Fourier Transform) is too slow for
large datasets.

If you are certain you want to use the TransformerDFT class, you can disable this error by passing
the input `raise_error_dft_visibilities_limit=False` when loading the Interferometer dataset.
"""
)

@classmethod
def from_fits(
cls,
Expand All @@ -122,7 +131,6 @@ def from_fits(
noise_map_hdu=0,
uv_wavelengths_hdu=0,
transformer_class=TransformerNUFFT,
dft_preload_transform: bool = True,
):
"""
Factory for loading the interferometer data_type from .fits files, as well as computing properties like the
Expand All @@ -148,10 +156,15 @@ def from_fits(
noise_map=noise_map,
uv_wavelengths=uv_wavelengths,
transformer_class=transformer_class,
dft_preload_transform=dft_preload_transform,
)

def apply_w_tilde(self):
def apply_w_tilde(
self,
curvature_preload=None,
batch_size: int = 128,
show_progress: bool = False,
show_memory: bool = False,
):
"""
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
given the `uv_wavelengths` (see `inversion.inversion_util`).
Expand All @@ -162,44 +175,33 @@ def apply_w_tilde(self):
This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used,
ensuring efficient set up of the `Interferometer` class.

Parameters
----------
curvature_preload
An already computed curvature preload matrix for this dataset (e.g. loaded from hard-disk), to prevent
long recalculations of this matrix for large datasets.
batch_size
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
which can be reduced to produce lower memory usage at the cost of speed.

Returns
-------
WTildeInterferometer
Precomputed values used for the w tilde formalism of linear algebra calculations.
"""

logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")

try:
import numba
except ModuleNotFoundError:
raise exc.InversionException(
"Inversion w-tilde functionality (pixelized reconstructions) is "
"disabled if numba is not installed.\n\n"
"This is because the run-times without numba are too slow.\n\n"
"Please install numba, which is described at the following web page:\n\n"
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
)
if curvature_preload is None:

curvature_preload = (
inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
noise_map_real=np.array(self.noise_map.real),
uv_wavelengths=np.array(self.uv_wavelengths),
shape_masked_pixels_2d=np.array(
self.transformer.grid.mask.shape_native_masked_pixels
),
grid_radians_2d=np.array(
self.transformer.grid.mask.derive_grid.all_false.in_radians.native
),
)
)
logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")

w_matrix = inversion_interferometer_util.w_tilde_via_preload_from(
w_tilde_preload=curvature_preload,
native_index_for_slim_index=np.array(
self.real_space_mask.derive_indexes.native_for_slim
).astype("int"),
)
curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
noise_map_real=self.noise_map.array.real,
uv_wavelengths=self.uv_wavelengths,
shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels,
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
show_memory=show_memory,
show_progress=show_progress,
)

dirty_image = self.transformer.image_from(
visibilities=self.data.real * self.noise_map.real**-2.0
Expand All @@ -208,19 +210,18 @@ def apply_w_tilde(self):
)

w_tilde = WTildeInterferometer(
w_matrix=w_matrix,
curvature_preload=curvature_preload,
dirty_image=np.array(dirty_image.array),
dirty_image=dirty_image.array,
real_space_mask=self.real_space_mask,
batch_size=batch_size,
)

return Interferometer(
real_space_mask=self.real_space_mask,
data=self.data,
noise_map=self.noise_map,
uv_wavelengths=self.uv_wavelengths,
transformer_class=lambda uv_wavelengths, real_space_mask, preload_transform: self.transformer,
dft_preload_transform=self.dft_preload_transform,
transformer_class=lambda uv_wavelengths, real_space_mask: self.transformer,
w_tilde=w_tilde,
)

Expand Down
83 changes: 81 additions & 2 deletions autoarray/dataset/interferometer/w_tilde.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
class WTildeInterferometer(AbstractWTilde):
def __init__(
self,
w_matrix: np.ndarray,
curvature_preload: np.ndarray,
dirty_image: np.ndarray,
real_space_mask: Mask2D,
batch_size: int = 128,
):
"""
Packages together all derived data quantities necessary to fit `Interferometer` data using an ` Inversion` via
Expand All @@ -34,6 +34,9 @@ def __init__(
real_space_mask
The 2D mask in real-space defining the area where the interferometer data's visibilities are observing
a signal.
batch_size
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
which can be reduced to produce lower memory usage at the cost of speed.
"""
super().__init__(
curvature_preload=curvature_preload,
Expand All @@ -42,4 +45,80 @@ def __init__(
self.dirty_image = dirty_image
self.real_space_mask = real_space_mask

self.w_matrix = w_matrix
from autoarray.inversion.inversion.interferometer import (
inversion_interferometer_util,
)

self.fft_state = inversion_interferometer_util.w_tilde_fft_state_from(
curvature_preload=self.curvature_preload, batch_size=batch_size
)

@property
def mask_rectangular_w_tilde(self) -> np.ndarray:
"""
Returns a rectangular boolean mask that tightly bounds the unmasked region
of the interferometer mask.

This rectangular mask is used for computing the W-tilde curvature matrix
via FFT-based convolution, which requires a full rectangular grid.

Pixels outside the bounding box of the original mask are set to True
(masked), and pixels inside are False (unmasked).

Returns
-------
np.ndarray
Boolean mask of shape (Ny, Nx), where False denotes unmasked pixels.
"""
mask = self.real_space_mask

ys, xs = np.where(~mask)

y_min, y_max = ys.min(), ys.max()
x_min, x_max = xs.min(), xs.max()

rect_mask = np.ones(mask.shape, dtype=bool)
rect_mask[y_min : y_max + 1, x_min : x_max + 1] = False

return rect_mask

@property
def rect_index_for_mask_index(self) -> np.ndarray:
"""
Mapping from masked-grid pixel indices to rectangular-grid pixel indices.

This array enables extraction of a curvature matrix computed on a full
rectangular grid back to the original masked grid.

If:
- C_rect is the curvature matrix computed on the rectangular grid
- idx = rect_index_for_mask_index

then the masked curvature matrix is:
C_mask = C_rect[idx[:, None], idx[None, :]]

Returns
-------
np.ndarray
Array of shape (N_masked_pixels,), where each entry gives the
corresponding index in the rectangular grid (row-major order).
"""
mask = self.real_space_mask
rect_mask = self.mask_rectangular_w_tilde

# Bounding box of the rectangular region
ys, xs = np.where(~rect_mask)
y_min, y_max = ys.min(), ys.max()
x_min, x_max = xs.min(), xs.max()

rect_width = x_max - x_min + 1

# Coordinates of unmasked pixels in the original mask (slim order)
mask_ys, mask_xs = np.where(~mask)

# Convert (y, x) → rectangular flat index
rect_indices = ((mask_ys - y_min) * rect_width + (mask_xs - x_min)).astype(
np.int32
)

return rect_indices
32 changes: 32 additions & 0 deletions autoarray/fit/fit_interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,38 @@ def noise_normalization(self) -> float:
noise_map=self.noise_map.array,
)

@property
def log_evidence(self) -> float:
"""
Returns the log evidence of the inversion's fit to a dataset, where the log evidence includes a number of terms
which quantify the complexity of an inversion's reconstruction (see the `Inversion` module):

Log Evidence = -0.5*[Chi_Squared_Term + Regularization_Term + Log(Covariance_Regularization_Term) -
Log(Regularization_Matrix_Term) + Noise_Term]

Parameters
----------
chi_squared
The chi-squared term of the inversion's fit to the data.
regularization_term
The regularization term of the inversion, which is the sum of the difference between reconstructed \
flux of every pixel multiplied by the regularization coefficient.
log_curvature_regularization_term
The log of the determinant of the sum of the curvature and regularization matrices.
log_regularization_term
The log of the determinant o the regularization matrix.
noise_normalization
The normalization noise_map-term for the data's noise-map.
"""
if self.inversion is not None:
return fit_util.log_evidence_from(
chi_squared=self.inversion.fast_chi_squared,
regularization_term=self.inversion.regularization_term,
log_curvature_regularization_term=self.inversion.log_det_curvature_reg_matrix_term,
log_regularization_term=self.inversion.log_det_regularization_matrix_term,
noise_normalization=self.noise_normalization,
)

@property
def dirty_image(self) -> Array2D:
return self.transformer.image_from(visibilities=self.data)
Expand Down
28 changes: 28 additions & 0 deletions autoarray/inversion/inversion/interferometer/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,31 @@ def mapped_reconstructed_image_dict(
mapped_reconstructed_image_dict[linear_obj] = mapped_reconstructed_image

return mapped_reconstructed_image_dict

@property
def fast_chi_squared(self):

xp = self._xp

chi_squared_term_1 = xp.linalg.multi_dot(
[
self.reconstruction.T, # (M,)
self.curvature_matrix, # (M, M)
self.reconstruction, # (M,)
]
)

chi_squared_term_2 = -2.0 * xp.linalg.multi_dot(
[
self.reconstruction.T, # (M,)
self.data_vector, # (M,)
]
)

chi_squared_term_3 = xp.sum(
self.dataset.data.array.real**2.0 / self.dataset.noise_map.array.real**2.0
) + xp.sum(
self.dataset.data.array.imag**2.0 / self.dataset.noise_map.array.imag**2.0
)

return chi_squared_term_1 + chi_squared_term_2 + chi_squared_term_3
Loading
Loading